设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2568|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 : p+ y0 {+ w. O' p$ o  I

    4 W3 b3 M. z$ o: C为预防老年痴呆,时不时学点新东东玩一玩。% n2 E. O% k& O
    Pytorch 下面的代码做最简单的一元线性回归:
    . d' }4 V: z, H, l% L+ C----------------------------------------------3 K/ B$ r. u; L* v5 N* P
    import torch
    5 _% S( ~+ _& w2 z% Simport numpy as np
    ( i. W, v. t: P( Pimport matplotlib.pyplot as plt
    / w- h. t) c6 T3 Pimport random
    : h: l1 o, U( S$ _
    / H8 S" x- k+ Qx = torch.tensor(np.arange(1,100,1))# L, Y  @2 Y2 G5 L7 Y, s% d
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    $ o- j  \$ {! y9 ?; w( r. D6 c* l, E9 n' X8 E  |
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b- h8 y. N7 N# w  x
    b = torch.tensor(0.,requires_grad=True)
    - E; z) e$ A5 W% t) Z
    4 [" n! ]9 [' L, ?: ?epochs = 100
    3 [6 y$ B6 `# A! _9 Y, J- H- y6 c& q( l' `3 G4 E8 O
    losses = []. }' P, u4 P6 h0 j/ [
    for i in range(epochs):! Y4 P  }# E  }8 b
      y_pred = (x*w+b)    # 预测
    . r' j6 M5 k/ |- f! M  y_pred.reshape(-1)
    + d& C2 T/ ^( Z  {* d
    / D5 j# H* x5 U6 z9 b  loss = torch.square(y_pred - y).mean()   #计算 loss( Y4 N9 ~- Y; k5 s3 r* o
      losses.append(loss)' i/ I' l5 R  X  [1 w4 z! F7 E$ I
      0 l, ]# Z+ G1 h
      loss.backward() # autograd
    8 R/ x0 L+ [* H9 \6 G  with torch.no_grad():
    % h+ u  W9 B4 Q; J, n    w  -= w.grad*0.0001   # 回归 w
    . K. [. o' y- t  [5 ]7 K    b  -= b.grad*0.0001    # 回归 b 7 @) a0 m( L1 o  |1 \: D
      w.grad.zero_()  
    # N8 d7 n# @8 Q1 U$ |: s1 L5 c' c  o  b.grad.zero_()
    / C/ q# s7 N% r- n* S+ b! ]! B$ V; _3 J
    print(w.item(),b.item()) #结果" d1 Z7 Z7 G6 e" C  x$ T6 k

    $ ~+ W( j5 Y. WOutput: 27.26387596130371  0.4974517822265625( b$ E0 Y8 Z& L( o5 W* N- y9 Y
    ----------------------------------------------# g, ~1 q% c# V- J  F! c! o' C* @4 ~
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。: J# Z% A% O# n5 n$ u! O7 t
    高手们帮看看是神马原因?
    ' b4 w* F  T5 X$ f

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    , ^4 I, p8 {5 {7 j- J( G6 b: O; I0 X6 V3 M
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    # p0 i" \, C& ?-------
    % @; v( y/ S! Q* T9 n0 d不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。2 m) Z9 s6 O! v* Y( w
    -------
    : d+ l! N. T6 J  S4 D7 X算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    # n! T9 Q9 r8 x6 x没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    4 R. i9 c, J, U0 j-------
    " e( D# F( S  V( k+ P, p不好意思, ...

    0 }  W' F, k$ ]( E/ m+ k谢谢,算法应该没问题,就是最简单的线性回归。. G; X  B* ~; S8 q
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 6 n/ z. T+ m( G7 d
    雷达 发表于 2023-2-14 21:529 g  ~/ W+ `( X( R* n+ o
    谢谢,算法应该没问题,就是最简单的线性回归。
    # v: b9 K3 f9 a6 [: r+ s我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    ; N- G$ X/ O% }& N% I

    3 z/ a; u6 K8 V# Q刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。( K7 z7 U+ J8 z7 r4 W
    & K" M  y; a9 @2 `
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    5 q+ e$ E9 o& X9 q& f2 L7 W
    老福 发表于 2023-2-14 22:00
    , ~+ X2 w$ t% u8 P刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    . g7 l  w( H7 l1 y( N$ \3 ~+ G3 e6 r$ d, y" P8 ^! S1 z2 @
    或者把b但的起点改为1试试。 ...
    * h2 k6 b! ^5 a/ W
    $ h$ z3 a6 J! K) u
    你是对的。' `' v; x4 q4 n' e  p
    去掉了随机部分
    & D4 H( \4 ?) A+ ]" q9 g#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    . f5 Z0 Y( m8 B) O8 j" T$ W" ]y = (x*27+15).reshape(-1)
    " _4 x" S, g( S, C3 I3 C7 A
    : X5 \) c$ W8 H' e$ \# l循环次数加成10倍,就看到 b 收敛了' `4 m: O, @% |2 W2 m; K
    w , b
    0 [2 c/ r, J+ Q' M5 Y27.002620697021484 14.826167106628418/ N/ K+ Q" J) o3 E' H3 {

    , w( {3 ]& H3 l1 i& W" v和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-3-9 13:52 , Processed in 0.081103 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表