基于 Python 的自然语言处理系列(22):模型剪枝(剪枝)
在深度学习领域,尤其是当模型部署到资源有限的环境中时,模型压缩技术变得尤为重要。剪枝(Pruning)是一种常见的模型压缩方法,通过减少模型中不重要的参数,可以在不显著降低模型性能的情况下提升效率。在本文中,我们将详细介绍如何在PyTorch中使用剪枝技术,并通过一些实验展示其效果。
1. 加载数据集与预处理
我们将使用TorchText库加载常用的AG_NEWS数据集,并进行预处理。首先,导入必要的库并设置随机种子以保证实验的可重复性。
import torch, torchdata, torchtext
from torch import nn
import time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
接下来,加载AG_NEWS数据集,并将其拆分为训练集、验证集和测试集。我们使用TorchText的random_split
方法来进行数据划分。
from torchtext.datasets import AG_NEWS
train, test = AG_NEWS()
train_size = len(list(iter(train)))
too_much, train, valid = train.random_split(total_length=train_size, weights = {"too_much": 0.7, "smaller_train": 0.2, "valid": 0.1}, seed=999)
数据预处理
我们将使用Spacy作为分词器,并将文本转换为整数表示。这里,我们使用build_vocab_from_iterator
来为数据集生成词汇表。
from torchtext.data.utils import get_tokenizer
tokenizer = get_tokenizer('spacy', language='en_core_web_sm')
from torchtext.vocab import build_vocab_from_iterator
def yield_tokens(data_iter):
for _, text in data_iter:
yield tokenizer(text)
vocab = build_vocab_from_iterator(yield_tokens(train), specials=['<unk>', '<pad>', '<bos>', '<eos>'])
vocab.set_default_index(vocab["<unk>"])
2. FastText 预训练词向量
接下来,我们将加载FastText的预训练词向量,并将其应用到我们的词汇表中。
from torchtext.vocab import FastText
fast_vectors = FastText(language='simple') # 使用FastText预训练词向量
fast_embedding = fast_vectors.get_vecs_by_tokens(vocab.get_itos()).to(device)
fast_embedding.shape
3. 数据加载器
我们需要定义一个数据加载器collate_fn
,来确保批处理中的序列长度一致(通过填充)。在这里,我们还会生成序列长度信息,以便后续用于LSTM中的打包序列处理。
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
pad_idx = vocab['<pad>']
def collate_batch(batch):
label_list, text_list, length_list = [], [], []
for (_label, _text) in batch:
label_list.append(int(_label) - 1) # 标签从0开始
processed_text = torch.tensor([vocab[token] for token in tokenizer(_text)], dtype=torch.int64)
text_list.append(processed_text)
length_list.append(processed_text.size(0))
return torch.tensor(label_list, dtype=torch.int64), pad_sequence(text_list, padding_value=pad_idx, batch_first=True), torch.tensor(length_list, dtype=torch.int64)
train_loader = DataLoader(train, batch_size=64, shuffle=True, collate_fn=collate_batch)
valid_loader = DataLoader(valid, batch_size=64, shuffle=False, collate_fn=collate_batch)
test_loader = DataLoader(test, batch_size=64, shuffle=False, collate_fn=collate_batch)
4. 模型定义
我们将定义一个双向LSTM模型,并将预训练的FastText词向量作为模型的嵌入层权重初始化。
class LSTM(nn.Module):
def __init__(self, input_dim, emb_dim, hid_dim, output_dim, num_layers, bidirectional, dropout):
super().__init__()
self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=pad_idx)
self.lstm = nn.LSTM(emb_dim, hid_dim, num_layers=num_layers, bidirectional=bidirectional, dropout=dropout, batch_first=True)
self.fc = nn.Linear(hid_dim * 2, output_dim)
def forward(self, text, text_lengths):
embedded = self.embedding(text)
packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths.to('cpu'), enforce_sorted=False, batch_first=True)
packed_output, (hn, cn) = self.lstm(packed_embedded)
output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
hn = torch.cat((hn[-2,:,:], hn[-1,:,:]), dim=1)
return self.fc(hn)
5. 模型训练与评估
我们定义模型的训练与评估函数,并加载预训练好的模型进行测试。
criterion = nn.CrossEntropyLoss()
def accuracy(preds, y):
predicted = torch.max(preds.data, 1)[1]
return (predicted == y).sum().item() / len(y)
def evaluate(model, loader, criterion):
model.eval()
epoch_loss, epoch_acc = 0, 0
with torch.no_grad():
for label, text, text_length in loader:
label, text = label.to(device), text.to(device)
predictions = model(text, text_length).squeeze(1)
loss = criterion(predictions, label)
acc = accuracy(predictions, label)
epoch_loss += loss.item()
epoch_acc += acc
return epoch_loss / len(loader), epoch_acc / len(loader)
test_loss, test_acc = evaluate(model, test_loader, criterion)
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')
6. 剪枝(Pruning)
随机剪枝
首先,我们使用PyTorch的torch.nn.utils.prune
库对模型进行随机剪枝。例如,以下代码将随机剪掉全连接层中95%的连接。
import torch.nn.utils.prune as prune
fc = model.fc
prune.random_unstructured(fc, name="weight", amount=0.95)
print(list(fc.named_buffers())) # 打印权重掩码
基于L1范数的剪枝
我们还可以基于权重的L1范数进行剪枝,以下代码展示了如何根据最小的L1范数剪枝95%的连接。
prune.l1_unstructured(fc, name="weight", amount=0.95)
print(fc.weight)
全局剪枝
全局剪枝是通过在整个模型中移除最低重要性的连接,而不是逐层进行剪枝。我们可以使用global_unstructured
来实现这一目标。
parameters_to_prune = [(model.embedding, 'weight'), (model.lstm, 'weight_ih_l0'), (model.lstm, 'weight_hh_l0'), (model.fc, 'weight')]
prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.7)
7. 自定义剪枝方法
我们还可以通过继承torch.nn.utils.prune.BasePruningMethod
类来自定义剪枝方法。下面是一个简单的自定义剪枝示例,剪去张量中的每隔一个元素。
class ExamplePruningMethod(prune.BasePruningMethod):
PRUNING_TYPE = 'unstructured'
def compute_mask(self, t, default_mask):
mask = default_mask.clone()
mask.view(-1)[::2] = 0
return mask
结语
在本篇文章中,我们探讨了模型剪枝(Pruning)的多种方法,包括随机剪枝、基于L1范数的剪枝和全局剪枝等。这些技术可以有效减少模型参数量,在不明显降低性能的情况下,显著提升模型的推理效率。剪枝方法的选择应根据模型和任务的特点来决定,不同的剪枝策略适用于不同的场景。
剪枝作为模型压缩的一部分,尤其在部署到计算资源受限的设备时,能够大幅减少计算负担。同时,自定义剪枝方法也提供了灵活性,允许开发者根据需求进行更细粒度的优化。通过本文的实践,大家可以尝试不同的剪枝方法,观察其对模型大小和性能的影响。
在下一篇文章中,我们将介绍DrQA,这是一个针对问答系统的经典模型。我们会探讨如何构建一个可以回答复杂问题的问答系统,继续深入自然语言处理领域的实际应用。
如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!
欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。
谢谢大家的支持!
推荐阅读
-
基于 Python 的自然语言处理系列(22):模型剪枝(剪枝)
-
包婷婷 (201550484)作业一 统计软件简介与数据操作-SPSS(Statistical Product and Service Solutions),"统计产品与服务解决方案"软件。最初软件全称为"(SolutionsStatistical Package for the Social Sciences),但是随着SPSS产品服务领域的扩大和服务深度的增加,SPSS公司已于2000年正式将英文全称更改为"统计产品与服务解决方案",标志着SPSS的战略方向正在做出重大调整。为IBM公司推出的一系列用于统计学分析运算、数据挖掘、预测分析和决策支持任务的软件产品及相关服务的总称SPSS,有Windows和Mac OS X等版本。 1984年SPSS总部首先推出了世界上第一个统计分析软件微机版本SPSS/PC+,开创了SPSS微机系列产品的开发方向,极大地扩充了它的应用范围,并使其能很快地应用于自然科学、技术科学、社会科学的各个领域。世界上许多有影响的报刊杂志纷纷就SPSS的自动统计绘图、数据的深入分析、使用方便、功能齐全等方面给予了高度的评价。 R统计软件介绍 R是一套完整的数据处理、计算和制图软件系统。其功能包括:数据存储和处理系统;数组运算工具(其向量、矩阵运算方面功能尤其强大);完整连贯的统计分析工具;优秀的统计制图功能;简便而强大的编程语言:可操纵数据的输入和输出,可实现分支、循环,用户可自定义功能。 与其说R是一种统计软件,还不如说R是一种数学计算的环境,因为R并不是仅仅提供若干统计程序、使用者只需指定数据库和若干参数便可进行一个统计分析。R的思想是:它可以提供一些集成的统计工具,但更大量的是它提供各种数学计算、统计计算的函数,从而使使用者能灵活机动的进行数据分析,甚至创造出符合需要的新的统计计算方法。 该语言的语法表面上类似 C,但在语义上是函数设计语言(functional programming language)的变种并且和Lisp 以及 APL有很强的兼容性。特别的是,它允许在"语言上计算"(computing on the language)。这使得它可以把表达式作为函数的输入参数,而这种做法对统计模拟和绘图非常有用。 R是一个免费的*软件,它有UNIX、LINUX、MacOS和WINDOWS版本,都是可以免费下载和使用的。在R主页那儿可以下载到R的安装程序、各种外挂程序和文档。在R的安装程序中只包含了8个基础模块,其他外在模块可以通过CRAN获得。 二、R语言 R是用于统计分析、绘图的语言和操作环境。R是属于GNU系统的一个*、免费、源代码开放的软件,它是一个用于统计计算和统计制图的优秀工具。 R作为一种统计分析软件,是集统计分析与图形显示于一体的。它可以运行于UNIX,Windows和Macintosh的操作系统上,而且嵌入了一个非常方便实用的帮助系统,相比于其他统计分析软件,R还有以下特点: 1.R是*软件。这意味着它是完全免费,开放源代码的。可以在它的网站及其镜像中下载任何有关的安装程序、源代码、程序包及其源代码、文档资料。标准的安装文件身自身就带有许多模块和内嵌统计函数,安装好后可以直接实现许多常用的统计功能。[2] 2.R是一种可编程的语言。作为一个开放的统计编程环境,语法通俗易懂,很容易学会和掌握语言的语法。而且学会之后,我们可以编制自己的函数来扩展现有的语言。这也就是为什么它的更新速度比一般统计软件,如,SPSS,SAS等快得多。大多数最新的统计方法和技术都可以在R中直接得到。[2] 3. 所有R的函数和数据集是保存在程序包里面的。只有当一个包被载入时,它的内容才可以被访问。一些常用、基本的程序包已经被收入了标准安装文件中,随着新的统计分析方法的出现,标准安装文件中所包含的程序包也随着版本的更新而不断变化。在另外版安装文件中,已经包含的程序包有:base一R的基础模块、mle一极大似然估计模块、ts一时间序列分析模块、mva一多元统计分析模块、survival一生存分析模块等等.[2] 4.R具有很强的互动性。除了图形输出是在另外的窗口处,它的输入输出窗口都是在同一个窗口进行的,输入语法中如果出现错误会马上在窗口口中得到提示,对以前输入过的命令有记忆功能,可以随时再现、编辑修改以满足用户的需要。输出的图形可以直接保存为JPG,BMP,PNG等图片格式,还可以直接保存为PDF文件。另外,和其他编程语言和数据库之间有很好的接口。[2] 5.如果加入R的帮助邮件列表一,每天都可能会收到几十份关于R的邮件资讯。可以和全球一流的统计计算方面的专家讨论各种问题,可以说是全世界最大、最前沿的统计学家思维的聚集地.[2] R是基于S语言的一个GNU项目,所以也可以当作S语言的一种实现,通常用S语言编写的代码都可以不作修改的在R环境下运行。 R的语法是来自Scheme。R的使用与S-PLUS有很多类似之处,这两种语言有一定的兼容性。S-PLUS的使用手册,只要稍加修改就可作为R的使用手册。所以有人说:R,是S-PLUS的一个“克隆”。 但是请不要忘了:R是免费的(R is free)。R语言源代码托管在github,具体地址可以看参考资料。[3] 。 R语言的下载可以通过CRAN的镜像来查找。 R语言有域名为.cn的下载地址,有六个,其中两个由Datagurn,由 中国科学技术大学提供的。R语言Windows版,其中由两个下载地点是Datagurn和 USTC提供的。 三、stata Stata 是一套提供其使用者数据分析、数据管理以及绘制专业图表的完整及整合性统计软件。它提供许许多多功能,包含线性混合模型、均衡重复反复及多项式普罗比模式。用Stata绘制的统计图形相当精美。 新版本的STATA采用最具亲和力的窗口接口,使用者自行建立程序时,软件能提供具有直接命令式的语法。Stata提供完整的使用手册,包含统计样本建立、解释、模型与语法、文献等超过一万余页的出版品。 除此之外,Stata软件可以透过网络实时更新每天的最新功能,更可以得知世界各地的使用者对于STATA公司提出的问题与解决之道。使用者也可以透过Stata. Journal获得许许多多的相关讯息以及书籍介绍等。另外一个获取庞大资源的管道就是Statalist,它是一个独立的listserver,每月交替提供使用者超过1000个讯息以及50个程序。 四、PYTHON