设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2390|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    - |1 A: O+ ]& o# U; J+ ^1 r$ C) a- N
    为预防老年痴呆,时不时学点新东东玩一玩。
      c5 T9 n2 w  x% B& WPytorch 下面的代码做最简单的一元线性回归:
      O' S" ^3 F( q6 P! ?----------------------------------------------( K: o% q7 ~8 A6 R( R
    import torch8 e0 g+ T- Q, W
    import numpy as np. E; B' u' I- c6 w9 j
    import matplotlib.pyplot as plt
    9 S, q* z1 r! \: }5 {3 p0 wimport random: Y9 N) A7 E$ Y2 q3 @

    1 g( F% G" Y+ t7 b5 }, m# `x = torch.tensor(np.arange(1,100,1))
    ! T: l$ P! q1 o5 n( R3 ?y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=151 [- m* a1 v/ L& ~
    5 U0 {' e; R+ ]9 C, o
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    7 R' \5 U% s& s- H' tb = torch.tensor(0.,requires_grad=True)
    7 B  m9 k2 U( j" V7 f; [7 I( O- t3 D9 }; w) u- h2 w# p
    epochs = 100
    0 H* v5 V& F# |( ?3 \! u6 q5 [. w5 e  K2 d* w! u! g) u3 x
    losses = []
    . z1 W2 j# o7 O0 zfor i in range(epochs):4 S/ s- p2 V; f, O, w
      y_pred = (x*w+b)    # 预测$ [- d+ z: [- y+ |
      y_pred.reshape(-1)
    . j+ J& b* c) G7 s: y
    5 L" _9 R$ o/ d& r# r, c7 S  loss = torch.square(y_pred - y).mean()   #计算 loss3 q& H: s) |8 C  s- \
      losses.append(loss)6 g2 {2 f8 }6 _8 {  ?
      
    & n* ]) s0 ~& i' S  loss.backward() # autograd/ \7 h6 ~( c! \6 h8 Z
      with torch.no_grad():* y. E3 W. y* k8 W$ W
        w  -= w.grad*0.0001   # 回归 w
    , q4 y; C# A9 J7 {1 g6 ]) o    b  -= b.grad*0.0001    # 回归 b
    + \$ Z1 c1 }1 \0 J( o  w.grad.zero_()  
    2 N  S" A+ U6 D  b.grad.zero_()
    7 B& }( f8 Z$ w) h3 d4 @3 r2 X1 p" L. |4 U/ N, w* m3 w; [0 d
    print(w.item(),b.item()) #结果
    9 \5 q3 s7 J) X  K/ m  @- H! ?0 O8 p2 M
    Output: 27.26387596130371  0.4974517822265625: _9 X, M+ q6 z5 O* V" x
    ----------------------------------------------9 d3 e% z1 _; U2 x- G1 b7 o
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。7 W3 c; H) T, L* ]3 v
    高手们帮看看是神马原因?( f' e9 [' s3 Y# }0 ?9 w5 t8 i

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    ) L5 R6 [- f- Q* p+ D1 P& o$ W4 }+ e, K
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    0 v1 C' ?% ]9 w8 y. E- k-------5 b$ K  a' L( p- a& T
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。4 \# K, X3 o3 M0 y6 Q
    -------
    ' G, K: G: K+ U: ?+ N& n算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23: g" M  x3 J( I  \' l* g* I
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    4 g$ \3 _0 |0 u- e-------
    0 \, ~8 r; ]6 a. j" S不好意思, ...
    ; Q& f, U- s4 T, G6 u
    谢谢,算法应该没问题,就是最简单的线性回归。
    . m" H" Y- o2 r; n8 l我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 ! @# G) n& h; V$ h4 c, x) D
    雷达 发表于 2023-2-14 21:52
    : v4 E" H) W/ {  r2 K6 U谢谢,算法应该没问题,就是最简单的线性回归。
    2 J9 [. I$ Z& f8 u6 a" ~' v8 @我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    - v0 j9 H* x& S# q& F6 ~

    / [' v3 S" a, e* U" c- P刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。) v9 }( \( C9 e  o7 a2 p# G* e
    3 W6 ?2 A- n) Y4 L( p  Z
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑 : P! [2 e+ d; G4 e" \5 u9 m
    老福 发表于 2023-2-14 22:00
    . F0 l# M. Q+ J0 ^) H6 B. z; h! i* l刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    - G( D+ b, j0 z! w3 S- N+ c/ T6 F0 L1 P+ I' T
    或者把b但的起点改为1试试。 ...

    - A: N* ]" D6 ^2 Q: D
    8 p6 c3 ]# k0 g* B  S你是对的。1 h* F4 }0 U1 N; g! c
    去掉了随机部分; A- @/ h- b2 y* b3 C
    #y = (x*27+15+random.randint(-2,3)).reshape(-1)6 h0 l' j0 k6 b8 z4 K
    y = (x*27+15).reshape(-1)! H+ ?+ H% X  M" \3 v9 o9 H. }

    - x* r2 i1 l+ k( N. ~8 i( {循环次数加成10倍,就看到 b 收敛了
    * W  I' S$ P/ u! rw , b
    0 e+ E' u; M0 |5 F# O8 z- Y: o27.002620697021484 14.826167106628418  K& g; y9 n: U
    7 h( ^- ?& V6 P( N6 t" _
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-1-28 17:54 , Processed in 0.062453 second(s), 21 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表