设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 1757|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    : c, C7 m. O  h! s! F6 h* C
    5 Z& s, `, W1 v- M" o. m为预防老年痴呆,时不时学点新东东玩一玩。
      l9 k7 ?' L3 o8 c0 v0 OPytorch 下面的代码做最简单的一元线性回归:
    5 h9 {8 l1 x  I----------------------------------------------
    9 C( t- d5 ?) b( wimport torch
    3 n4 l0 T  q1 c  W8 C8 Mimport numpy as np
    6 Q$ U# |" C" I5 @- P# Qimport matplotlib.pyplot as plt1 H. Q4 m* @, A3 F  T
    import random
    3 R" t, A. J7 n6 V
    ) E& @1 l5 j8 K. w: i5 t3 sx = torch.tensor(np.arange(1,100,1)), c: y3 U$ G3 S0 v6 q. q
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    ) G0 M+ G( w, R0 S  p; C: l; o  r7 O: _" }9 G) R# H+ r
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b: d3 \2 h& `- d( p0 q: v5 @
    b = torch.tensor(0.,requires_grad=True)
    # R0 K( \' y- \' m& b
    8 Y6 W' u  @; h+ Qepochs = 100* \! I! I& l! H( {  }; M* I5 u
    + s1 p% `( S' y
    losses = []
    % L8 T" z# ?; o9 X, X' Z$ Jfor i in range(epochs):4 }' U0 ~7 D( h% o* b1 H0 o$ e
      y_pred = (x*w+b)    # 预测6 p# @: g# z6 H) L( }! I
      y_pred.reshape(-1)1 r1 {1 T6 E0 V( j6 R- Q3 X

    ; `5 ^1 z$ C% \6 U, i/ b3 F9 Q  loss = torch.square(y_pred - y).mean()   #计算 loss
    & Z) d8 E1 c, `0 D# c  j  P  losses.append(loss)* k1 N0 f% Z# A. d* p6 ~7 V9 H
      
    & c" s( m) E3 L7 n2 R9 g% O) X  loss.backward() # autograd
    " `2 G& a4 n3 R+ b) g* q. S1 T  with torch.no_grad():/ t6 d/ d$ L8 P0 u4 |
        w  -= w.grad*0.0001   # 回归 w
    9 _$ j% j/ [2 {- b" V% h/ V! c    b  -= b.grad*0.0001    # 回归 b
    . {1 c& T9 ~0 ^% C  w.grad.zero_()  4 x! I( E4 Q' g- w* G' ~" u
      b.grad.zero_()
    ( X: R1 O# b7 g. l9 [" [1 m7 c; z' l. Y. W* F8 g
    print(w.item(),b.item()) #结果, a2 h; i5 B" d% C+ T  \/ C

    ' U7 \: t, S9 A. X% V1 k* \Output: 27.26387596130371  0.4974517822265625
    7 w3 q  P% ?6 \% [: r( ~% a----------------------------------------------, I5 t5 E- ~6 Y7 K9 i0 s$ W
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    ( g$ A9 B/ x& G6 y% r$ f( O; m- ]7 n高手们帮看看是神马原因?- \. E8 |* n# O% |

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 " ^, S+ R+ V" G7 R, Z0 V
    % U- Z( \! M- b) r7 r
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?% t/ e0 Q( S# }
    -------% n& F9 m( a. K6 `5 _5 i: w
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。* M5 {- @! J$ a8 O  I  [2 {3 s) T
    -------
    7 @% v$ j1 H/ v6 _& x0 k算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    0 e5 ^! [- ^$ S6 D! W: A没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    7 l) B9 R4 ], l" @+ \, Z+ o-------
    + q  |9 M; h/ b4 j, M1 g' P! T不好意思, ...

    # I& h0 W# a. Y. c5 j谢谢,算法应该没问题,就是最简单的线性回归。4 h% \( b$ n  h$ f
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 ( p5 i! \+ B$ ~# L# m
    雷达 发表于 2023-2-14 21:52; c7 ]) |/ {" H5 d( ~& g, ^. q
    谢谢,算法应该没问题,就是最简单的线性回归。
    / y! n$ N' W  H2 b" A+ x5 |我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    ( K7 ~; X0 [1 b* |, M" ^% V8 k( h* V
      H7 K+ ?# u" T5 F" I, b7 ?
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。) B- n; v  H( g/ w( |3 G# M  l
    9 O/ m0 G) e' E! v( g! K# {
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    1 F  A4 c9 A- n( f& n3 Q
    老福 发表于 2023-2-14 22:00
    " {+ r1 a" g! a0 E  j4 S( k刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    - P& C& C+ @4 }9 x0 D6 Y4 B: w+ T4 O+ x6 |5 o
    或者把b但的起点改为1试试。 ...
    $ ?: y7 w( Q9 n. Y  o( |

    2 Q2 W" \( V/ H# Z你是对的。
    0 r1 g6 j3 a  r8 E去掉了随机部分
    5 M* N( B8 c8 [) `#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    1 B0 v! q8 @6 J- i4 d9 ]) py = (x*27+15).reshape(-1)& m, }4 |9 c, x2 y. L; i

    # `" k+ R( I& u循环次数加成10倍,就看到 b 收敛了
    , y4 j/ S  L2 Z" V8 Q5 Mw , b
    ! E8 c/ J! c  E2 Z* t# l) G27.002620697021484 14.826167106628418
    % ~8 f) Y" f3 W* m' J- s- _9 _& u& s$ c' {9 z# }
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2025-7-5 02:09 , Processed in 0.036690 second(s), 22 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表