设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 3052|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 : E; ~9 V: y3 F) x" {/ m2 K
    % I8 T# W# P* m9 |3 ?5 I
    为预防老年痴呆,时不时学点新东东玩一玩。# D. z7 o6 P( s- R  m$ ^
    Pytorch 下面的代码做最简单的一元线性回归:8 V' o0 l1 O7 {7 N* W
    ----------------------------------------------( y- Z% ~& [5 W' j7 x0 o) v/ z5 `
    import torch
    # }) w2 \. e$ F0 A4 r" Simport numpy as np6 S+ |, R/ y6 N" Y; n
    import matplotlib.pyplot as plt
    0 \2 l, ~4 t- t' v3 Limport random
    " g& z# M8 q5 B2 S6 C: b6 G3 ~
    , s/ G! A% K1 \8 t' P# Kx = torch.tensor(np.arange(1,100,1))
    ! B  U& T7 k3 D6 q9 r& a4 i" Dy = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    & ~8 g$ d* }* M- L( p( v- s
    $ G3 j; S  E2 W2 r+ ^6 _w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    " Q! o* C7 a6 I' H9 T% ^b = torch.tensor(0.,requires_grad=True)8 W) x7 u/ g0 A, U$ V  @

    9 a8 `) ]; w' N+ m" _epochs = 100
    & B& I* |% _3 i: K
    ! m0 B9 B0 X2 J, e1 {losses = []
    & P! i. m- A6 `8 @for i in range(epochs):' R" j- \5 O! Y6 f  ?- N- K) \
      y_pred = (x*w+b)    # 预测3 B2 x  {8 }+ Y; k+ Z8 H
      y_pred.reshape(-1)3 G# O1 |$ {% n% B4 O

    : n7 c% t) |! g5 L* N  loss = torch.square(y_pred - y).mean()   #计算 loss
    5 f, q2 C! N% Y7 X  losses.append(loss)! a3 E; G+ c- z
      0 _" h9 {3 J: G$ S7 ?& B
      loss.backward() # autograd
    + H; A6 U" Q: ^9 v$ ~$ q8 t/ A  with torch.no_grad():
    % ]5 M5 V" e0 N0 [4 w1 ^    w  -= w.grad*0.0001   # 回归 w5 [3 U4 L% ]( a' I* G' ]" l
        b  -= b.grad*0.0001    # 回归 b
    ! l0 `# N& U* _3 m( X  w.grad.zero_()  7 d9 `, z+ r! p' n4 ?
      b.grad.zero_()
    & N! p3 C* ~# F+ Q
    $ L+ u$ |6 _2 r7 o% R  I( uprint(w.item(),b.item()) #结果
    6 C2 @& o/ |+ C6 P! ]5 n
    / i. o! m/ a) B* s' v$ F1 X, E& Z/ g: ]Output: 27.26387596130371  0.4974517822265625
    " D) S8 X6 F$ k) R6 _, I----------------------------------------------
    4 I# p$ j! S$ c最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    " @% _3 Y  H& X8 N高手们帮看看是神马原因?9 B) j  S$ o7 g- [

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 3 v: p+ O2 X5 J2 K; W5 u! g
      U6 }* [0 G$ N3 @7 x9 ^$ c/ U3 l
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?: j7 K) L/ c- g" ?
    -------# D) N; N& q  N. ~
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    5 P* n/ m: P4 ]-------3 X0 [' |: v3 a6 d8 J
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    8 Y# K- L# y) }没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    . c$ ]4 _# m' }; U-------/ s6 P- G+ s8 b' u: o2 [1 s
    不好意思, ...

    ) y- v# @! y  \5 W$ A. m, D* W谢谢,算法应该没问题,就是最简单的线性回归。/ y2 L$ H6 P! }( S0 V% u
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    0 K3 r( [( l& W/ e, U$ G8 `
    雷达 发表于 2023-2-14 21:52
    * A+ _6 _" Z% u' U6 n( h谢谢,算法应该没问题,就是最简单的线性回归。2 @. I2 ?5 G8 h% z3 w6 `3 ]
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    4 q  x0 B' H5 w  S

    . m$ [# m/ a! |刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    # I7 L! s6 T+ U0 N! L" L: N$ t2 Q  B
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    & ~0 q% g- J1 M9 m2 h/ Y
    老福 发表于 2023-2-14 22:000 v6 u: l: \9 K2 F; J' k
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。7 i" F  [& }& B8 |% n
    # ~8 Q$ v6 T- a# g( r/ Y
    或者把b但的起点改为1试试。 ...
    % b9 D. `" l; x% i: ?
    4 {8 c! Y# z6 Q7 ^, R, {6 K
    你是对的。
    6 u' G/ d4 U" K: [, Q去掉了随机部分  u! i% _& |' p/ w% X  u- t1 w% }
    #y = (x*27+15+random.randint(-2,3)).reshape(-1)
    7 t7 f1 h0 ~& b5 T+ G/ ry = (x*27+15).reshape(-1)
    6 l+ o+ J) a0 S. A7 _
    & v. i# e" C: z! k; ?循环次数加成10倍,就看到 b 收敛了/ u! j, Q; x3 X2 \$ w6 Y, y! q9 w
    w , b; K0 e2 c: Z& i7 R8 P+ ~
    27.002620697021484 14.826167106628418
    * b5 M  ?" u: N/ ^& {; Y8 v. {) y; M' V
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-6-7 09:43 , Processed in 0.060718 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表