PyTorch 深度学习项目 100 例】--基于 pytorch 使用 LSTM 实现本篇新闻的分类任务 | 例 9
前言
大家好,我是阿光。
本专栏整理了《PyTorch 深度学习项目实战 100 例》,内包含了各种不同的深度学习项目,包含项目原理以及源码,每一个项目实例都附带有完整的代码+数据集。
正在更新中~ ✨
???? 我的项目环境:
- 平台:Windows10
- 语言环境:python3.7
- 编译器:PyCharm
- PyTorch 版本:1.8.1
???? 项目专栏:【PyTorch 深度学习项目实战 100 例】
一、使用 LSTM 实现新闻本文分类任务
本文主要用 LSTM 循环神经网络拟合微调实现一个包含十五个类别的新闻文本分类任务,主要是对新闻内容进行特征抽取,获取语义分析来实现分类任务。
在这里插入图片描述
二、数据集介绍
整个数据集整合划分出 15 个候选分类类别:法治、国际、国内、健康、教育、经济、军事、科技、农经、三农、人物、社会、生活、书画、文娱的文本数据。
数据总共有 4482 条新闻纪录,字段分别为:标题、标题链接、新闻内容、关键词、发布时间、标签、新闻采集时间。
在这里插入图片描述
三、项目实现思路
为了体会 LSTM 的作用,并没有对原始数据进行高纬度的建模,只是使用了新闻内容的这个特征,没有对其他特征进行建模,由于新闻内容是中文文本数据,所以我们需要对其进行向量化,转成数值型数据然后送入到网络模型。
但是对于文本数据来讲,如果只是单纯使用 Embedding 进行嵌入的话,完全没有考虑到语义那种前后联系,会导致模型训练效果较差。
所以我们本项目中使用了 LSTM 这种网络进行捕捉语义信息,因为 LSTM 是一个循环神经网络,它可以将上个学习步的细胞信息传递给下个细胞,这样就会把前面出现的语句信息与当前输入进行结合,来预测之后出现的语句。
首先对于输入数据我们将其进行序列化,将一句话中的所有字转成对应的索引号,如果长度不足,我们需要使用 0 进行填充,保证输入到网络模型中的向量长度一致,然后需要使用 Embedding 进行将其进行嵌入,获得每个字的嵌入连续型向量,此处也可以使用 one-hot 编码,但是这会导致维度爆炸,以及矩阵稀疏问题,之后把生成的嵌入向量导入到 LSTM 层中,然后,因为这个时间片已经保存了整个语句的语义信息
四、网络结构
项目中使用的模型是 LSTM,在模型中我们定义了三个组件,分别是 embedding 层,lstm 层和全连接层。
在这里插入图片描述
- Embedding 层:将每个词生成对应的嵌入向量,就是利用一个连续型向量来表示每个词
- Lstm 层:提取语句中的语义信息
- Linear 层:将结果映射成 2 大小用于二分类,即正反面的概率
注意:在 LSTM 网络中返回的值为最后一个时间片的输出,而不是将整个 output 全部输出,因为我们是需要捕捉整个语句的语义信息,并不是获得特定时间片的数据。
五、语句测试
- 首先需要对我们待测试的语句进行转为序号编码
- 如果序列长度不足,使用 0 进行填充
- 加载训练好的模型
- 加载数据映射字典,获得结果
try:
# 数据预处理
input_shape = 80 # 序列长度,就是时间步大小,也就是这里的每句话中的词的个数
# 用于测试的话
sent = "今年春节档国内电影市场异常火爆,已创造多项新纪录。“七大新片齐上阵”,其中动画电影《新神榜:哪吒重生》乍一看不是特别显眼,除了搭上了《哪吒:魔童降世》的“哪吒”IP,也没有多少前期宣传,从当前票房反应来说也只能算中规中矩,但分析..."
# 将对应的字转化为相应的序号
x = [[word2idx[word] for word in sent]]
# 如果长度不够180,使用0进行填充
x = pad_sequences(maxlen=input_shape, sequences=x, padding='post', value=0)
x = torch.from_numpy(x)
# 加载模型
model_path = './best_model.pkl'
model = LSTM(vocab_size=vocab_size, hidden_dim=hidden_dim, num_layers=num_layers,
embedding_dim=embedding_dim, output_dim=output_dim)
model.load_state_dict(torch.load(model_path, 'cpu'))
# 模型预测,注意输入的数据第一个input_shape,就是180
y_pred = model(x.long().transpose(1, 0))
print('输入语句: %s' % sent)
print('新闻分类结果: %s' % idx2label[y_pred.argmax().item()])
except KeyError as err:
print("您输入的句子有汉字不在词汇表中,请重新输入!")
print("不在词汇表中的单词为:%s." % err)
完整源码
【PyTorch深度学习项目实战100例】—— 基于pytorch使用LSTM实现新闻本文分类任务 | 第9例