设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 3074|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    2 \( Z0 W0 P) R- L% ]" |6 O7 D0 C4 D! x- B4 A* B
    为预防老年痴呆,时不时学点新东东玩一玩。& b- i5 g$ V; |$ t' _
    Pytorch 下面的代码做最简单的一元线性回归:
    / E1 {5 D) H) P----------------------------------------------
    - _6 Z1 k: m1 W0 {- jimport torch
    % ^9 @1 n7 |* c4 ~* C, d- N8 D( himport numpy as np
    ' e7 f, |( [0 ]# O4 \" simport matplotlib.pyplot as plt
    # {6 S8 h" ]6 Z0 `4 Aimport random
    2 R9 W" D7 k" P8 [$ v" X- _4 v: C# q# M, [
    x = torch.tensor(np.arange(1,100,1))
    5 k+ a: f& I- V$ Q6 _y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    " a! t' y( Y8 }: A5 A, U) {+ o5 {4 z! P
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b  u& s8 {% e9 n
    b = torch.tensor(0.,requires_grad=True)
    : i- e7 L/ x6 Q* y
    $ q4 u' f% m! G7 }  m$ jepochs = 100
    ) o, E; H1 f) m9 V# @% i6 H! Q" N4 C/ H6 K
    losses = []
    / J" D$ f# A  G$ b" c& cfor i in range(epochs):! w& o) k5 K, f& g3 I0 ?/ w  C; r5 S
      y_pred = (x*w+b)    # 预测
    ' l1 m& v5 e9 y' ?  y_pred.reshape(-1)
    ( ~% m2 b4 [  H. a/ ~ 8 D4 n" ^& Y1 c# j7 z. p( N, b
      loss = torch.square(y_pred - y).mean()   #计算 loss
    ' E' N' \  _, _7 _) J  losses.append(loss)8 Q( A0 }1 R& s& i1 h3 ]$ D
      9 h, F, E: j& C- Z+ N
      loss.backward() # autograd  f6 t2 Q; Y9 }+ w! m$ o" q& p
      with torch.no_grad():
    5 [% e: r" M/ a    w  -= w.grad*0.0001   # 回归 w
    4 @! D9 X- B- ^: G$ p: e& P    b  -= b.grad*0.0001    # 回归 b , t0 ]; S* _5 [3 e1 G0 e
      w.grad.zero_()  
    * y( a) d2 V' E! R. e  Y  b.grad.zero_()8 s; V. @$ g% k. M; m3 a

    0 M7 m% y7 J/ _) cprint(w.item(),b.item()) #结果
    ; I& z( I9 o. E( b, t( o# \7 \/ K; X8 i  c4 c
    Output: 27.26387596130371  0.4974517822265625
    7 p4 r0 |6 d! l1 t! J----------------------------------------------
    2 T  H# \8 b6 U/ z% b最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。# W! X+ G4 n: H
    高手们帮看看是神马原因?1 y6 y# C- b+ U8 g) P

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    / s7 Y9 j- ?2 H& n0 x4 k; e0 _& _9 c# c% P. G
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    : W0 o2 q" W" @5 ^, F-------" s% v* a( }$ G( P$ g
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。+ C3 V/ J: u3 E: S
    -------; v4 A- \; w5 S& O! R; Y/ ~: N
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23# T( y0 [7 N( w
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?# K& _  ~7 j; N8 S  M6 x# {" N
    -------. ]/ H9 i0 b5 i5 s5 J7 s
    不好意思, ...
    . _, n8 i. `# A& {5 D
    谢谢,算法应该没问题,就是最简单的线性回归。) e/ j% t) w/ |$ V7 P2 x* K$ i
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 8 d9 [, c$ {7 X4 D& \
    雷达 发表于 2023-2-14 21:52
    $ M9 ]( F7 y4 r2 O5 e谢谢,算法应该没问题,就是最简单的线性回归。7 e+ z- C( T8 B7 U1 y* x' |
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    . Z, N7 s9 }. n  e& A* N  }( ^3 K& ~7 m) f5 W
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    ) M) I- c: L! }3 e
    9 U- M0 m" U4 s1 n或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    6 r; x# }$ V" w, C4 B  l' P
    老福 发表于 2023-2-14 22:003 G6 x6 [# W4 @
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。, `" n6 J) e8 v8 G; ^* \7 g

    # a# r& b! A1 R/ y或者把b但的起点改为1试试。 ...
    5 U% E- d2 T/ I7 P! j9 F3 H
    . m: Q3 G+ u6 U& ^' f. X
    你是对的。1 G  K7 Z1 S! R3 K
    去掉了随机部分2 |1 W5 J* g; o% T8 N3 L
    #y = (x*27+15+random.randint(-2,3)).reshape(-1)0 n0 h' \. f% }3 E" Z4 |. m: p
    y = (x*27+15).reshape(-1)
    ; s0 r6 c! r  i0 b* j& K5 J7 K% n3 i% l% O- D& A7 i
    循环次数加成10倍,就看到 b 收敛了
    / N0 s' D3 @9 y" iw , b
    ( A* Y+ G- \* ]3 P1 U27.002620697021484 14.826167106628418
    , Z6 D4 }+ N; h& r8 A( ?. p* @! Z# B: R, ]# }+ y
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-6-10 20:09 , Processed in 0.055480 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表