设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2906|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    " |6 q# p4 q  U1 w$ z
    ) c$ Y2 v. Q' Z6 n( c为预防老年痴呆,时不时学点新东东玩一玩。
    2 q& _5 v% y/ i% _Pytorch 下面的代码做最简单的一元线性回归:
    6 I: ^" r2 o7 I----------------------------------------------" g% p9 Q3 _, v' d
    import torch
    " X( v3 R% r+ q4 z) m, Y- Dimport numpy as np' r) U8 |+ x( N: l. j: ]  T
    import matplotlib.pyplot as plt- O+ L1 @( H0 ~
    import random5 b% T0 @: V& T3 m8 g# V$ g1 M" W" H

    ! R) m& m3 V3 Z% Cx = torch.tensor(np.arange(1,100,1))8 d- Z9 W+ ?* N* D3 H" \
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    - r$ q9 ^0 w! H* o6 `+ y- F* C9 g$ D) s3 l- f+ f9 \. |
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    2 r3 m' h8 `+ D) `) Y+ Mb = torch.tensor(0.,requires_grad=True)& j2 }( \$ u* ^/ R7 u2 Q) C2 d

    ' s' |2 h2 U1 a$ bepochs = 100. S6 w* R$ J6 n0 i" k) ^# u1 @/ v

    3 H% l2 n* D8 a; |: J# l$ c% zlosses = []
    2 H' P2 U3 o/ K- m1 P# q( ~& }9 N; tfor i in range(epochs):( O, b$ q, m% l1 Y6 C3 h
      y_pred = (x*w+b)    # 预测  U+ \0 |1 ~8 ~" O
      y_pred.reshape(-1)1 k' u7 _4 b/ ^  c- o/ M
    ) k3 P& N" e3 g/ D
      loss = torch.square(y_pred - y).mean()   #计算 loss
    3 T' D% f5 b0 P  losses.append(loss)) X4 E8 |( a+ d+ J' V
      
    0 w) ^9 F- |& m! ^. g7 E) t4 |  loss.backward() # autograd
      W& p: G9 X$ F  with torch.no_grad():
    % `$ {# G: h! w2 p& j# N9 g    w  -= w.grad*0.0001   # 回归 w' A9 [  z, `+ D6 p* Q
        b  -= b.grad*0.0001    # 回归 b ! I- m0 d; ^) I
      w.grad.zero_()  8 V3 s' y, J+ W: M9 o6 _, _. s1 ~, E
      b.grad.zero_()
    & W# V" f2 q7 G) J0 M0 W5 J2 l  A: P
    ( p" n6 @4 L+ l  e: ?3 }7 Kprint(w.item(),b.item()) #结果* l, x) F# H. J! u4 l
    , O9 h4 {( |" F' q% m5 |) U
    Output: 27.26387596130371  0.4974517822265625
    . C1 e+ P# @# s% ]----------------------------------------------
    ; z' C9 Z* o1 q7 T. x最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    ( _4 ~8 I: I! L高手们帮看看是神马原因?9 f: `1 F) b/ V6 p6 K& H

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 6 m$ |  G2 q5 G# u1 Y* S- Z
    1 b' o: W( V4 E  c5 V9 `
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?& T/ i! B: D3 w4 e, q
    -------+ S9 h" e: N4 A% |9 \: @5 j7 q: O
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。/ c0 E3 u! E% t: _! R/ r
    -------
    / I: j4 m% I* S2 Y算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23) r# C+ _( A* o1 |8 {8 i
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?# b8 \" `  J9 w, E6 ?
    -------
    ( U% i+ f( |' M1 @) O  J不好意思, ...

    + U( \* N% U6 G0 N& e6 Z谢谢,算法应该没问题,就是最简单的线性回归。  Q& l6 V9 u5 A: `
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    ) w% C/ l6 p# }
    雷达 发表于 2023-2-14 21:52' ^0 E* l# L" Z/ l2 T
    谢谢,算法应该没问题,就是最简单的线性回归。+ y+ ]3 ]# V; k! O% ~
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    6 ~* a' j! z& c( Y' j
    * a" B4 ?6 R1 \+ a刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    7 \: \& M$ \4 _# c. @' K. C
    5 {& V2 B; b3 U1 E6 g. b. B或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    - L( m! e6 G7 Z0 T7 ~
    老福 发表于 2023-2-14 22:00
    : n% [3 \) m: ~% z( \4 }" v3 ~) N刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。/ {* Z& M2 A  R, {: R
    ' K8 x7 s2 L1 V9 G4 G( H# R6 Q* r
    或者把b但的起点改为1试试。 ...

    6 a' E* c4 T  k! ^+ \# x3 j3 u/ j# J2 s& z+ ~! ~; g
    你是对的。/ P# O: O* ^$ k% X1 ^# K8 W
    去掉了随机部分# [4 e! F# o, h# m, P; B9 |
    #y = (x*27+15+random.randint(-2,3)).reshape(-1)9 b$ p9 e. r& Q( Y* T9 h3 E
    y = (x*27+15).reshape(-1)
    $ e& A) Z  x' {: }) W$ A2 }7 L5 I0 a! Z
    循环次数加成10倍,就看到 b 收敛了
    & O' Z. R( Y# Y: z2 ?, P9 l6 C: Dw , b
    ' O. k8 X% ?/ b8 l27.002620697021484 14.826167106628418- G1 ~+ n9 M& z, u9 Q& ^9 r

    1 E0 @6 \9 I# v7 s/ y3 s和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-5-18 10:01 , Processed in 0.068685 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表