设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2210|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    9 ]( K& B; |5 v' z$ q& O5 G; J+ u- }8 _* B
    为预防老年痴呆,时不时学点新东东玩一玩。
    & Y. T. p9 y- A% o/ rPytorch 下面的代码做最简单的一元线性回归:+ a+ ~* \) K' L
    ----------------------------------------------
    7 p& {4 z9 t3 h# zimport torch
    3 n* p# ], p+ K7 b( Dimport numpy as np) B3 E7 v1 _( c' h; i  h2 q
    import matplotlib.pyplot as plt0 x+ ]. H6 |6 V0 a- e4 P
    import random/ s4 y) L5 G6 x5 E2 g) c
    * I. [7 g; b) o3 b& X3 R
    x = torch.tensor(np.arange(1,100,1))6 E1 h3 ^. }" ]2 {
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=151 Z1 J6 S8 g  Q: o# }
    6 m& f0 C' W- I; E4 T
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b; t; r! U- W/ D$ V- V1 t
    b = torch.tensor(0.,requires_grad=True)# ?: i: q& M  `- D

      i* }- w: _1 W( f% a% z& \5 N6 yepochs = 1004 V  g3 d2 ^( [- M  r

    & m  E7 e5 N  _7 `  c5 Slosses = []5 N: w) e. J9 C& g9 ?+ y
    for i in range(epochs):6 d$ z) K6 R4 m+ ~. y/ S
      y_pred = (x*w+b)    # 预测4 O$ M( y2 }9 }7 q/ t2 ^9 ]8 `/ A
      y_pred.reshape(-1)
    / P4 T% R! u- X% F- _( w
    ! b8 K$ E+ s3 q) P. k4 V4 d$ `  loss = torch.square(y_pred - y).mean()   #计算 loss
    6 P5 G5 H; X1 [" u  losses.append(loss)
    % z# N$ M. z  E: u4 f  3 D$ [8 G& W0 F4 ^% ^
      loss.backward() # autograd
    * v. K3 x4 w# ?+ W+ `  z  with torch.no_grad():! J. y: L( A- L. a  e/ O+ Z
        w  -= w.grad*0.0001   # 回归 w
    , T+ h  a2 ?& _( ~/ ]8 R6 c/ g    b  -= b.grad*0.0001    # 回归 b
    7 @4 _! p; ]  Y+ J' }) B  w.grad.zero_()  - [% ]1 v! i( x
      b.grad.zero_(), B2 Z- _4 [. s
    ; C# L; W  B4 g( f4 |- ^) M  e
    print(w.item(),b.item()) #结果
    3 ^2 j; }0 j- p' p& W
    + m8 f4 i( N* gOutput: 27.26387596130371  0.4974517822265625
    5 N3 z( z  Z2 y. e6 D& C----------------------------------------------6 w3 J3 D& X/ \- r- E
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。" a. e& h/ Y% x2 m3 W+ q
    高手们帮看看是神马原因?
    0 d6 T! J& K* ]7 y  d5 \, `2 t4 G

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    5 a- ^  I: C6 {) ?4 y) t/ H" k9 w% D# y) o0 [6 q
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?0 o- c* x* g# K- ?; d
    -------
    6 N% {! ?& S. B2 N5 H( M不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    # y3 S" t- z# R-------7 e. V2 f' w5 s8 U7 o
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:234 }( A2 y& Y' d% ?6 z
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    4 g) j' K* _5 Y-------
    , @; }$ o; S- l% P$ g- u9 ^不好意思, ...
    # A. t  x) s0 o8 e
    谢谢,算法应该没问题,就是最简单的线性回归。, v. }7 T+ V1 m& t
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    & T, Z4 B7 K# M& J
    雷达 发表于 2023-2-14 21:52& k+ H  l  ~. S2 H
    谢谢,算法应该没问题,就是最简单的线性回归。
    2 `; t+ [0 c; |  k. r, [我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    - L: Y, C0 _9 l. d% Z
    ! z6 [7 e6 C4 C
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    ) ?5 S( c% G: x$ m
    % R7 K: Q% Q! D) T! t7 Q或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑 * \2 o& f5 }5 g2 R2 a, t
    老福 发表于 2023-2-14 22:00$ y6 b  i4 W4 i1 \- O& g; H) V
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    $ \# U, h2 |# u2 K6 a9 S. c+ n
    # t$ N6 D! w4 J- g或者把b但的起点改为1试试。 ...
    . ]) I4 D; L$ n* Y1 A

    - h3 B' Y5 |" \你是对的。
    , A' V: I: v1 D, M3 m! J% z5 M0 G去掉了随机部分& [6 t. e4 m% C$ Q) S6 c# z
    #y = (x*27+15+random.randint(-2,3)).reshape(-1)
    # }1 H9 i0 _* S9 n/ Qy = (x*27+15).reshape(-1)
    " r6 D! a; q/ V5 [) |0 E* N" v; }" J$ t& q" x9 Q
    循环次数加成10倍,就看到 b 收敛了
    : \- O$ L3 r4 \3 }w , b, j" ~7 x4 `. [: E3 \& W
    27.002620697021484 14.826167106628418
    ! d7 T2 F- a' ~& V( S- E
    $ W6 i" Y8 r& ?" I8 R和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2025-12-13 16:50 , Processed in 0.030789 second(s), 21 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表