设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2682|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    ( e2 [2 o" n. l) {1 c- U5 j2 z% O
    为预防老年痴呆,时不时学点新东东玩一玩。
    ' ^: a% c9 Q+ h6 yPytorch 下面的代码做最简单的一元线性回归:9 L# ^8 S# T. H/ K, Q6 \! ~& n0 v% P
    ----------------------------------------------6 K5 G5 p/ C. S
    import torch
      Y- F* [1 B5 G- W$ Yimport numpy as np
    ! N# t# G  C7 E4 \9 `import matplotlib.pyplot as plt" D% u0 v. }( |$ j$ a
    import random
    # z& E, ?5 m/ G  _- [* e! w! n, e8 ~+ \- j/ f
    x = torch.tensor(np.arange(1,100,1))
    " G# c& L7 p1 Fy = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=154 p2 K+ a; H$ C1 u3 h4 |& ~

    * v: k+ ?- r/ Vw = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    2 c. a7 h3 ~5 h1 k: H) `" y0 _b = torch.tensor(0.,requires_grad=True)) _. {' p) y8 U- Z( W' u( `- e
    : v7 a' Z6 i4 _; K7 q) y
    epochs = 100
    " Z0 n" U% q) t2 {5 O3 g
    ' t2 L* g* V+ t0 H; m3 W7 m3 Zlosses = []# ]/ t! M$ K+ w% ?' h& u
    for i in range(epochs):
    / J4 M5 ]& g9 l2 q2 `  y_pred = (x*w+b)    # 预测+ \. C: O4 u# q' M( n
      y_pred.reshape(-1)5 I: n  a8 M; k0 G' s: W& |

    7 B8 E4 w: c7 Z4 d4 ]" I/ X  loss = torch.square(y_pred - y).mean()   #计算 loss
    , G' E1 E! m% {. B  losses.append(loss)
    * p) R) I. X+ J' J( K: @) F) x  
    ( s2 S4 z$ e( Z# i" s  loss.backward() # autograd
    5 y" L9 g: a' b  v  with torch.no_grad():: ?4 H) k9 v( o- W$ I  ]- x
        w  -= w.grad*0.0001   # 回归 w: H8 C+ W/ W0 m7 s0 Q, p3 a+ }- O
        b  -= b.grad*0.0001    # 回归 b
    $ E9 `: t7 _! M* d  w.grad.zero_()  5 H4 \, `4 X2 [
      b.grad.zero_()
    ' m9 q/ v: d# W4 W4 Y+ A, H; V& i: c- x- l: |
    print(w.item(),b.item()) #结果
    , t! E3 \# ?+ K0 d
    2 T, k- u% M  t) `2 GOutput: 27.26387596130371  0.4974517822265625; h' Z; p8 |' X; ~
    ----------------------------------------------
    ) W+ P* ~* t6 ?9 s( W. H5 P最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    9 P. T) V2 G$ d# s0 c+ V8 v- E5 E高手们帮看看是神马原因?% L. D& A% v8 d1 E& a/ d

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 5 t7 f2 P+ C7 _, S

    # e# i; O( _, Q/ X5 v% Q没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    . x& ]1 ^' b& \% t$ h-------
    . E* g8 A- N8 X9 V  h不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。" c5 ?5 z) U) s) s6 T4 w6 j. ?
    -------
    # O+ I' m; b0 ~) c: b2 D算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    $ y- d% H$ X: A; w" ^没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    . f/ a6 m8 o3 r4 e/ m-------
      z8 N+ C) e# D; v) P! @/ E5 u不好意思, ...
    9 h' b! z, f  I: |
    谢谢,算法应该没问题,就是最简单的线性回归。
    1 j( Y+ ~4 ~& [. g! u我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 + V3 ]! _! X$ y+ \
    雷达 发表于 2023-2-14 21:52  I. P, I) T% h" y  c- p' ~& A$ f
    谢谢,算法应该没问题,就是最简单的线性回归。
    / ~6 Q+ p* X: G$ H5 }( D5 W+ M我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    + b8 @( z+ T# O, g' ?9 |

    2 q; z) n1 v3 A/ p刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    6 R" v  `( a( T2 @! \- Y6 c# H2 H
    ) w5 E5 I2 ~2 L/ _& q0 ?或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    + D! A7 R; _' a7 t- h8 _
    老福 发表于 2023-2-14 22:00
    $ }$ H6 m- _% x+ [# A/ G刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    0 ]5 t5 {1 E1 [! P
    3 Q- R4 X0 O( k或者把b但的起点改为1试试。 ...

    : W6 R3 X7 ?& K
    4 a" D! ~/ g& I+ l* r你是对的。( T" u  f- c- l
    去掉了随机部分8 X( }7 `. @+ Z) i9 q8 p& D
    #y = (x*27+15+random.randint(-2,3)).reshape(-1)  O8 t* `& N* I6 e
    y = (x*27+15).reshape(-1)
    / t  g* W& ]0 {$ O
    ( C5 p& @2 q, X* s0 ?循环次数加成10倍,就看到 b 收敛了
    2 \% E3 N; P, C' D# ~0 jw , b) Q8 {6 n5 E4 h7 v. M1 k/ z
    27.002620697021484 14.8261671066284186 v/ s% k3 G3 A/ W) k9 H' j

    - l% ^8 T# b5 \! F# i; a和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-4-1 09:30 , Processed in 0.059906 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表