设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2340|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    1 `  r* W' Z, L4 T7 S/ A! ]1 D! ?, f, |' D3 q
    为预防老年痴呆,时不时学点新东东玩一玩。
    - D) W! c! y! z# hPytorch 下面的代码做最简单的一元线性回归:
    2 O( T8 n, o( h/ x----------------------------------------------
    9 O, S5 p( o% H( ]import torch
    1 E5 \, ]; q# L% X' Y% Kimport numpy as np  e2 M) D* Q$ V
    import matplotlib.pyplot as plt
    7 a6 K; m: }$ [& X( a  n: C7 ]import random
    5 I3 u: z  C  x" |) ?4 S
    , t0 u% l) p0 ^5 W4 \! hx = torch.tensor(np.arange(1,100,1))
    / U- f2 g" F( Z1 M2 w2 zy = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    4 b0 {6 @( a6 p# @' k' m& u
    ; d) s4 O7 a/ f. c+ s9 O4 }w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    6 N- V, v( }+ S+ T) N5 z+ Vb = torch.tensor(0.,requires_grad=True)3 \! X3 P% l9 R) s

    ! @' m# X5 R, d2 `. D  wepochs = 100
    5 w9 |6 i1 F6 I7 L3 E4 a  J2 c" g; M7 O. G* k. c; `
    losses = []
    * Y& z, B$ }: \& @for i in range(epochs):3 e' [7 T2 R' U: k: Q5 A) `7 h( }
      y_pred = (x*w+b)    # 预测  B, q# \9 r$ [3 p. y/ T' e- M$ R
      y_pred.reshape(-1)- e0 s3 i( {6 w3 Z, p4 s3 R0 X% h

    - K0 x7 o' x3 S2 ?& u1 U  loss = torch.square(y_pred - y).mean()   #计算 loss
    ' G# f7 p* q/ I( S7 ^: S4 c; C" n+ ?& W  losses.append(loss)% N+ v1 h- a: Y/ U$ j% d
      ) O$ X; |$ y% ?* p' G, t
      loss.backward() # autograd& n: o4 I7 A  ], S
      with torch.no_grad():
    7 F# e0 ?- U& \" U* T    w  -= w.grad*0.0001   # 回归 w5 ]  V- `8 N! r' b/ c
        b  -= b.grad*0.0001    # 回归 b 3 s: \/ s/ s9 W8 W* R
      w.grad.zero_()  
    5 k* q5 N5 q7 H  b.grad.zero_()
    4 Y7 Y7 C0 L. t0 A
    3 k1 g$ K4 b! uprint(w.item(),b.item()) #结果
    $ F: k1 s/ N. {- d" u6 H' \
    ; G5 z6 A5 k: I- U& z; I" ?Output: 27.26387596130371  0.4974517822265625
    2 B2 Z. t. k7 [* v$ w/ a----------------------------------------------
    2 `0 |5 `, X% ^2 p: R最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。' j; l1 b3 n" V! k
    高手们帮看看是神马原因?3 l( {# E2 W- q

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    % B2 h9 z; O0 T: C$ a4 |+ O+ B+ [
    1 k  U9 [! {, s# z9 S. E9 ~& ~没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    ' z7 R, q* p2 x' H' s-------
    * j* d: `) n; o, w; U不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。& b5 \; `" `2 R4 X& C( {, k$ f5 C5 Z
    -------
    0 l( X4 ^# e( s8 P. K* W) ]算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    & B& |' Q& J9 w' k4 @$ x) w, G没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?) b; }8 q: L. w" ~5 ^- t
    -------
    ( H3 r2 a0 a2 _! L- a' h0 \+ E: `不好意思, ...

    ! \: Z+ t& u( w8 b% I谢谢,算法应该没问题,就是最简单的线性回归。
    / K* H7 T9 ]8 M- p我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    " _: _, }; W8 m/ s. q9 ?* e5 T
    雷达 发表于 2023-2-14 21:521 d9 g1 P( [" R  n9 @) Y
    谢谢,算法应该没问题,就是最简单的线性回归。
    ' A( R2 d3 C2 A/ ]% c" n我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    2 K5 S- A/ {9 |8 _; r3 r( U9 u# j1 t6 q9 K# @3 v5 p) O
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。, Z' `& ]. U0 L
    , E- ]4 ]- M' M- v# y& M
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    8 l3 m5 V: d# @- G
    老福 发表于 2023-2-14 22:00
    . c$ G2 j6 F, I4 h9 y+ ]刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。2 R/ U( T: h9 ~1 \; t+ O* {

    * j+ g4 V* S% m% ~或者把b但的起点改为1试试。 ...

    , d& d, _8 C4 u/ E7 e8 t& ]  K) e
    7 Y0 Y" y, s/ M9 X你是对的。, @  d1 B3 d0 a! l2 j* c
    去掉了随机部分
    1 @$ I6 o5 Z8 `# x6 e% ?; f6 A% c#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    / C8 N2 h6 B, V% K* wy = (x*27+15).reshape(-1)# e5 {9 r0 T: k4 N1 ?% M, O3 K
    ) s* M! s% S% s& ?: J+ h, X! p
    循环次数加成10倍,就看到 b 收敛了
    " r. p5 P# c( Q- N7 [4 \  d# C$ ew , b
    9 e, H, y; U9 u+ \27.002620697021484 14.826167106628418
    3 R( V4 m9 @* G9 C. E( Q2 A3 ^+ t  w. N& @$ E: N+ u1 B+ V
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-1-16 05:50 , Processed in 0.036690 second(s), 22 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表