设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2759|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 / ^1 J/ g8 Y" i: O
    6 c) h6 y; A% S! ~
    为预防老年痴呆,时不时学点新东东玩一玩。
    ! e4 U7 p6 \, pPytorch 下面的代码做最简单的一元线性回归:+ c& Z. g) P2 g1 N. f
    ----------------------------------------------
    / _' _) j* p/ Yimport torch5 k$ H9 L4 G, r0 w7 w: g' ?
    import numpy as np
    7 g0 Z! i* c0 F- g  limport matplotlib.pyplot as plt
    6 Z* d# X* P. w1 V  s2 \( Aimport random- U- K' y4 _1 p8 [7 X! u
    8 ?- Q8 P6 d* ]9 m
    x = torch.tensor(np.arange(1,100,1))
    " [, u6 V# z7 Z; g/ ry = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    . M  z2 |! d* f8 J  K8 z  l( V
    : V5 V  t4 q# z9 jw = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b; v2 z. i8 q. m: H( j  |
    b = torch.tensor(0.,requires_grad=True)
    0 m0 T& X5 A, s5 p9 g9 z( s4 z; Q) E# z: m% w
    epochs = 100' ~) A  P0 }& Z3 |9 N* v

      }/ ], \$ K3 s4 k4 w5 slosses = []
    & m1 e6 E7 K/ a4 a- A3 J; sfor i in range(epochs):
    ! m5 r$ u% g! }1 I0 N$ C+ M  y_pred = (x*w+b)    # 预测# C* C3 S1 @  \$ {3 O, T
      y_pred.reshape(-1)
    & ?( k- C5 R& W( A& l& k
    - T9 ]3 m7 X# |' U# C  loss = torch.square(y_pred - y).mean()   #计算 loss
    6 N4 m2 Z" V1 O) i. L  losses.append(loss)) u; G/ u3 g" y) z7 W( ]  k9 r
      
    0 z9 n( O. r9 o3 S8 Q* |- x. u  loss.backward() # autograd
    & ?/ K9 W( R0 f7 @' \- Y% u) N5 h  with torch.no_grad():4 [+ C) X, Q$ F- x
        w  -= w.grad*0.0001   # 回归 w
    ( a% {+ i& l2 Y1 o    b  -= b.grad*0.0001    # 回归 b 8 |, d' W; C: h$ A, D: S& \8 j
      w.grad.zero_()  
    # z1 i5 T" l5 `8 _& d* N  b.grad.zero_()
    / T: g/ J  D0 j1 E2 N0 q& `
    1 R  U% v( n8 F1 c. eprint(w.item(),b.item()) #结果
    8 i$ d5 c# o& {& L/ ?4 r+ ^2 `2 _3 ?9 `5 I, ]& @$ h$ w/ o+ U4 ~
    Output: 27.26387596130371  0.49745178222656251 G- i* R  y8 V
    ----------------------------------------------; ]& N9 a# f: C  K
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。% G& F# l, ]& Z1 U
    高手们帮看看是神马原因?
    + z4 m7 j5 P! @5 F. {" c! i

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    " F& Y7 ]+ d) P  y7 \7 u2 O; Y+ D# D9 i& N( t
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    3 C! M3 [" U+ [-------+ w% o1 M- @, b2 _6 R6 h  T
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    0 A$ z* P1 |3 ~! K- f-------
    0 `: A4 {( L: m" _2 \1 p算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    ( B) x5 V6 ]2 n2 s3 c( E5 n没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?/ M3 V( P# R. {" s# r6 p2 L
    -------7 Q: u  ^/ V4 G! r9 e0 h
    不好意思, ...
    & f% [$ o! m; w& Y2 e. m) r
    谢谢,算法应该没问题,就是最简单的线性回归。8 O' I; W: r* r4 Y. C% Q* ^
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 2 ^6 \7 y: I. b  H
    雷达 发表于 2023-2-14 21:52( M9 d+ Q  Y" S. {8 }% g9 k8 ?# Y
    谢谢,算法应该没问题,就是最简单的线性回归。/ k9 \- I2 {0 D% {
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    * B4 D2 Y4 j  I3 W& N1 g6 s& `  m0 \' A4 }% ^/ C% S( v& b. E
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    4 |0 a' w0 I- N$ w+ ~( o  a& L; u* ?8 B1 R. J, ?+ K% T, s/ X
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑 3 m3 t4 J) N1 R( n0 |* Q$ M
    老福 发表于 2023-2-14 22:004 p6 w/ O, q; {' ~( a
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。6 f1 W6 `. ~0 Q6 T- Y' _* c
    ; O8 V: k) e) P
    或者把b但的起点改为1试试。 ...

    8 L! K5 S& y% e2 }/ w0 T# Q; _+ k5 d( J7 N  R
    你是对的。
    * o; O) {. H# f6 b* F去掉了随机部分
    2 n( V, D. I: T9 k#y = (x*27+15+random.randint(-2,3)).reshape(-1)+ _% o5 f( n# g4 k: I. l, |
    y = (x*27+15).reshape(-1)
    , N& r( _- ?& p* E. @" C; Y; O" z5 }
    循环次数加成10倍,就看到 b 收敛了
    ; }6 e5 c+ e0 W' M- _4 g/ Tw , b* }" }# g; r4 E5 `) J8 Q+ u
    27.002620697021484 14.826167106628418" O7 t  A% w# b( T

    4 e- l9 c. G# ^: Z  U和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-4-17 17:26 , Processed in 0.055401 second(s), 17 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表