设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2067|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    8 c( }7 j4 N/ C. D2 O. F0 B4 c3 l# b. @& ?1 X0 |3 Q+ s
    为预防老年痴呆,时不时学点新东东玩一玩。
    3 v& H6 C/ d) a6 W! P7 l4 QPytorch 下面的代码做最简单的一元线性回归:
    " A6 m! h) B. f! O6 t% b4 w----------------------------------------------3 `6 a3 F: {+ Y/ f
    import torch
    7 u% b  U/ N7 e+ u: Dimport numpy as np& P5 ^8 e9 K' w- `- h; X, a
    import matplotlib.pyplot as plt5 b  X( n9 Y% F
    import random  B! C  W6 V& M: T& J& X
    . d4 g/ J4 A6 W6 w  g/ {4 L
    x = torch.tensor(np.arange(1,100,1))1 T# f6 z/ R9 o' l$ B
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    2 F% d2 I$ ~6 l9 |- \. T, d" c$ `) z  d
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    5 u5 y+ h' Z0 V" Qb = torch.tensor(0.,requires_grad=True)
    : d6 f. P7 n% `* p9 d, a
      L- l& K5 a" S0 cepochs = 1008 V; t" ^, w4 ~1 [( W+ q, b8 u9 F

    ) r& L! j* D  b& y, u8 wlosses = []
    2 k( x! J7 I  a9 X8 Sfor i in range(epochs):
    3 k% Y7 m! f9 [; n  |  y_pred = (x*w+b)    # 预测
    ) q; x7 A3 U4 m& U  y_pred.reshape(-1)
    " M/ X0 R& P) X5 Q1 E
    - s# C% ~/ r  d( i2 l  loss = torch.square(y_pred - y).mean()   #计算 loss
    6 P2 O# k- l6 }- d" ^  losses.append(loss)
    & k* x8 H& S) U# l5 U  # P) c9 j3 j9 l7 u: b# y
      loss.backward() # autograd
    % y6 j, B7 v9 l2 ]  with torch.no_grad():
    , u6 I+ F. C+ x- q/ T    w  -= w.grad*0.0001   # 回归 w
    - P2 W' |* ]) H& ]+ S    b  -= b.grad*0.0001    # 回归 b
    8 V' x+ \5 Y% J. T/ M  w.grad.zero_()  * h: T% L3 a4 W& n" K2 e
      b.grad.zero_()
    2 y9 w6 ]  g, ]8 L7 M
    ! V  r4 G8 X% v4 Hprint(w.item(),b.item()) #结果" u2 @" _+ J' `# t2 S" L6 Z

    2 \) V  {  Q3 OOutput: 27.26387596130371  0.4974517822265625
    , @$ h$ ~" l6 w6 @0 d----------------------------------------------' b- a) F0 X  v+ {' K5 _
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。2 z( {  t8 ]( K
    高手们帮看看是神马原因?' c0 x- d6 R2 n: t  }

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    # D. f; b) j/ o( Q9 W. B+ x
    $ _7 n0 \6 ]  g5 {0 u  c没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    # f) h" h" x$ {# ^, n' ^; ~# d-------
    0 }# M( |: _1 P& j不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。; R3 E+ t% R$ C! l2 ~
    -------
    ( |( F4 W$ ?0 b0 Z8 Z. E算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23. Z  n9 r/ A: G; l# s: k
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    ) \. y: L7 M/ ]# a" F- Q9 P-------
    - L$ X- P: p& b" I1 A' y不好意思, ...
      p$ J4 R3 ?+ Z/ }6 k: D
    谢谢,算法应该没问题,就是最简单的线性回归。1 [$ Q7 n' C( n$ L" o
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    - e7 D" u' s0 K; T" ^* |
    雷达 发表于 2023-2-14 21:52
    1 @2 G' A8 z& ]% m6 K谢谢,算法应该没问题,就是最简单的线性回归。* F7 m& `' o1 P' v, X2 R
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    9 h- p$ J7 ?1 I2 u

    1 t! H, C1 v9 F7 b9 r' {8 h$ n2 _刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。, c# A% P) T. ^* R

    6 j  d4 v- a2 c6 b* ^6 q或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    , [" {: x( J. j9 s; M" R
    老福 发表于 2023-2-14 22:00
    * P, i6 h5 R3 ]1 Z' Q) K6 p刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。* E! k  u; E6 |/ `8 J8 N
    ) _* A6 o% @) u" ^0 z
    或者把b但的起点改为1试试。 ...

      p3 s$ u9 o4 @% p# n2 u% t6 S0 E5 s# m- D
    你是对的。. D7 j3 a1 v$ s
    去掉了随机部分
    - x9 n( S4 m% g7 T3 a* s#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    3 o4 r# \, @! my = (x*27+15).reshape(-1)
      H/ s7 {% H, i: d5 k8 W
    " _+ \2 {+ {9 u+ c: ?循环次数加成10倍,就看到 b 收敛了
    2 H* W2 Z, g' Nw , b
    7 P/ @" _1 z3 H3 U' e27.002620697021484 14.826167106628418) _! X" Y$ l' i3 Z- k
    4 y3 E+ d9 M( U5 I3 M) T
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2025-11-5 06:55 , Processed in 0.029506 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表