设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2832|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 * G3 r" K/ c  Y0 v) C

    5 V1 c5 F: a( D; _2 H# [为预防老年痴呆,时不时学点新东东玩一玩。
    1 ?2 ]1 Z" e6 h; u- c/ b- a: yPytorch 下面的代码做最简单的一元线性回归:$ U' K; N  y/ n( l, ]0 `" j
    ----------------------------------------------% E5 J* E8 f  E
    import torch
    - `* N! X# T% w8 {& s/ ~7 uimport numpy as np! ]+ o9 O' P- C0 n
    import matplotlib.pyplot as plt, V) K$ E' O1 Q' A0 l
    import random. O2 J' I7 G8 d9 U6 U

    . F: o- K! ?1 |3 V2 b; F2 I; sx = torch.tensor(np.arange(1,100,1))
    ( E1 \, N# V- m1 H5 ~y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    7 J2 I0 H) W; C5 h
    : {2 D4 J) i8 _" ]w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    " g4 _$ d3 g. O1 c1 u4 {. o+ xb = torch.tensor(0.,requires_grad=True)% v5 o1 J, O" @0 j* u3 U( s
    8 |/ Z6 d; d* U$ ?, Q0 ]4 P# k
    epochs = 100
    7 O1 _3 P! b/ r; y6 A
    6 |- d$ g) Z! a, J  w1 Z; Klosses = []
    1 K4 }' v9 Q. N' }$ m3 I5 B) @for i in range(epochs):  ?+ M: T% N$ y7 l1 Y. P
      y_pred = (x*w+b)    # 预测
    * ?' n# M" V6 F+ k, \! D% z  y_pred.reshape(-1)
    : q. i6 A' e8 S6 z7 G
    / l9 T( d2 b" E5 j) r7 e# f  loss = torch.square(y_pred - y).mean()   #计算 loss: D, m. U  |4 I1 `/ q: W$ }& Z
      losses.append(loss)
    5 J: a+ A6 {, s6 b  
    & ~1 z+ ]: B2 U9 z. q6 y# j2 `. w  loss.backward() # autograd8 Z6 x- N4 k7 ~7 J& w( {2 j5 d! d
      with torch.no_grad():
    2 ^/ |7 Y8 V! ?) X7 _    w  -= w.grad*0.0001   # 回归 w5 S# s) e8 g; a. O2 p: m& i
        b  -= b.grad*0.0001    # 回归 b
    4 U/ R( p( D9 c/ k$ m' x  w.grad.zero_()  
    # O( j5 d+ K2 |- I  b.grad.zero_()
    : s9 l1 i$ O2 K0 ~
    , Z8 }8 w' T" h! Aprint(w.item(),b.item()) #结果/ f! W$ D; `2 _0 R+ h' U

    ! w/ j1 D/ h& D3 \) n+ d% ?Output: 27.26387596130371  0.49745178222656254 T. n1 }2 t  `& q6 ]
    ----------------------------------------------+ e9 o+ r% l4 y1 r: |% g3 \
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    ; {+ O$ v6 d  N* j高手们帮看看是神马原因?
    4 |3 `4 v. Q. B, T! {

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 # z: q3 h; Y' ?- v2 P/ X" z
    9 Y2 W* H( c" x$ \
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    ) K3 `# j/ X6 P# Y: r-------& t, i* Y; O9 L9 K% d
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    . k/ c) _9 Z& m-------; m5 z; l" Z) v) M% j
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    5 w% G+ Z- p, L* V没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?6 f, N7 V+ x$ o% Y: e$ o7 ~
    -------
    * g$ L8 D, T) u. D不好意思, ...

    , x2 i+ G' D$ ?: j3 e4 R" R9 V谢谢,算法应该没问题,就是最简单的线性回归。
    . x9 w) [+ O' w% l, Z  I我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    : B* q) Z7 r' q& m; l
    雷达 发表于 2023-2-14 21:52! U1 c2 R) E& `! O5 b6 E* Q
    谢谢,算法应该没问题,就是最简单的线性回归。$ `! k& u( X+ @
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    , U  Q* I% M% W$ v/ Q1 d& c6 j0 D4 P$ y) [
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    3 q* H+ M5 V# N9 b- o  \; T. \$ H
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑 ; j+ O( p. ~+ G9 L% e( x
    老福 发表于 2023-2-14 22:00# T% q1 g1 [1 \
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    $ r# b8 ?3 S: |" n- E- J
    * ?+ w) t) b& A# o6 x3 }/ Z或者把b但的起点改为1试试。 ...
    % \6 o+ |4 J6 j. w' G' f

    1 V% w6 B3 Q; |/ T+ H/ C你是对的。
    2 n2 t7 |- I+ g8 f$ |' {  q去掉了随机部分
    " l  b; L3 n2 V, Y1 o4 e; z#y = (x*27+15+random.randint(-2,3)).reshape(-1); h* @2 a* J5 b7 I% s
    y = (x*27+15).reshape(-1)$ v8 [' F4 t( u0 ?9 H
    2 Z8 p) R" `+ x/ L* o" e3 f
    循环次数加成10倍,就看到 b 收敛了4 h; [* c: ~+ H0 E# s: L! j) y
    w , b+ g! U! j' G: B* x
    27.002620697021484 14.826167106628418
    5 S% }5 g+ a( Y7 y- M7 ]6 E$ m8 g8 G2 ?, l9 `! D5 f
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-5-2 17:32 , Processed in 0.067727 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表