欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

了解 Pytorch 的参数 "batch_first"。

最编程 2024-04-12 08:55:11
...

用过PyTorch的朋友大概都知道,对于不同的网络层,输入的维度虽然不同,但是通常输入的第一个维度都是batch_size,比如torch.nn.Linear的输入(batch_size,in_features),torch.nn.Conv2d的输入(batch_size, C, H, W)。而RNN的输入却是(seq_len, batch_size, input_size),batch_size位于第二维度!虽然你可以将batch_size和序列长度seq_len对换位置,此时只需要令batch_first=True。
但是为什么RNN输入默认不是batch first=True?这是为了便于并行计算。因为cuDNN中RNN的API就是batch_size在第二维度!进一步,为啥cuDNN要这么做呢?因为batch first意味着模型的输入(一个Tensor)在内存中存储时,先存储第一个sequence,再存储第二个... 而如果是seq_len first,模型的输入在内存中,先存储所有序列的第一个单元,然后是第二个单元... 两种区别如下图所示:

batch firth v.s. seq_len first

seq_len first意味着不同序列中同一个时刻对应的输入单元在内存中是毗邻的,这样才能做到真正的batch计算。

[参考资料] https://zhuanlan.zhihu.com/p/32103001