设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2200|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    4 z# }( _+ ?) N9 K8 {* G
    % T) E3 d; }# ~6 J& Y: Q为预防老年痴呆,时不时学点新东东玩一玩。8 ?% n& O2 W! _0 h. E8 o
    Pytorch 下面的代码做最简单的一元线性回归:3 R, i" d8 D* O& x6 T
    ----------------------------------------------6 M' T6 j4 M+ x5 }0 U9 _
    import torch% O' x1 ~, T: z, D4 U
    import numpy as np
    5 [: m3 j) t) M1 h) ?, Oimport matplotlib.pyplot as plt
    ) F- l8 ^8 v) r1 ?import random
    1 W0 J, ~8 b* E) W  N0 g
    % o' F# {2 L8 Z/ K1 ^" px = torch.tensor(np.arange(1,100,1))) {8 j% c$ e- M$ S3 f6 {
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    ( X- w0 e% c4 }5 ]) e" d, \* [" ^# G- B4 U% u% [$ B
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b" ?& `5 X  V9 w7 L& W5 P
    b = torch.tensor(0.,requires_grad=True)
    ) z, t: [' m8 @3 G
    7 ]8 x. D% `8 H, Pepochs = 1002 j5 u1 w8 \) c6 I

    2 }* _' e' X& Plosses = []# w7 M" R7 C5 c* M5 P
    for i in range(epochs):
    8 K* w0 q% h: T4 y6 r* D! B; J  y_pred = (x*w+b)    # 预测
    0 b" F, }3 B2 N: D( o1 w  y_pred.reshape(-1)* I9 |! g& X5 @" f

    / m1 ?, d4 y1 T6 C  loss = torch.square(y_pred - y).mean()   #计算 loss
    2 D5 F3 Z, ?  B/ a# Z0 F7 i# Z. h  losses.append(loss)
    ; U7 w# b3 `5 m$ I  * V% Q3 e, Z% u3 r. c, W
      loss.backward() # autograd
    : J( ]+ V2 f2 j# i# ~5 x  with torch.no_grad():! T! `8 m, X& h2 q1 {$ a/ k
        w  -= w.grad*0.0001   # 回归 w
    7 A4 [" M; F, \' |3 X  k- B    b  -= b.grad*0.0001    # 回归 b
    * q0 J, }+ B& l0 L% G7 N$ C  w.grad.zero_()  
    4 k4 G$ V& F4 J9 g  b.grad.zero_()
    ( y% q; O$ [* T" H) p7 W. b, g0 v
    print(w.item(),b.item()) #结果
    $ i2 D+ K6 e& r  F2 D) Z; h8 d' g
    5 x( k7 Q, |6 I# H( ~5 n& zOutput: 27.26387596130371  0.4974517822265625& n) Q- t4 f. e% ]1 w# M+ C
    ----------------------------------------------: k7 l. ^  m$ }4 A, {$ r# ^$ o
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    # o5 C) @- X4 y- i* @高手们帮看看是神马原因?
    " z; i! d2 w1 L3 V2 T1 D; Z  N- S

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    + p( ]; v" g5 A3 U4 ]. j1 Y& `% N, v1 J
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?( q+ S1 i  ^- p- k5 ?& A5 V
    -------, I+ M( Y/ v' O4 ^, R
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    1 z$ {. B* p: e' p-------
    % X2 J0 ?3 `, y0 Z/ b算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    ( {+ U5 n0 u- \没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    & H2 G" P9 a# x9 w2 J-------
    % N/ e! Y; S  l( d& e8 S$ e不好意思, ...
    & q8 Q8 x& t* j+ A: `
    谢谢,算法应该没问题,就是最简单的线性回归。5 D2 x+ |: {1 G8 X
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 & S' F& X# l1 u! {3 i, g
    雷达 发表于 2023-2-14 21:52
    2 q' g& t' R9 a& G; p1 p8 @7 i' l谢谢,算法应该没问题,就是最简单的线性回归。
    ( B; [* S1 x0 p+ _/ c4 ^1 ~& M我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    # c2 @2 c7 V3 d& R

    9 O+ l1 b, g4 g刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    # A) K! A' Q7 j: z' s- }6 o" u  x  ~& g4 [' e
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    ) k5 k) u; \( y8 ]) b4 O! B1 q
    老福 发表于 2023-2-14 22:00
    9 R: X: P' [& k. _; F* c7 u0 H刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。9 P: A4 m/ `) Y; h2 S) _. X) C

    - B( e8 k6 }  Y: g或者把b但的起点改为1试试。 ...
    " m: v7 g* E( r

    , K. L( A* [3 `+ c你是对的。
    ; v) |; y: L- D去掉了随机部分
    ' `; \$ p! x% G% ^- I3 k' F: ^#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    % }( ~  o* C7 Zy = (x*27+15).reshape(-1)
    + O: c4 |- F# ^* [' Z, P% J& [( d* W' w* r. @% L
    循环次数加成10倍,就看到 b 收敛了
    # A  n+ y  x8 _, s2 u' B0 aw , b
    % u- {. k( r' g/ v2 E; y27.002620697021484 14.826167106628418
    8 L1 n2 A+ Z. E' [) J  s2 p, S4 M' a+ {: Y) b2 ?
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2025-12-10 20:54 , Processed in 0.030769 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表