* 废话不多说
* 准备数据 # 为实验准备数据 def init_data(): x = torch.FloatTensor(torch.linspace(0, 100)
) y = x + torch.FloatTensor(torch.randn(100)) * 10 return x, y def
train_test_split(x, y, radio=0.1): assert len(x) == len(y), "length of x is not
equal to y" size = int(len(x) * radio) # print(size) # print(x.shape) x_train,
x_test= x[:(len(x) - size)], x[(len(x) - size):] y_train, y_test = y[:(len(y) -
size)], y[(len(y) - size):] return x_train, x_test, y_train, y_test
训练数据可视化
def plot(x_train, y_train): plt.figure(figsize=(10,8)) plt.plot(x_train.data.
numpy(),y_train.data.numpy(),'o') plt.xlabel("X") plt.ylabel("Y") plt.show()
* 来来来你想要的梯度下降法 x, y = init_data() print(x.shape) x_train, x_test, y_train,
y_test= train_test_split(x, y) plot(x_train,y_train) w = torch.nn.Parameter(
torch.randn(1)) b = torch.nn.Parameter(torch.randn(1)) learn_rate = 0.0001 #
print(len(x_train)) # print(len(x_test)) for i in range(1000): pred = w.
expand_as(x_train)*x_train+b.expand_as(x_train) loss = torch.mean((pred-y_train)
**2) print(loss) loss.backward() w.data.add_(-learn_rate*w.grad.data) b.data.
add_(-learn_rate*b.grad.data) w.grad.data.zero_() b.grad.data.zero_() x_data =
x_train.data.numpy() plt.figure(figsize=(10,7)) xplot, = plt.plot(x_data,y_train
.data.numpy(),'o') yplot, = plt.plot(x_data,w.data.numpy()*x_data+b.data.numpy()
) plt.xlabel("X") plt.ylabel("Y") str1 = str(w.data.numpy()[0])+'x+'+str(b.data.
numpy()[0]) plt.legend([xplot,yplot],['Data',str1]) plt.show() pred = w.
expand_as(x_test)*x_test+b.expand_as(x_test) print(pred) x_data = x_train.data.
numpy() x_pred = x_test.data.numpy() plt.figure(figsize=(10,7)) plt.plot(x_data,
y_train.data.numpy(),'o') plt.plot(x_pred,y_test.data.numpy(),'s') x_data = np.
r_[x_data,x_test.data.numpy()] plt.plot(x_data,w.data.numpy()*x_data+b.data.
numpy()) plt.plot(x_pred,w.data.numpy()*x_pred+b.data.numpy(),'o') plt.xlabel(
"X") plt.ylabel("Y") str1 = str(w.data.numpy()[0]) + 'x+' + str(b.data.numpy()[0
]) plt.legend([xplot, yplot], ['Data', str1]) plt.show()
* 训练集结果显示

* 训练集待判结果显示
*
* 待判和回判结果显示
*
* 理解pytorch底层原理
* 梯度下降法的的基本实现
* matplotlib基本绘图

技术
今日推荐
阅读数 133
下载桌面版
GitHub
百度网盘(提取码:draw)
Gitee
云服务器优惠
阿里云优惠券
腾讯云优惠券
华为云优惠券
站点信息
问题反馈
邮箱:ixiaoyang8@qq.com
QQ群:766591547
关注微信