设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 3000|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 7 c- i' K% ~( q2 j' R8 X, v2 V( A: c

    . R6 J9 ]6 ]6 a4 }为预防老年痴呆,时不时学点新东东玩一玩。
    ) ?' I  k0 C' BPytorch 下面的代码做最简单的一元线性回归:' c, A/ u6 C5 x
    ----------------------------------------------
    + X" Y3 `/ V4 t9 ^3 \/ C; Dimport torch
    1 h$ e& L( g, u: e5 S9 f% Limport numpy as np$ r5 q, D8 i% V) i) K: N
    import matplotlib.pyplot as plt
    * \# T6 W  b7 E. s8 Simport random5 G0 }) U' m) q4 y* c+ {# q

    ( e8 t! S, F( Q1 qx = torch.tensor(np.arange(1,100,1)); N  P: O: F; Y9 t% H2 M0 i
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    9 D' X$ W0 y& h1 o- \$ {- y" |# U4 K# T. H  F' Y. G
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b: ?  N& G9 z, ^+ V4 I7 N0 U
    b = torch.tensor(0.,requires_grad=True)# E1 {1 C$ l- ?6 _; v7 B
    # O6 A  N. }' h, D
    epochs = 100
    7 ]+ H: C) X8 j2 J* I" a, P' k6 R% f' j2 x
    losses = []9 W2 A; X% _# \* G/ [
    for i in range(epochs):! Y8 J' \2 O5 d9 q  I1 b6 K
      y_pred = (x*w+b)    # 预测; F# R- H) o* M: c% m+ p& M/ i
      y_pred.reshape(-1)- A# h$ M, y$ A( I0 d

    , f$ S9 z( o9 ^5 J  loss = torch.square(y_pred - y).mean()   #计算 loss
    ' h2 q! Y2 s$ g( H& t3 |  losses.append(loss)1 k" M, A& h8 T
      4 S! V: e$ T2 J! \, X5 t
      loss.backward() # autograd
    + a9 Y- b9 h; Y: j  with torch.no_grad():
    4 M# l/ q- D. ?6 l9 {    w  -= w.grad*0.0001   # 回归 w3 S3 F4 H4 C! V. B2 ~1 x! E/ N
        b  -= b.grad*0.0001    # 回归 b 5 c, V" ]5 x. m+ [8 ]& G1 |
      w.grad.zero_()  - q( J1 u+ v% F7 O
      b.grad.zero_()9 C/ N4 ], @+ Q* H( f
    ( g' @9 J5 J! Z2 {
    print(w.item(),b.item()) #结果
    / r/ s5 ^9 ?. {
    . A$ v1 u) @6 `; t# l- VOutput: 27.26387596130371  0.4974517822265625
    + a, r* m* ]% C4 b----------------------------------------------
    2 o0 X' S8 O1 x7 w8 G" k: J8 ]最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。- F$ |2 w2 U8 Z( K7 H
    高手们帮看看是神马原因?$ D. @$ X, ?$ |% c" S

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    " x9 j: g( B4 h: V5 k
    + D" Y& \& u# Y* j( b9 s没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    % \) P5 ~6 b: C-------: T6 I2 b; q2 l6 j
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。: {. d! h0 f- y$ x) {; s# z( i
    -------7 u) ?1 m7 C" q, h1 b0 a  Y
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23' j6 k- {/ k- Q
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?6 ]2 M& w4 G% d5 H+ r& s" `
    -------: m( y7 ?4 v4 T* _, z! @
    不好意思, ...

    " i6 K3 V1 b9 E; a谢谢,算法应该没问题,就是最简单的线性回归。. h1 @8 _% Y. b; ^
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    + E$ U0 B+ U2 f% ~$ r
    雷达 发表于 2023-2-14 21:525 S# B2 W( Z9 F! R, w' r
    谢谢,算法应该没问题,就是最简单的线性回归。. u% r% V9 o. o8 X/ x
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    7 o- V; M' o( {0 W* u0 w5 {( M
    3 {: R" G7 c4 c( O1 B8 L1 k7 `" H. d! F刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。* G  f& y9 D) P+ w$ k6 e& w

    4 n$ W4 }  L* i. R/ k或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑 / Z# N/ \8 U7 l+ ?( r
    老福 发表于 2023-2-14 22:00
    " G' |; s3 j6 v刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。% p# u4 [3 D. `# {

    ; [; \3 @) F2 D: ]  ^( F或者把b但的起点改为1试试。 ...

    ) @; x5 t% M8 w* L7 f+ m, ?( H; s8 D  O# x/ A
    你是对的。
    - p; U7 F! T6 p$ T2 N/ F去掉了随机部分
    + m) i) Q9 v- o! v5 U) a#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    4 H% ~) K* ~3 O; g: I+ Wy = (x*27+15).reshape(-1)
    % x: \7 w. h8 J* ^0 I
    $ _/ s: ?0 z# H7 i! }循环次数加成10倍,就看到 b 收敛了
    7 m9 F& G) X6 Z$ Uw , b" N) c+ M! Y; B, O% `4 a1 U
    27.002620697021484 14.826167106628418* O, v& D4 i5 a( p) l5 K! A

    2 n! Q% F/ Z! q) [6 z4 q和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-5-31 11:31 , Processed in 0.059380 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表