设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2487|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    4 K) S% u) [: [
    + J- X/ c4 S& w, A为预防老年痴呆,时不时学点新东东玩一玩。% g' x- h5 S1 j' B
    Pytorch 下面的代码做最简单的一元线性回归:# F0 _: X7 M; ]0 ~% Z5 r" }; H, O$ r7 w" B
    ----------------------------------------------
    7 m( L' e5 E% {7 h% F4 _- D6 @import torch$ v) a- U, E5 ~
    import numpy as np
    4 q/ R: r4 F3 s% x# S8 R# n- Uimport matplotlib.pyplot as plt
    5 ~  P$ y. ]6 s$ z" [" \4 [import random
    * I2 S& L9 R/ y- ?$ b& L! _9 p1 t& E7 s' m. {3 }0 H
    x = torch.tensor(np.arange(1,100,1))1 h( C7 n1 V* B1 c' p# m
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15" p' q1 R! A1 l# f6 W" _# O

    ' H3 R9 n# `( Mw = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    # j& }: j9 J* k! U% ]& x8 Bb = torch.tensor(0.,requires_grad=True)
    # z. e- j8 v' x
    ; U: {$ B8 R- f) G7 ]8 Xepochs = 100
    3 ^' a/ F4 s+ e7 _& x' X/ d
    ) W$ P! r2 `6 J( E# Q- _* S8 Llosses = []0 y  w$ |* k' P  L% S/ c6 o1 l
    for i in range(epochs):
    ; k* _% `# z! z1 d5 k  y_pred = (x*w+b)    # 预测; ]8 p0 U/ ]5 y" T
      y_pred.reshape(-1)4 x$ q  v5 \( _4 j
    + v% L' ~' |) _
      loss = torch.square(y_pred - y).mean()   #计算 loss9 x, s/ m) b& N7 ~/ W7 H9 l& W: Y
      losses.append(loss)
    , [. a8 O9 \% q+ \2 k( Y  
    . c3 P. \& b, W0 I% _  loss.backward() # autograd0 E! r: ^! ~" }1 a5 a' v
      with torch.no_grad():
    4 ?" M, V* {$ e    w  -= w.grad*0.0001   # 回归 w
    ! T5 p* X8 Y' a' ?    b  -= b.grad*0.0001    # 回归 b
    ' z9 Q/ G7 p8 G; l. L  w.grad.zero_()  - |8 h9 Y0 o" z" t8 [" X; J) q: C
      b.grad.zero_()
    & F- j# X: X" I! D  \+ O5 @* N: N: H) h9 z# g3 h2 S4 F
    print(w.item(),b.item()) #结果7 k% U: H  S, K  C

    / w' ?  E: {: j6 x/ V* U* EOutput: 27.26387596130371  0.4974517822265625
    , }6 X3 W1 K* n0 o8 d; m3 M----------------------------------------------5 }1 l: p5 a/ m( ?) i" F# d- t8 k) l
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。  f) f* ~+ R0 ]5 l& i$ B
    高手们帮看看是神马原因?
    0 R4 b9 I# X6 }0 u. \

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 9 G: x, c$ |; t' J0 _/ L: x! a

    ; Y$ T% q* J! o0 X. L/ U没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?# ~* ^' q2 _3 K
    -------
    & ~/ w1 B- d! p+ y4 T# J不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。+ D( k; q7 ~2 W  L9 T
    -------
    % q, c- D; E1 i/ s% ~; A算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23; z+ N. `4 R7 v
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?# |' _, C5 G( N! q4 s
    -------
    * D9 `' j9 {. m5 a  r不好意思, ...

    & @" M* n' A- |0 R, W0 i, Z, L谢谢,算法应该没问题,就是最简单的线性回归。
    . L0 O) u4 g4 P" V$ u- q, X我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    + T0 h* G7 C" ~3 {/ d/ S8 {
    雷达 发表于 2023-2-14 21:52
    . S2 p5 X, E2 k' o谢谢,算法应该没问题,就是最简单的线性回归。
    # J5 C1 v+ a% o0 y! d( L我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    & @! Y5 O: s1 U1 \: E

    3 ^5 H+ t0 K+ E: u刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。5 O" |0 N+ Y" z7 C8 C5 V+ ~$ }

    % R$ w4 H5 X, e7 q" ~3 G2 }1 Y或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑 . N  [( S; `" D
    老福 发表于 2023-2-14 22:00
    ' A& C" }0 N% A, r刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。- J% m# U# S) d+ D. v' F1 B

    * [( E+ m  g2 @1 Z0 q( ^; J或者把b但的起点改为1试试。 ...
    0 j9 A( Z5 g& C. v- d
    0 y, Y* O" E/ |+ E) I
    你是对的。
    * x, b2 ~! I/ q6 w去掉了随机部分% k3 ^$ e. h. |2 ]7 Z
    #y = (x*27+15+random.randint(-2,3)).reshape(-1)& ~) \2 A* k  J- |
    y = (x*27+15).reshape(-1)+ G' o* }5 i7 T  c$ C
    " D. B! t2 v0 g5 E# W" ~
    循环次数加成10倍,就看到 b 收敛了
    $ M; d+ g5 K% v* W8 Y9 r( Ow , b
    # b* W7 B* S& r- M. Q" X- y9 ~. _27.002620697021484 14.826167106628418. \- k5 v) {9 @0 V* @* u: }

    1 S& M- T: n- C( p9 c9 @1 i7 U和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-2-22 00:42 , Processed in 0.057172 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表