设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2884|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑   @: i7 J' p; k
    : B) y7 c1 M% _1 g* o
    为预防老年痴呆,时不时学点新东东玩一玩。) s# Z; ]% a% U+ M
    Pytorch 下面的代码做最简单的一元线性回归:
    # W1 [- t! b7 B6 |& T----------------------------------------------8 N7 m* m# I5 `* T3 P- Y, ?, j4 r1 ~
    import torch+ h4 I# q8 s( v" b! M) S
    import numpy as np
    + e  w: E! k' R% b7 `& Cimport matplotlib.pyplot as plt. V7 n7 ?  u2 Y. |
    import random" Z% ?5 M/ N" }. r0 X

    / h1 f9 W% [9 m' hx = torch.tensor(np.arange(1,100,1))
    + m0 _* y/ k* b4 F& By = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    4 _5 y6 Z, V, g* X& J0 w, v
    6 \8 N5 C* B) @7 v7 |' k4 Y- vw = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    $ g* \' Z/ |( t3 Jb = torch.tensor(0.,requires_grad=True)
    : x8 _0 `- V0 t2 i; L
      C: f% ^8 R1 K6 B( r* depochs = 100
    ; s! Z% g5 c0 z: K, i+ P) M/ u; _# e& ~' z6 b
    losses = []
    . O! C6 Y& ?6 T) S) Ofor i in range(epochs):& @4 k: A1 ~' L2 N7 q- D; _
      y_pred = (x*w+b)    # 预测2 t- E  q. W& V" D1 |
      y_pred.reshape(-1)  i0 H/ f. G' Z
    . g! Z/ K$ f' z8 ~- k' n+ C. o
      loss = torch.square(y_pred - y).mean()   #计算 loss
    # ^2 |# W# f5 K  losses.append(loss)
    1 H) R: G# u7 H- ]- D0 v  
    ! D# ?# N1 D* U$ J6 e  loss.backward() # autograd4 e5 O7 _4 `' {9 ?& L) |
      with torch.no_grad():" W6 k8 T, K5 {0 C1 f3 R# N( ^
        w  -= w.grad*0.0001   # 回归 w  J& ^8 B+ x7 I- r# E
        b  -= b.grad*0.0001    # 回归 b
    ' M- {! b. z2 p; `6 h  w.grad.zero_()  
    2 ?1 |* u' w/ H2 Q7 o7 ?$ @' `  b.grad.zero_()
    8 T; L" M7 I2 i0 _) v  M! W/ ]+ W' O% q! V0 l
    print(w.item(),b.item()) #结果
    ) x' x' {' H: b0 w* u' W+ C9 K4 A& K$ A3 a2 l: g" P" |
    Output: 27.26387596130371  0.4974517822265625$ }6 b* c9 {! `" B( n* S
    ----------------------------------------------6 ^2 e6 I9 O1 B/ b4 x  {
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。7 |! J! C8 \4 [8 G
    高手们帮看看是神马原因?
    % D) g+ U$ a* f- P1 _

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    7 T$ u. W; O1 [
    , M) p% u% F4 K! J5 |没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    ; V3 L/ Y. Y: K, n* F( E-------$ Q: R  n: `) `# B2 m. w
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    - ^4 S  x6 k8 z-------
    ' z9 J: [3 n( Q, d  [2 x% j. B% {算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    5 D6 j; B% B- Q" i# _& M' \没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    & h! ?( f+ K+ m$ d! A- x-------
    7 c# O% y" m. f- o' L不好意思, ...

    3 d( u! |& t4 ]. m! O; s* O谢谢,算法应该没问题,就是最简单的线性回归。
    . o$ `7 ^0 L7 M" u1 z我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 ' D) [9 d* [  b
    雷达 发表于 2023-2-14 21:52
    / q, O4 {' u6 k( o' j5 t谢谢,算法应该没问题,就是最简单的线性回归。
    0 ?- d/ C1 {* O0 E; I3 d' W我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    % {6 t2 t2 Z4 d$ [7 h

    7 S. N% B1 q2 B' H& r  y& Y刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。7 Z/ s7 G0 F. g7 E
    1 E0 k$ L; U: ~& b- G
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    3 e+ {3 ]- b2 X
    老福 发表于 2023-2-14 22:00, D+ r9 ]. Q$ [* k0 x
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。% A8 J/ u/ ~3 R$ d' C, H+ }; d& O
    2 n& b0 W. w3 W7 {1 H( k
    或者把b但的起点改为1试试。 ...
    ) q. A5 B0 l4 x

    % ~0 |& V5 _0 e8 S& ~' [你是对的。
    / P1 M7 O4 z/ u0 d" A+ C  \" N去掉了随机部分
    7 f; A8 s, A( D) b#y = (x*27+15+random.randint(-2,3)).reshape(-1)6 D5 k7 H5 V( Z3 y* u6 s
    y = (x*27+15).reshape(-1)- f* w3 y% l; X; J( n9 ?
    0 F( O' `% n! l: H$ k* C
    循环次数加成10倍,就看到 b 收敛了5 Z7 T; m% y& J
    w , b4 h# y- x' n' C) D' i3 `$ }2 m
    27.002620697021484 14.826167106628418' o9 A, {8 Y# y% r
    " I! y4 J* S0 W1 ]' [6 ]) h
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-5-14 02:20 , Processed in 0.092853 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表