用PyTorch和Python实现LSTM:时间序列预测的实战教程
原文链接:
https://stackabuse.com/time-series-prediction-using-lstm-with-pytorch-in-python/
时间序列数据,顾名思义是一种随时间变化的数据类型。例如,24小时时间段内的温度,一个月内各种产品的价格,一个特定公司一年的股票价格。高级的深度学习模型,如长短期记忆网络(LSTM),能够捕捉时间序列数据中的模式,因此可以用来预测数据的未来趋势。在本文中,您将看到如何使用LSTM算法使用时间序列数据进行未来预测。
Dataset and Problem Definition
我们将使用的数据集内置在Python Seaborn库中。让我们先导入所需的库,然后再导入数据集:
1 import torch 2 import torch.nn as nn 3 4 import seaborn as sns 5 import numpy as np 6 import pandas as pd 7 import matplotlib.pyplot as plt 8 %matplotlib inline
读入数据
1 import pandas as pd 2 flight_data = pd.read_csv('./data/flights.csv') #或者flight_data = sns.load_datasets("flights") 3 flight_data.head()
输出:
flight_data.shape
#输出: (144,3)
数据集有三列:year、month以及passengers,包含了12年的乘客出行纪录;
任务:
根据面132个月的出行数据预测后12个月的出行数据;
绘制每个月的乘客出行频率:
1 plt.plot(flight_data['passengers']) 2 plt.grid(True) 3 plt.title("Month vs passenger") 4 plt.ylabel("Total passengers") 5 plt.xlabel("Months") 6 plt.autoscale(axis='x',tight=True)
从输出结果中可以看出,每年的乘客数量是在逐渐递增的;
在同一年内,乘客的数量是波动的,这是符合常识的,因为在节假日的时候,乘客的数量相较于一年中的其他日子是会变多的;
数据预处理:
首先,看一下数据集中的列的数据类型:
1 flight_data.columns
输出:
all_data = flight_data['passengers'].values.astype(float)
接下来,将数据集划分为训练数据集和验证数据集:
test_data_size = 12 train_data = all_data[:-test_data_size] #size=132 test_data = all_data[-test_data_size:] #size=12
此时,数据集并没有经过标准化处理;
但是乘客数量在刚开始的年份要远小于近两年的数量;
我们将使用Min/max进行标准化;
1 from sklearn.preprocessing import MinMaxScaler 2 3 scaler = MinMaxScaler(feature_range=(-1,1)) 4 train_data_normalized = scaler.fit_transform(train_data.reshape(-1,1))
之后,将其转化为tensor的数据形式:
train_data_normalized = torch.FloatTensor(train_data_normalized).view(-1) #转换成1维张量
最后,就是将数据处理成sequences和对应标签的形式;
在这里,我们取时间窗口为12,因为一年有12个月,这个是比较合理的;
1 train_window=12 2 3 def create_inout_sequences(input_data,tw): 4 inout_seq = [] 5 L = len(input_data) 6 for i in range(L-tw): 7 train_seq = input_data[i:i+tw] 8 train_label = input_data[i+tw:i+tw+1] 9 inout_seq.append((train_seq,train_label)) 10 return inout_seq 11 12 train_inout_seq = create_inout_sequences(train_data_normalized,train_window)
#一共有120个样本 132-12=120
创建LSTM模型:
1 class LSTM(nn.Module): 2 def __init__(self, input_size=1,hidden_layer_size=100,output_size=1): 3 super().__init__() 4 5 self.hidden_layer = hidden_layer_size 6 self.lstm = nn.LSTM(input_size,hidden_layer_size) 7 self.linear = nn.linear(hidden_layer_size,uotput_size) 8 self.hidden_cell = (torch.zeros(1,1,self.hidden_layer_size), 9 torch.zeros(1,1,self.hidden_layer_size)) 10 def forward(self, input_seq): 11 lstm_out, self.hidden_cell = self.lstm(input_seq.view(len(input_seq),1,-1), self.hidden_cell) 12 predictions = self.linear(lstm_out.view(len(input_seq),-1)) 13 return predictions[-1]
input_size:对应的输入数据特征; 虽然我们的序列长度是12,但是对于每个月来说,我们只有1个值,例如,乘客的总数量,因此输入的size是1;
hidden_layer_size: 每层的神经元的数量,我们每层一共有100个神经元;
output_size:预测下一个月的输出数量,输出的size为1;
之后我们创建hidden_layer_size, lstm, linear 以及hidden_cell。
LSTM算法接收三个输入:之前的输入状态;之前的的cell状态以及当前的输入;
hidden_cell变量包含了先前的隐藏状态和cell状态;
lstm和linear层变量,用于创建LSTM和线性层;
在forward()算法中,使用input_seq作为输入参数,首先被传递给lstm;
lstm的输出,包含了当前时间戳下的隐藏层和细胞状态,以及输出;
lstm层的输出被传递给Linear层,预测的乘客数量,就是predictions的最后一项;
1 model = LSTM() 2 loss_function = nn.MSELoss() 3 optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
1 # 模型的训练 2 epochs = 15 3 4 for i in range(epochs): 5 for seq, labels in train_inout_seq: 6 optimizer.zero_grad() 7 model.hidden_cell = (torch.zeros(1,1,model.hidden_layer_size), 8 torch.zeros(1,1,model.hidden_layer_size)) 9 10 y_pred = model(seq) 11 12 single_loss = loss_function(y_pred, labels) 13 single_loss.backward() 14 optimizer.step() 15 16 if i%2 ==1 : 17 print(f'epoch:{i:3} loss:{single_loss.item():10.8f}') 18 print(f'epoch:{i:3} loss:{single_loss.item():10.8f}')
模型的预测:
1 # 预测: 2 fut_pre = 12 3 4 test_inputs = train_data_normalized[-train_window:].tolist() 5 print(test_inputs) 6 7 model.eval() 8 for i in range(fut_pre): 9 seq = torch.FloatTensor(test_inputs[-train_window:]) 10 with torch.no_grad(): 11 model.hidden = (torch.zeros(1,1,model.hidden_layer_size), 12 torch.zeros(1,1,model.hidden_layer_size)) 13 14 test_inputs.append(model(seq).item())
上一篇: 实现自动驾驶的运动预测:VectorNet论文初探(第一部分)
下一篇: 【资料大全】路径表示
推荐阅读
-
用PyTorch和Python实现LSTM:时间序列预测的实战教程
-
包婷婷 (201550484)作业一 统计软件简介与数据操作-SPSS(Statistical Product and Service Solutions),"统计产品与服务解决方案"软件。最初软件全称为"(SolutionsStatistical Package for the Social Sciences),但是随着SPSS产品服务领域的扩大和服务深度的增加,SPSS公司已于2000年正式将英文全称更改为"统计产品与服务解决方案",标志着SPSS的战略方向正在做出重大调整。为IBM公司推出的一系列用于统计学分析运算、数据挖掘、预测分析和决策支持任务的软件产品及相关服务的总称SPSS,有Windows和Mac OS X等版本。 1984年SPSS总部首先推出了世界上第一个统计分析软件微机版本SPSS/PC+,开创了SPSS微机系列产品的开发方向,极大地扩充了它的应用范围,并使其能很快地应用于自然科学、技术科学、社会科学的各个领域。世界上许多有影响的报刊杂志纷纷就SPSS的自动统计绘图、数据的深入分析、使用方便、功能齐全等方面给予了高度的评价。 R统计软件介绍 R是一套完整的数据处理、计算和制图软件系统。其功能包括:数据存储和处理系统;数组运算工具(其向量、矩阵运算方面功能尤其强大);完整连贯的统计分析工具;优秀的统计制图功能;简便而强大的编程语言:可操纵数据的输入和输出,可实现分支、循环,用户可自定义功能。 与其说R是一种统计软件,还不如说R是一种数学计算的环境,因为R并不是仅仅提供若干统计程序、使用者只需指定数据库和若干参数便可进行一个统计分析。R的思想是:它可以提供一些集成的统计工具,但更大量的是它提供各种数学计算、统计计算的函数,从而使使用者能灵活机动的进行数据分析,甚至创造出符合需要的新的统计计算方法。 该语言的语法表面上类似 C,但在语义上是函数设计语言(functional programming language)的变种并且和Lisp 以及 APL有很强的兼容性。特别的是,它允许在"语言上计算"(computing on the language)。这使得它可以把表达式作为函数的输入参数,而这种做法对统计模拟和绘图非常有用。 R是一个免费的*软件,它有UNIX、LINUX、MacOS和WINDOWS版本,都是可以免费下载和使用的。在R主页那儿可以下载到R的安装程序、各种外挂程序和文档。在R的安装程序中只包含了8个基础模块,其他外在模块可以通过CRAN获得。 二、R语言 R是用于统计分析、绘图的语言和操作环境。R是属于GNU系统的一个*、免费、源代码开放的软件,它是一个用于统计计算和统计制图的优秀工具。 R作为一种统计分析软件,是集统计分析与图形显示于一体的。它可以运行于UNIX,Windows和Macintosh的操作系统上,而且嵌入了一个非常方便实用的帮助系统,相比于其他统计分析软件,R还有以下特点: 1.R是*软件。这意味着它是完全免费,开放源代码的。可以在它的网站及其镜像中下载任何有关的安装程序、源代码、程序包及其源代码、文档资料。标准的安装文件身自身就带有许多模块和内嵌统计函数,安装好后可以直接实现许多常用的统计功能。[2] 2.R是一种可编程的语言。作为一个开放的统计编程环境,语法通俗易懂,很容易学会和掌握语言的语法。而且学会之后,我们可以编制自己的函数来扩展现有的语言。这也就是为什么它的更新速度比一般统计软件,如,SPSS,SAS等快得多。大多数最新的统计方法和技术都可以在R中直接得到。[2] 3. 所有R的函数和数据集是保存在程序包里面的。只有当一个包被载入时,它的内容才可以被访问。一些常用、基本的程序包已经被收入了标准安装文件中,随着新的统计分析方法的出现,标准安装文件中所包含的程序包也随着版本的更新而不断变化。在另外版安装文件中,已经包含的程序包有:base一R的基础模块、mle一极大似然估计模块、ts一时间序列分析模块、mva一多元统计分析模块、survival一生存分析模块等等.[2] 4.R具有很强的互动性。除了图形输出是在另外的窗口处,它的输入输出窗口都是在同一个窗口进行的,输入语法中如果出现错误会马上在窗口口中得到提示,对以前输入过的命令有记忆功能,可以随时再现、编辑修改以满足用户的需要。输出的图形可以直接保存为JPG,BMP,PNG等图片格式,还可以直接保存为PDF文件。另外,和其他编程语言和数据库之间有很好的接口。[2] 5.如果加入R的帮助邮件列表一,每天都可能会收到几十份关于R的邮件资讯。可以和全球一流的统计计算方面的专家讨论各种问题,可以说是全世界最大、最前沿的统计学家思维的聚集地.[2] R是基于S语言的一个GNU项目,所以也可以当作S语言的一种实现,通常用S语言编写的代码都可以不作修改的在R环境下运行。 R的语法是来自Scheme。R的使用与S-PLUS有很多类似之处,这两种语言有一定的兼容性。S-PLUS的使用手册,只要稍加修改就可作为R的使用手册。所以有人说:R,是S-PLUS的一个“克隆”。 但是请不要忘了:R是免费的(R is free)。R语言源代码托管在github,具体地址可以看参考资料。[3] 。 R语言的下载可以通过CRAN的镜像来查找。 R语言有域名为.cn的下载地址,有六个,其中两个由Datagurn,由 中国科学技术大学提供的。R语言Windows版,其中由两个下载地点是Datagurn和 USTC提供的。 三、stata Stata 是一套提供其使用者数据分析、数据管理以及绘制专业图表的完整及整合性统计软件。它提供许许多多功能,包含线性混合模型、均衡重复反复及多项式普罗比模式。用Stata绘制的统计图形相当精美。 新版本的STATA采用最具亲和力的窗口接口,使用者自行建立程序时,软件能提供具有直接命令式的语法。Stata提供完整的使用手册,包含统计样本建立、解释、模型与语法、文献等超过一万余页的出版品。 除此之外,Stata软件可以透过网络实时更新每天的最新功能,更可以得知世界各地的使用者对于STATA公司提出的问题与解决之道。使用者也可以透过Stata. Journal获得许许多多的相关讯息以及书籍介绍等。另外一个获取庞大资源的管道就是Statalist,它是一个独立的listserver,每月交替提供使用者超过1000个讯息以及50个程序。 四、PYTHON
-
用 Python 中的先知模型进行天气时间序列预测和异常检测 - 方法
-
PyTorch-LSTM 实现单变量时间序列预测:一步步详解教程