设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 1780|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 / K2 @5 r( I1 U& N

    : h+ K; y6 m2 J9 p3 }为预防老年痴呆,时不时学点新东东玩一玩。
    # ?) N0 g: A" @- X& `/ {. PPytorch 下面的代码做最简单的一元线性回归:
    ! L6 e! h: Y+ z$ R+ w% A----------------------------------------------1 r. F" ~% T+ W" Q0 B
    import torch
    : Z" v7 ]. J; Aimport numpy as np5 b  a" j4 H6 S( L
    import matplotlib.pyplot as plt) B2 v( U1 [! ~$ b4 I" z/ z
    import random1 \6 r4 d1 t% o7 I

    + E, ?; T/ @$ K* ix = torch.tensor(np.arange(1,100,1))
    & Q9 h3 R7 f6 i( [; n; ky = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=155 F& A; i# y3 B, m) e1 }

    : y; ~" ?8 g* r$ r! U- Y6 s4 Zw = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b( z: V7 f0 m% z% K
    b = torch.tensor(0.,requires_grad=True)
    * _: U4 ]5 k. _% k
    " o. j0 |4 z, x6 B: s) Uepochs = 100
    & s9 d2 [# k  g. k9 }2 J( ]. u; p$ e; r1 |
    losses = []
    $ Z; b* K8 z# p) xfor i in range(epochs):% U9 @, p' ?/ }4 T8 f& p7 }9 ^  e
      y_pred = (x*w+b)    # 预测
    , m* I- w3 v) u8 |3 q. z  y_pred.reshape(-1)5 U1 G: H! S7 @+ l% E

    * k1 B0 T" G* u9 V6 A  loss = torch.square(y_pred - y).mean()   #计算 loss; F9 k5 t6 i+ V7 e8 Z+ R% U
      losses.append(loss)* t9 T' y3 i4 Y
      
    ( L/ X, c) g0 `& f  loss.backward() # autograd
    ; O2 Y. ~+ G0 ]4 u. J& F  with torch.no_grad():, n2 q2 C8 o" ~/ v1 V
        w  -= w.grad*0.0001   # 回归 w: W; @* K6 ~+ E. u8 C+ ?
        b  -= b.grad*0.0001    # 回归 b ; x5 y' c5 g* y3 q# x# D% q/ I
      w.grad.zero_()  ) r8 w2 g* G( X6 x3 Z+ E
      b.grad.zero_()4 ^. C" O7 C' e
    / P/ {4 \- D/ z- f  g' O4 d6 V
    print(w.item(),b.item()) #结果
    # R9 |5 o! o/ S0 K
    2 f! y: i& @8 r3 J( I& y0 tOutput: 27.26387596130371  0.4974517822265625
      B! J* |0 m. p' I6 s: n----------------------------------------------
    ! A* o) Y: }6 g) R& H最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    7 {+ o6 y7 R- z+ y4 Z# b高手们帮看看是神马原因?$ C  M' v* |2 C3 Z3 h) O, x

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    ( |  m; Q7 W$ ~- J1 N$ v9 d$ H: q- z% N. s1 C8 E" Y
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?6 m( ^8 A: f5 E+ w
    -------3 C0 h7 B  Z. q( q1 ~) B
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    + L" X, T/ Y% C5 t. S% Q-------
    # @$ W5 T# \, t3 i' M8 d算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
      \& I: N4 F- i4 h3 @7 i' d没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?$ y: [+ c$ y: j: T* l" W  E0 m
    -------6 k% D/ c+ Y% p" F  d( u
    不好意思, ...
    # W/ T# W8 X6 q6 \
    谢谢,算法应该没问题,就是最简单的线性回归。2 w4 v. p+ _8 Z1 f! ?2 G! w
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 4 c$ s- F8 w8 R
    雷达 发表于 2023-2-14 21:52; K; _0 U/ r/ r6 b* y1 i; e
    谢谢,算法应该没问题,就是最简单的线性回归。3 z1 U7 A  ]/ T$ I8 x9 T/ O
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    7 U. l$ n# N8 D7 h# O

    5 }' R' m% B. }刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    ( o: i8 m* Y7 d7 t4 K8 X5 C) e8 m5 j9 ^& O8 j7 j7 R
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    8 T% V  I2 }/ _. o* U6 T6 p
    老福 发表于 2023-2-14 22:005 F2 z! l' T( s. v9 r
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    " a7 P4 C7 a& F. ~. [, n- V' y6 y% E* U. I2 r3 |: E  g8 ]
    或者把b但的起点改为1试试。 ...

    # M; C8 o7 b# f# q: {3 ^. b* n. i  ]( Z
    你是对的。* R4 O5 g8 {- v  w
    去掉了随机部分
    2 I0 \) |! Y9 _! B& B7 g# R#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    " p. e* |) C3 S( Y1 by = (x*27+15).reshape(-1)+ v' V) h5 \% z0 b8 J" m6 a9 W  @9 U* g
    2 }! H' _; P0 _
    循环次数加成10倍,就看到 b 收敛了& O0 y$ B; E7 G
    w , b* d( D8 L4 m0 t# q
    27.002620697021484 14.826167106628418; Z  n5 `: P, I1 C& ~% U
    & \: _4 D4 p8 X
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2025-7-16 14:18 , Processed in 0.035973 second(s), 21 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表