设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2759|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 7 w  t" H( B. j3 q9 b$ o% ~
    % z5 H# T: `/ s
    为预防老年痴呆,时不时学点新东东玩一玩。
    / y# K. n1 h9 c6 Q: h- fPytorch 下面的代码做最简单的一元线性回归:
    9 k# T9 {) G" P3 h----------------------------------------------
      Y" ^( v% ~6 b# ?* Q# M8 himport torch
    9 b, I0 A' [' b) w+ Jimport numpy as np
    1 F4 p5 g; K3 j$ |- t" simport matplotlib.pyplot as plt
    3 G! S0 I, I$ [  `+ Bimport random
      _: R( \- X9 M( m) `
    $ @1 q7 D1 s0 kx = torch.tensor(np.arange(1,100,1))
    " _3 q( C$ q# k7 {( Vy = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    3 a$ ]) a3 j: X3 ?/ Z: L  z' H( Q2 g, z( R; `( @2 n3 @
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    # i3 J9 X1 I; {  v+ sb = torch.tensor(0.,requires_grad=True)
    4 z! g+ K( K* A8 [  ]' S* k% M
    5 t3 B& }! l3 k( Oepochs = 100
    7 v/ l  y, q: E2 g# \) [" N8 H: L9 j9 t% C
    losses = []( D) \. Y$ W; u' A5 X5 f
    for i in range(epochs):
    # u+ N9 F: M$ |/ Q( b( B  y_pred = (x*w+b)    # 预测4 F. S( W5 T5 `% h: L9 C. y
      y_pred.reshape(-1)
    / l9 A4 H) x( h . f& W3 r) W; V( r8 w/ b
      loss = torch.square(y_pred - y).mean()   #计算 loss
    1 }% K  n% D, f! A# Q% {  losses.append(loss), `. h& h0 `+ O* v% }4 o" y- E
      9 A7 O2 R" \' u, V( n4 A9 Z! `) o& l
      loss.backward() # autograd& Y& |5 r) h+ B9 k& j, @
      with torch.no_grad():
    3 L3 ^/ d& k/ P    w  -= w.grad*0.0001   # 回归 w6 f0 i# Z2 U! r7 j
        b  -= b.grad*0.0001    # 回归 b
    6 D" E" c+ P& F. [  w.grad.zero_()  
    # c1 B1 i% {( y4 z  b.grad.zero_()
    , m6 g0 @6 i, J3 _/ o
    & r' F& `9 |2 d& `5 vprint(w.item(),b.item()) #结果
    : S# l9 C6 i3 S( K+ E. ?* i' t/ n
    + h$ S. }: C2 ^( F7 E1 k- i& ?4 _9 GOutput: 27.26387596130371  0.4974517822265625
    9 z/ `0 j# q  C) \7 X----------------------------------------------
    7 h7 w' b; _  Y最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。6 l+ S6 |$ n$ U8 i3 Y3 m
    高手们帮看看是神马原因?
    2 t4 _5 o; J+ I' F

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    ) ~' R) [* N0 [. c% u
    * A$ l4 |6 N" y) P没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    * k; e+ b  V9 Z+ R! Y) H-------
    % l* l* M( P8 o不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。9 h1 Y, s! N% \) {9 C+ E0 b3 S
    -------
    ! V7 O" D: l2 ?) `6 e算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    8 S( j: c- [  n/ u3 B5 [没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    . E( i1 Y% K1 _; C( ]4 M$ X-------# D* d7 h0 I; y( {0 W3 R! \
    不好意思, ...
    7 f! x& U. t4 T
    谢谢,算法应该没问题,就是最简单的线性回归。1 A; t6 s- E: p% l3 E% o" P. L+ V
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    ! T( P0 U) M, h9 m8 Q, c) {3 s
    雷达 发表于 2023-2-14 21:52
    9 D! A0 o2 q  ]+ _9 h谢谢,算法应该没问题,就是最简单的线性回归。/ {' `$ ^5 F7 i' A& `6 ]% s
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    ; Y- w, p& q% K7 K

    8 U" W: U/ I% @2 G2 q: J刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    & e/ C1 s1 M" ?( n
    4 i/ l/ C; d# S  B. D或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    5 w" f. X$ I) T0 l
    老福 发表于 2023-2-14 22:000 @* x4 r; d1 J
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。- `- U) M# F" Q
    - [) \- F2 s4 S
    或者把b但的起点改为1试试。 ...

    : l) v8 Z4 a- S. m# e; _; Y  z( _. F- d: D
    你是对的。
    ' z/ D) L3 T8 T; [( g0 R+ {/ w. q去掉了随机部分1 S# K0 u. s' K4 n& @+ m
    #y = (x*27+15+random.randint(-2,3)).reshape(-1)# Y& s0 f) F. z
    y = (x*27+15).reshape(-1)
    ; f. k* n/ q. {- D  h  e9 Q) R7 W2 R
    循环次数加成10倍,就看到 b 收敛了" l5 d, u$ w* T0 a9 |* o" C
    w , b6 |! T" F4 P7 p; [# q, r( A; D: N
    27.002620697021484 14.826167106628418
      q1 ?$ l# N- C  C( f7 j* T/ O+ N. ?# t& B  z. i, Y8 ~. ]0 {* \
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-4-17 17:27 , Processed in 0.073070 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表