设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2786|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 , ?$ p) G8 s6 W2 S
    ( i/ |& F/ g  K8 }* j/ N
    为预防老年痴呆,时不时学点新东东玩一玩。1 d+ \# P6 b5 m$ g8 ?' ]1 H
    Pytorch 下面的代码做最简单的一元线性回归:' L( Q/ e  Z2 M- ]2 a& I( u- Q
    ----------------------------------------------
    & w- p6 s2 O: G, {+ b7 timport torch" ^! y' A# V" e) R3 @9 H3 _4 m" z
    import numpy as np6 K6 \8 Q, f% t1 V; A- n: K
    import matplotlib.pyplot as plt
    0 T  u& `$ g3 {0 c' ]import random
    3 r# |* _/ G; B/ h$ ^
    . X1 x/ N* S) q" m+ nx = torch.tensor(np.arange(1,100,1))) J. e' T# j7 j0 B" E
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    , d! N1 m1 R% l
    " D0 P( k( s1 y! H4 e3 U- cw = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    ) T. @3 D; N. T  Q+ j, Q+ _1 _b = torch.tensor(0.,requires_grad=True)
    , M7 M5 V- n. }4 v$ t& s3 G
    + H) U& x. \" C$ s! P4 _& h. E7 zepochs = 100
    # h$ s# `6 S0 D. w
    ) w+ y. E1 O# c+ w, Mlosses = []; n( |* u  S+ w0 x
    for i in range(epochs):+ m# ^4 ?. s: l/ ~; G, w$ q
      y_pred = (x*w+b)    # 预测* A' U& R' j; t
      y_pred.reshape(-1)
    3 r5 V5 `! F6 k4 |; c6 r! N
    % y1 g  s" U$ j( D9 V: ~3 M  loss = torch.square(y_pred - y).mean()   #计算 loss
    " `. U3 I9 l2 L) o2 U  losses.append(loss)
    * s* Y* y5 j* |' G4 T2 Y  % q! F" ~  P! Z* c
      loss.backward() # autograd8 t. M5 {* v* T! l* g& l% J0 d9 r  m
      with torch.no_grad():. R; N. Z0 v9 p/ C* E
        w  -= w.grad*0.0001   # 回归 w6 z+ a0 Y$ _, `. p: Q$ z+ z
        b  -= b.grad*0.0001    # 回归 b
    4 `* G+ m1 d2 W. T  w.grad.zero_()  - M4 o& m0 u& H# \- ]/ V
      b.grad.zero_()
    2 e/ y) }9 E4 A9 U8 m
    3 f6 O, j- m2 H3 fprint(w.item(),b.item()) #结果: v6 z2 P$ {' Q8 c/ y( B

    ; h0 I6 c9 I* C; }) dOutput: 27.26387596130371  0.49745178222656253 A$ C# V) g9 g2 D! E
    ----------------------------------------------
    1 n# t- }# c- E: g9 R! d& I最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。3 F( b0 Z) J; ]+ g
    高手们帮看看是神马原因?+ s( X% J6 G" x9 W  O

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 ; P( @9 o( H6 j( H) d
    % y2 A# J5 D% o& Z% |. O
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    0 b# ^' x7 \( d( V-------+ M! K; I* x' ^- Q6 E
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    0 R4 B$ ?3 P6 {1 V: X-------: J! e2 K0 G" D0 x- J: ^( z0 l1 f
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    : t  b4 J  i/ [没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    5 b$ n/ B  S! k* U" }0 J1 I( y% @+ }% s- i-------: ^8 u- Z( s5 |1 ]! z* N1 @0 W
    不好意思, ...
    " w) V1 H4 I# ?- z- ]4 w4 h) u5 R
    谢谢,算法应该没问题,就是最简单的线性回归。4 b' x$ t+ K' I$ d4 A% ^
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
      O) ?" P% ?/ P- p
    雷达 发表于 2023-2-14 21:52
    ; V7 O8 H% o$ ?! _/ I谢谢,算法应该没问题,就是最简单的线性回归。% q, m0 ]. y# S
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    / h% G, ]! K* H' \( v
    # T, z. h3 E2 J# g# a刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    4 ^4 t" _  }* O, w8 z  X$ J' }  d" [) k: ]
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑 : z/ d! e1 m) W" G; O
    老福 发表于 2023-2-14 22:00
    ) m: v4 @& R' B刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。) I- m1 \3 G# E2 w: ^! b5 L8 W4 U+ |
    # @) i# _- A* d
    或者把b但的起点改为1试试。 ...

    4 w! b& u; ]5 B3 w3 W* k5 D8 Q. f) x6 l# `, u; o
    你是对的。6 j2 O1 O# f: u1 b& i7 t
    去掉了随机部分
    " x: M; L. _) H- i- k4 r( X5 v#y = (x*27+15+random.randint(-2,3)).reshape(-1)- k( G( w( [; s/ ?/ h/ B
    y = (x*27+15).reshape(-1)
      }* W+ b/ B3 K; l7 e9 }$ S: x7 n5 |% g9 c
    循环次数加成10倍,就看到 b 收敛了7 a, g. w: m  d3 Y
    w , b
    , A7 x' O' U9 Q. B$ ?# W% L27.002620697021484 14.8261671066284182 ]1 P2 S( i; _' g' G: p

    3 g/ U" t. n; ]3 j/ C% j" X: [和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-4-24 03:21 , Processed in 0.060175 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表