设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2101|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 : g/ o& b3 E2 @8 c, F) G8 o" E2 n) @
    4 K8 U+ Y$ ]1 g5 `5 B7 T
    为预防老年痴呆,时不时学点新东东玩一玩。' p& H7 b& E6 K
    Pytorch 下面的代码做最简单的一元线性回归:
    5 K) `5 s6 z+ }+ l5 A----------------------------------------------# B5 ^$ M% g2 ?" g5 q2 O" p
    import torch
    # [2 W" A& i% W3 V8 p$ W& O- R" z5 Y/ dimport numpy as np
    ' H" e0 J3 c. A' O1 Q2 b% dimport matplotlib.pyplot as plt0 {3 P. ~# T4 q
    import random5 k& E+ K  l/ p2 h/ U, S3 B- b  _
    " n! b" v/ c/ f. b: E$ A
    x = torch.tensor(np.arange(1,100,1))7 |- y. u5 b5 q5 z4 P+ R- U
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    4 t: u; v. @* G
    $ S. x) R8 s+ W% Kw = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    4 {# K) Q' g( A6 z6 P) Z$ T) ?b = torch.tensor(0.,requires_grad=True)
    : K7 }/ j  J6 }' j- }
    ) c- v, i/ f" O( t& V) Gepochs = 100
    4 b* N* @. _$ p8 s+ X. T! z3 [, ?) Q+ W' c# K: y
    losses = []
    8 y* M& T. E: Zfor i in range(epochs):
      v* s7 N8 ~) ], L8 ~  y_pred = (x*w+b)    # 预测( V2 y/ u, @( {( ~. J
      y_pred.reshape(-1)
    ; M3 _  v( a4 F, r- i5 v
    ) q# _: r7 ^9 m% G: K  loss = torch.square(y_pred - y).mean()   #计算 loss
    $ [9 A8 n& H4 |  D  losses.append(loss): u; u( L  c0 M, s
      
    7 e( ]/ B% n! F$ d1 L  loss.backward() # autograd
    - ^% K( @3 D. z' a# x; t  with torch.no_grad():
    ; D+ x8 o- f$ k0 m1 Z  a' m% a    w  -= w.grad*0.0001   # 回归 w
    3 Q8 S9 A+ N3 u( o" J    b  -= b.grad*0.0001    # 回归 b
    $ ^7 V1 c* |, g. E  w.grad.zero_()  5 w! d  ]+ X* B+ y
      b.grad.zero_()# s% h# u5 K: t# k" h9 J1 b

    9 c; F, H/ h9 bprint(w.item(),b.item()) #结果
    ( T' b' I' I; R$ c% W# H
    4 ~$ C# c2 ?* bOutput: 27.26387596130371  0.4974517822265625
    / n8 a; s# C4 F" O2 @. g3 I----------------------------------------------) c0 ~3 O) b* Z6 y5 ]
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。) P7 H9 S1 A1 Z  R+ d3 Z+ L
    高手们帮看看是神马原因?" E6 F& Y7 U7 {' _, C0 I

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
      ]9 w. g; O/ v- [: ?% r- d) h# ^( A, G3 O& l
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?9 ?" K- `& P' [' x
    -------
    8 `2 ^. u. @4 s- z  q/ V不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    % W+ q2 V& @4 a3 K. J-------
    ! O: H5 O$ x$ @( O& S6 m8 C算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    0 K% b3 w! d& `1 x没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?, P8 w0 Y" _/ I, a% F4 K) Q
    -------: k  r9 ?+ o; ?) k  @! [0 ]! O, u
    不好意思, ...

    ; K* e$ r8 A9 x- o9 x+ `谢谢,算法应该没问题,就是最简单的线性回归。( M5 F1 `2 H6 f0 p1 w
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 ! a3 ~2 ?3 r3 f8 V  B) k; f
    雷达 发表于 2023-2-14 21:52  b& Y& x  l% A7 b- K8 p9 i! A# J
    谢谢,算法应该没问题,就是最简单的线性回归。/ _/ e: R( l& R
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    - _3 t4 x8 z& T5 @
    1 ~. Y& I- d; M9 C3 \% t刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    1 P3 F: N: p. {6 b
    ' a6 n3 O' F8 d或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    " h7 @7 f4 @* f& D
    老福 发表于 2023-2-14 22:00+ U; X  f& E" l+ u8 `
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    ! ]4 ^# x8 W8 G9 C
    + ~9 t% @( H% c! r: n0 Y或者把b但的起点改为1试试。 ...
    * E7 y% D8 |/ v* a/ Q) c; l

    : }# Y# n( T6 q你是对的。
    & m+ {. B& d+ Q* ?去掉了随机部分% V! l, l4 N: ~* Y
    #y = (x*27+15+random.randint(-2,3)).reshape(-1)) S4 N) ?& z) u8 {& z+ f, h
    y = (x*27+15).reshape(-1), g+ P- b6 P. j' g9 _7 a; Q' A

    ) ]3 A, M0 ~8 s0 X9 L5 o& X循环次数加成10倍,就看到 b 收敛了$ `) g2 _' H; N3 \& }
    w , b3 b5 U2 b* A/ c6 w. V4 I& M1 f1 }
    27.002620697021484 14.826167106628418
    ! N& D( M  a& [+ k- d$ o+ Y9 T5 ?! ^9 M
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2025-11-15 15:41 , Processed in 0.037870 second(s), 21 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表