设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 3022|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    1 S) _  x! O+ w3 ?6 U1 v# Q9 \) b7 u7 }
    为预防老年痴呆,时不时学点新东东玩一玩。6 w! D" \: W& {/ W* w6 [1 O
    Pytorch 下面的代码做最简单的一元线性回归:
    & {6 n, R  A# `----------------------------------------------3 J$ p3 @- C5 h$ k2 Q
    import torch1 i! n" w2 D# L% d4 T9 D
    import numpy as np3 [8 b  B$ {! _  s6 `
    import matplotlib.pyplot as plt
    0 v: t* k  f# F( {4 l! i1 b3 Aimport random: w; n0 i- U0 _" s1 C
    " L, s# [- Z) s* i; ?: j# _
    x = torch.tensor(np.arange(1,100,1))4 q) v! ]" R( z; Y: d! W+ O
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    0 p) ~" S( O- m6 v) D" ^8 H; A0 F0 Q% D- C% o: x
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b* ~" Y  N- g( ?9 ~& n
    b = torch.tensor(0.,requires_grad=True)0 T& c/ k2 m. b+ _+ U

    2 g# Y/ v8 P4 ]epochs = 100) E+ O* u! w' a) M. ~6 D
    ' |$ h0 ?9 |0 k; H7 L( v0 f4 |
    losses = []
    ! B; I0 X- L8 f$ H: ?$ B2 f5 T" d+ ffor i in range(epochs):
    ; d$ Q4 o( B$ d4 l/ m% `  y_pred = (x*w+b)    # 预测( b5 f  x1 V1 x
      y_pred.reshape(-1)0 {& o  T) L, L7 g
    7 M# c1 S7 x+ l7 q
      loss = torch.square(y_pred - y).mean()   #计算 loss: d/ `; b, j4 m& p7 B2 P
      losses.append(loss). Q* F* _( \8 T6 o; {
      " t% X8 g( b1 |& u" b/ |
      loss.backward() # autograd
    3 K5 ?3 z6 W" y  with torch.no_grad():+ c( r+ ^2 G7 L
        w  -= w.grad*0.0001   # 回归 w
    0 i( @* f' I- U& W' L9 z2 W: ]: L    b  -= b.grad*0.0001    # 回归 b
    , F3 Y8 w! H% N$ {  w.grad.zero_()  
    " l8 G' a/ }. e7 m; w) i1 H  b.grad.zero_()
    7 @- P: D* y, b5 i1 K1 ]. z, |# t( w
    " ]4 p' ~8 b5 _( T/ v& u7 c) G, hprint(w.item(),b.item()) #结果
    % `9 D, Q. G; r) ~) {$ h4 v5 e: ~" R! ^+ c* U- G( C+ O( _
    Output: 27.26387596130371  0.4974517822265625
    : Z; c1 K4 ^& Z----------------------------------------------, W  r" E4 n/ p# n9 _6 I, r
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。+ N/ M9 [( n! g
    高手们帮看看是神马原因?. i5 p; E$ s  o

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    / D  u+ |; u$ p' g8 M2 m( m7 B0 P+ R% M7 a, O/ _& l: I
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?) Z3 f" u1 z  A1 |
    -------
    0 S# Y: L( ^' O, j/ ?' Z2 E2 G' n6 N不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。- r+ r7 Q5 e/ H4 R" \) ^
    -------
    6 P6 d2 B, X; ?5 M, V算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    4 M. ~. U% V2 f/ G9 g没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    ; W' l; v9 p% f0 v-------1 W9 N/ j: W! `. W3 h( I
    不好意思, ...

    $ ^) p# v, O2 X) K. r% N谢谢,算法应该没问题,就是最简单的线性回归。" ~3 b& `1 J  T% M
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    ( H8 i* T& W9 {: e
    雷达 发表于 2023-2-14 21:52
    * t8 G3 d1 o8 B6 s. y: o谢谢,算法应该没问题,就是最简单的线性回归。
    5 ]  N  v* a+ `& n- k+ A1 ~我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    " |5 {* r6 r6 Z5 ?

    ! ]! G; }: T) o刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    : ^* f7 K3 l* J4 q
    # G8 i6 |6 y1 N* @或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑 . S% f! c7 h9 a/ v9 l- F
    老福 发表于 2023-2-14 22:00
    7 q8 E" T2 s7 h+ V3 G刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。: N0 n* y) L" I& z9 p
    ; U! |' o8 T/ a
    或者把b但的起点改为1试试。 ...

    9 K& W' o% K7 X
    ; _6 d6 v3 p  K) V* {你是对的。( u" j: u* o3 c  B6 \8 L7 z
    去掉了随机部分! Z2 r% m' Y: J$ Y4 r& l4 \3 h
    #y = (x*27+15+random.randint(-2,3)).reshape(-1)
    , F2 X8 n" y2 S  N; vy = (x*27+15).reshape(-1)
    2 w! |. g' q, I: {
    & T2 c7 ?, U0 S" H- m' s循环次数加成10倍,就看到 b 收敛了
    3 q, W# S/ K! k/ M* x/ xw , b
    ' [* K6 A0 ]$ _$ c2 x8 Z# u7 @+ e27.002620697021484 14.826167106628418
    " o3 I! w( A5 z$ N/ P# @9 S; `4 b5 A% x! z1 U, d
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-6-3 08:16 , Processed in 0.068718 second(s), 22 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表