设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2136|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 + u8 R; i! S* J. M, v! k$ ^
    & w# m1 `5 C: D
    为预防老年痴呆,时不时学点新东东玩一玩。7 e3 m5 ]8 E7 l  E
    Pytorch 下面的代码做最简单的一元线性回归:1 t- `! ?1 P5 {+ V: L, J" ?. P
    ----------------------------------------------- L" `$ Y- y. Q
    import torch/ e* _% G1 y$ `4 [3 ^& ]; c) C2 W
    import numpy as np
    ! O7 d# W- `% r6 T. b. d! t4 c. ]import matplotlib.pyplot as plt
      o  I; S( j8 e  Simport random
    ( o2 ^8 S5 M: X" S+ s1 N' v! o
    x = torch.tensor(np.arange(1,100,1))4 [- X: H. h, {) J; p
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    / |- e. d: Z8 k$ S0 W) W2 |3 F" r0 Y+ q* p) C- q/ b
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b3 H/ E" w" E  D& H* E+ r" C
    b = torch.tensor(0.,requires_grad=True)# S3 u2 G1 p  K6 R

    % I4 k" Q5 w+ o. A9 gepochs = 100, N- u: W& _% S. w! {

    * K; M7 u( ?$ y7 ?losses = []
    9 r% B4 c8 L& y  J3 cfor i in range(epochs):
    2 ~( J  g7 F; K2 f; g) m  y_pred = (x*w+b)    # 预测4 O) L3 B7 D4 o+ t# |0 C
      y_pred.reshape(-1)* S& B7 u' [8 E+ B6 y" @! d

    ) C8 }! U8 T; D* [* p* ?  loss = torch.square(y_pred - y).mean()   #计算 loss
    : N6 `, h7 z* d  f! i  losses.append(loss)
    5 p/ a4 z! w( ], A% z7 q6 C  ( V1 Y0 h9 p, V7 G6 j0 k# |; L* V
      loss.backward() # autograd
    1 ]$ V$ P' n% v- [' T  with torch.no_grad():5 b+ e; M4 J+ ]  H' L+ J/ c) |7 @
        w  -= w.grad*0.0001   # 回归 w" ?7 q0 a* g' x. N4 |5 O
        b  -= b.grad*0.0001    # 回归 b * n1 f. a) g( ?' t8 }2 \
      w.grad.zero_()  
    & G8 S8 L0 x' {7 q  b.grad.zero_()$ @6 F$ p+ z% L/ x0 F$ u; a

    $ a8 t. l, t+ y8 Z% oprint(w.item(),b.item()) #结果7 R* h& |% U8 g4 D% @$ E+ v
    # P' s  k2 p. @+ R
    Output: 27.26387596130371  0.4974517822265625& U, ~5 L, d, d% u: d
    ----------------------------------------------1 p+ d  v- \( J+ ~+ X, n2 E
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。, z  E% V1 x# a" `' v! {
    高手们帮看看是神马原因?3 C! X& J3 m* D  |3 y

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 " l6 y  j  m5 r; N
    3 S5 E7 k) @& k/ j7 L: L, v+ d/ ?+ a
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    9 F" q" B1 S9 o8 E4 ]' C" d/ M3 F-------
    9 _2 i9 J9 C; _8 K! _不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。$ h5 _1 I, J9 y5 l' O% u: g) D
    -------
    % O6 Q" l2 V1 x& R算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    - x, ~4 r+ f7 f! V  d" T% t没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    9 B4 |- A, N( A-------; P" N, O' y! J# p% v& G
    不好意思, ...
    ) ]% }, T( T/ {% f% h
    谢谢,算法应该没问题,就是最简单的线性回归。
    " S( k- D) B/ w, }) g6 B. p我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    7 l( Z* f/ D5 w) F/ M, v
    雷达 发表于 2023-2-14 21:52; s% X  ^7 [: M" D# [
    谢谢,算法应该没问题,就是最简单的线性回归。
    , C' F6 s# F- t1 Q: l9 ^. B我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    * \' e2 J8 `5 L5 q- ]& p
    1 Y! l4 l- p' }1 I" R/ z
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。- p& o8 e, L: |, q$ M
    , |1 U: U8 i! R( j2 B
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    : t7 z) T1 l! K; ?( C
    老福 发表于 2023-2-14 22:00
    5 ?3 O# x$ z' ^1 ?8 |0 u  [, }刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。* j" e% \" C2 w& Z. i, I

    % t$ W$ C7 g2 x( U  j或者把b但的起点改为1试试。 ...
    % K3 M% j7 d& y5 M% y' l
    ! u/ [5 `. Y9 G
    你是对的。. w$ ], \7 e! I( ^9 i' i6 S: s( m6 d
    去掉了随机部分
    5 E4 M# z3 [0 y#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    ! J, ^1 k4 C* T& o; _y = (x*27+15).reshape(-1)
    * r* p4 Q; c7 H2 K: I8 q3 Z- ?9 z: M! W" c3 {
    循环次数加成10倍,就看到 b 收敛了2 f: m# h  d& O) k6 f
    w , b
    5 F- g7 ~4 U! C& E/ A' ^3 V27.002620697021484 14.826167106628418
    3 [5 ]1 P, x( B$ b* y/ L
    , ^3 {) T- b" T7 o6 F, H) S2 t和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2025-11-24 17:01 , Processed in 0.029115 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表