设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 3090|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 + l& R9 M7 ]: s. i" h/ |
    ) }* m9 A8 ]5 K9 ]
    为预防老年痴呆,时不时学点新东东玩一玩。1 L" r) n; C( C1 d+ L
    Pytorch 下面的代码做最简单的一元线性回归:2 y" Y) E& S* ?/ j
    ----------------------------------------------! _7 q! @0 {& y! d- C  }
    import torch
    2 \8 B# Z/ l* dimport numpy as np
    ! M7 u; L" i$ x4 Kimport matplotlib.pyplot as plt
    5 F, P4 I: a; `* }3 Z, D7 Y. iimport random& m* o" ?, a8 H; `
    - Z3 P2 @) c3 O6 [3 E5 f/ x% A: G
    x = torch.tensor(np.arange(1,100,1))0 p- C2 j' I; s( Z9 ^8 h$ n
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    ! X7 @3 U* B% L% Z7 G  l. d; n' u  `" h  F" G
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    # z+ Y' O1 H' s4 {! Lb = torch.tensor(0.,requires_grad=True), m; b+ j! D  y9 `$ }9 G
    $ S, X0 I, @8 s5 ?
    epochs = 100
    0 N  E$ x! K# _& k5 G4 o% \+ f! s1 Q7 p
    losses = []
    & N, x! x4 ~* h% U& Afor i in range(epochs):, j- l! a( Z$ k4 f
      y_pred = (x*w+b)    # 预测# _8 S& Y7 B- \
      y_pred.reshape(-1)
    # |' `8 R+ J  j' j' {% u0 z# S ) L- O/ a9 P1 T4 t2 p, N
      loss = torch.square(y_pred - y).mean()   #计算 loss" O$ a" F" {8 o' r$ D8 j$ e: C9 k
      losses.append(loss)
    & c+ U5 ^# R' \. l) g  
    % b  E% h# C, `- c1 u9 b# R( o  loss.backward() # autograd
    7 j7 ?3 G; T( Z3 @! M9 Y  with torch.no_grad():
    . @' w+ G' N1 ?4 Y    w  -= w.grad*0.0001   # 回归 w0 A0 }% i% {9 Z2 |( c# d; t
        b  -= b.grad*0.0001    # 回归 b 6 F3 V8 ?: m: M, N
      w.grad.zero_()  
    3 k" J) d, n. j  b.grad.zero_()0 l0 Q' H" _% d/ q7 g
    ; j* V- i5 ~7 s6 _( F# }
    print(w.item(),b.item()) #结果$ M, z8 Q, Q# p0 k9 Z. m1 A9 i2 y! a' O

    % H) n9 f% K; X3 ^4 m% ZOutput: 27.26387596130371  0.4974517822265625- Q& L) g6 t% ~+ N. g! n. r' \
    ----------------------------------------------
    2 l9 Q+ D) f6 I最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    ! m' d/ T+ D4 ^高手们帮看看是神马原因?0 d9 {: N" f3 f

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 , a, ^. h4 n* ?" F+ A" ]. d' L
    3 H/ D9 Z. u4 e9 L) H; X1 i
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    ! v9 S) \( S* l- t. |-------0 U' e! l9 d9 Q+ T5 n  {
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    ) U+ Z' W# l' r3 C-------
    8 l" V& C3 f" w- \算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    , t, g: |; `: v没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    ) t$ h# w) ~* z-------
    ; S/ x) h3 C! `- X8 v不好意思, ...
    : [) {2 @7 A$ ?+ p
    谢谢,算法应该没问题,就是最简单的线性回归。% [) t( j# O, B& n3 c) V3 E, ]
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 . e7 Y; F4 u. S: T2 [
    雷达 发表于 2023-2-14 21:527 {. v0 z4 v0 x& v/ J
    谢谢,算法应该没问题,就是最简单的线性回归。1 Y; }% {$ J! L( A5 |1 N
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    ; k, ^$ t# b2 T5 n5 _; s) W# }6 t$ \) e( Q! P
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。0 X- o, w* \% L( @1 K
    * v$ ^" v$ m$ e5 l3 O3 y+ I- W
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    $ m# g6 e) q, J7 v$ M. Z/ d5 E1 y
    老福 发表于 2023-2-14 22:00; Q) D: b- h5 ~8 y8 q& {1 c0 I
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。8 H, _9 s, {+ g3 X
    $ U, c9 O* T- l+ z; O8 }; k
    或者把b但的起点改为1试试。 ...

    ' ?' a' J- ?! K; T, \5 ?8 z4 Z  h8 A" ?
    你是对的。
    5 v* o5 g* F6 g. P/ A" T去掉了随机部分
    7 b! k, Z2 g$ }6 l* e; z* N2 e#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    + ~7 V2 C: D; Z( d" |0 j* ]& u4 By = (x*27+15).reshape(-1)
    7 i+ X4 ], B# \
    : T/ I+ D+ r9 R' \* e' l6 G/ E, y循环次数加成10倍,就看到 b 收敛了
    1 {: M2 i% ]# o5 Cw , b
    9 J* u# ^; y7 L$ S$ s; T27.002620697021484 14.8261671066284182 @6 n6 F8 {2 K+ w

    7 Q4 V+ T2 S% Z和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-6-13 18:35 , Processed in 0.056479 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表