设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2993|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 ' _# r: s& }  L8 m% H5 g

    ( t' y3 g- v7 H' O5 |为预防老年痴呆,时不时学点新东东玩一玩。- x! W6 `6 O3 \  p7 r% W
    Pytorch 下面的代码做最简单的一元线性回归:8 ~) y! A, n: L5 o  g8 M
    ----------------------------------------------% z/ i* p! n* b+ u. I
    import torch
    ' Q& l7 C$ R7 n, T' u2 @  ]import numpy as np
    7 e" a2 H  F7 v1 U9 Yimport matplotlib.pyplot as plt2 N1 S1 g$ U. N$ W/ q# V
    import random
    " `' l! R9 t: k( ?5 \- Q% N
    - f1 W7 V! |0 Sx = torch.tensor(np.arange(1,100,1))
    & Q2 d& ^6 h4 j5 b+ \& P( H+ ]y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=156 P$ a# ]% ^6 Z0 e2 m

    ( R0 Z. I+ g. Ow = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    * \) Z6 w* q0 e: a# B1 ]b = torch.tensor(0.,requires_grad=True)
    ' T2 L; M6 f+ s% G8 d
    : t8 E, v2 Q9 q* D% f' w3 wepochs = 1007 ^7 [) k6 u  E4 C' s& T+ ], G0 v
    8 {# D/ I, U, I+ ]
    losses = []
    $ o! u& D6 y( {5 a. Y0 X& rfor i in range(epochs):* o3 _' q# P2 i$ Q/ P
      y_pred = (x*w+b)    # 预测2 u0 [# _: ]2 X0 y4 W( t
      y_pred.reshape(-1)
    1 M/ W0 T8 Q) D8 t  ]$ [  H8 ]
    ( v1 \9 P4 H$ _" Q. `  loss = torch.square(y_pred - y).mean()   #计算 loss
    , M% J' |9 {7 G  losses.append(loss)
    $ c. |: F1 B; U" o  
    : `. v  ]' W, g; C2 w& `  loss.backward() # autograd
    $ w# Y+ _# }( l: f4 F6 F( a/ U$ k  with torch.no_grad():+ a9 Z  ]/ D% @4 C: I. N
        w  -= w.grad*0.0001   # 回归 w
    / V/ ^8 B& o4 i    b  -= b.grad*0.0001    # 回归 b
    2 w7 t2 M2 O0 H8 G& u7 u" f3 W3 z  w.grad.zero_()  
    + G) s6 V5 J$ y( g: @+ ]0 P  p  b.grad.zero_()# M0 H9 y( ^' |8 ?( L

    3 I+ L4 c+ e9 d7 N& O* gprint(w.item(),b.item()) #结果
    ) V3 m' [( f( Q7 j4 ^/ b* t  T7 L* z* }7 k2 G4 ]: q8 [; M
    Output: 27.26387596130371  0.4974517822265625
    1 n  y! B9 w! S; m----------------------------------------------$ \2 i; M; c; |! E' d- ?
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。( a7 W- B+ B1 I% h4 d% E8 c7 Z
    高手们帮看看是神马原因?) v5 q' Z2 n) B' F) P; r% f2 {

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 4 k+ r' L% f; Z0 |

    , t$ M& `8 c0 f没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    + |- ~5 @# R1 A" U- o% \+ l; n-------
    " e/ Z- d( B1 @: }. z' s不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    3 X' B0 t. }0 |* m  h-------! U' k3 y( f* z+ u  o. ]  S0 g' M
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    + q: u5 z5 b9 p; X没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    % p1 Y' l8 u* b: R) G-------
    / \6 M5 i/ T+ [) a6 h不好意思, ...
    / m2 \, R7 T7 v5 C
    谢谢,算法应该没问题,就是最简单的线性回归。
    & A9 }& n" d5 d8 Q4 `& `我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    " D+ x9 C% P9 `: Y+ Q1 `- U7 a
    雷达 发表于 2023-2-14 21:52) @% W) f4 e. u
    谢谢,算法应该没问题,就是最简单的线性回归。
    ' R* p8 O# V5 t/ \- A/ ~我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    7 V: n3 q3 [3 }& P- h
    ( N7 x' k+ h: ~' x! F2 D% e
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。) G0 z' T+ X, A- ~3 f
    ' D% J8 _1 ]3 W& x; l
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑 : c' Q, |  g! ~6 H
    老福 发表于 2023-2-14 22:00
    1 W9 e+ j; K, @; Q) i0 e6 v/ x刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。& g- I9 k$ c  E6 Z" K# G8 X

    ) @/ R$ ]. V: x或者把b但的起点改为1试试。 ...
    5 [$ `- U0 t9 `4 I2 f3 ?6 {  h

    ) h# X* v/ z& O$ T你是对的。! v% N: O0 I! f& O3 @3 w7 v4 Q
    去掉了随机部分
    & r9 t; M1 s! {# B# F#y = (x*27+15+random.randint(-2,3)).reshape(-1)9 _# O  x1 R, S3 u  v  q; O, s' T
    y = (x*27+15).reshape(-1)
    & @+ z+ r3 h$ |7 u+ V
    1 V) g! o- p* d6 E. t循环次数加成10倍,就看到 b 收敛了
    " h* y( Z; V& v: E' @w , b
    / W+ j4 K) g  ]27.002620697021484 14.826167106628418
    1 w+ [( L3 P$ ], [$ F
    - ?0 @! F( {; H  _和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-5-30 10:59 , Processed in 0.057443 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表