说明 pytorch 中 numel 函数的用法
最编程
2024-06-01 20:23:41
...
获取tensor中一共包含多少个元素
import torch x = torch.randn(3,3) print("number elements of x is ",x.numel()) y = torch.randn(3,10,5) print("number elements of y is ",y.numel())
输出:
number elements of x is 9
number elements of y is 150
27和150分别位x和y中各有多少个元素或变量
补充:pytorch获取张量元素个数numel()的用法
numel就是"number of elements"的简写。
numel()可以直接返回int类型的元素个数
import torch a = torch.randn(1, 2, 3, 4) b = a.numel() print(type(b)) # int print(b) # 24
通过numel()函数,我们可以迅速查看一个张量到底又多少元素。
补充:pytorch 卷积结构和numel()函数
看代码吧~
from torch import nn class CNN(nn.Module): def __init__(self, num_channels=1, d=56, s=12, m=4): super(CNN, self).__init__() self.first_part = nn.Sequential( nn.Conv2d(num_channels, d, kernel_size=3, padding=5//2), nn.Conv2d(num_channels, d, kernel_size=(1,3), padding=5//2), nn.Conv2d(num_channels, d, kernel_size=(3,1), padding=5//2), nn.PReLU(d) ) def forward(self, x): x = self.first_part(x) return x model = CNN() for m in model.first_part: if isinstance(m, nn.Conv2d): # print('m:',m.weight.data) print('m:',m.weight.data[0]) print('m:',m.weight.data[0][0]) print('m:',m.weight.data.numel()) #numel() 计算矩阵中元素的个数 结果: m: tensor([[[-0.2822, 0.0128, -0.0244], [-0.2329, 0.1037, 0.2262], [ 0.2845, -0.3094, 0.1443]]]) #卷积核大小为3x3 m: tensor([[-0.2822, 0.0128, -0.0244], [-0.2329, 0.1037, 0.2262], [ 0.2845, -0.3094, 0.1443]]) #卷积核大小为3x3 m: 504 # = 56 x (3 x 3) 输出通道数为56,卷积核大小为3x3 m: tensor([-0.0335, 0.2945, 0.2512, 0.2770, 0.2071, 0.1133, -0.1883, 0.2738, 0.0805, 0.1339, -0.3000, -0.1911, -0.1760, 0.2855, -0.0234, -0.0843, 0.1815, 0.2357, 0.2758, 0.2689, -0.2477, -0.2528, -0.1447, -0.0903, 0.1870, 0.0945, -0.2786, -0.0419, 0.1577, -0.3100, -0.1335, -0.3162, -0.1570, 0.3080, 0.0951, 0.1953, 0.1814, -0.1936, 0.1466, -0.2911, -0.1286, 0.3024, 0.1143, -0.0726, -0.2694, -0.3230, 0.2031, -0.2963, 0.2965, 0.2525, -0.2674, 0.0564, -0.3277, 0.2185, -0.0476, 0.0558]) bias偏置的值 m: tensor([[[ 0.5747, -0.3421, 0.2847]]]) 卷积核大小为1x3 m: tensor([[ 0.5747, -0.3421, 0.2847]]) 卷积核大小为1x3 m: 168 # = 56 x (1 x 3) 输出通道数为56,卷积核大小为1x3 m: tensor([ 0.5328, -0.5711, -0.1945, 0.2844, 0.2012, -0.0084, 0.4834, -0.2020, -0.0941, 0.4683, -0.2386, 0.2781, -0.1812, -0.2990, -0.4652, 0.1228, -0.0627, 0.3112, -0.2700, 0.0825, 0.4345, -0.0373, -0.3220, -0.5038, -0.3166, -0.3823, 0.3947, -0.3232, 0.1028, 0.2378, 0.4589, 0.1675, -0.3112, -0.0905, -0.0705, 0.2763, 0.5433, 0.2768, -0.3804, 0.4855, -0.4880, -0.4555, 0.4143, 0.5474, 0.3305, -0.0381, 0.2483, 0.5133, -0.3978, 0.0407, 0.2351, 0.1910, -0.5385, 0.1340, 0.1811, -0.3008]) bias偏置的值 m: tensor([[[0.0184], [0.0981], [0.1894]]]) 卷积核大小为3x1 m: tensor([[0.0184], [0.0981], [0.1894]]) 卷积核大小为3x1 m: 168 # = 56 x (3 x 1) 输出通道数为56,卷积核大小为3x1 m: tensor([-0.2951, -0.4475, 0.1301, 0.4747, -0.0512, 0.2190, 0.3533, -0.1158, 0.2237, -0.1407, -0.4756, 0.1637, -0.4555, -0.2157, 0.0577, -0.3366, -0.3252, 0.2807, 0.1660, 0.2949, -0.2886, -0.5216, 0.1665, 0.2193, 0.2038, -0.1357, 0.2626, 0.2036, 0.3255, 0.2756, 0.1283, -0.4909, 0.5737, -0.4322, -0.4930, -0.0846, 0.2158, 0.5565, 0.3751, -0.3775, -0.5096, -0.4520, 0.2246, -0.5367, 0.5531, 0.3372, -0.5593, -0.2780, -0.5453, -0.2863, 0.5712, -0.2882, 0.4788, 0.3222, -0.4846, 0.2170]) bias偏置的值 '''初始化后''' class CNN(nn.Module): def __init__(self, num_channels=1, d=56, s=12, m=4): super(CNN, self).__init__() self.first_part = nn.Sequential( nn.Conv2d(num_channels, d, kernel_size=3, padding=5//2), nn.Conv2d(num_channels, d, kernel_size=(1,3), padding=5//2), nn.Conv2d(num_channels, d, kernel_size=(3,1), padding=5//2), nn.PReLU(d) ) self._initialize_weights() def _initialize_weights(self): for m in self.first_part: if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel()))) nn.init.zeros_(m.bias.data) def forward(self, x): x = self.first_part(x) return x model = CNN() for m in model.first_part: if isinstance(m, nn.Conv2d): # print('m:',m.weight.data) print('m:',m.weight.data[0]) print('m:',m.weight.data[0][0]) print('m:',m.weight.data.numel()) #numel() 计算矩阵中元素的个数 结果: m: tensor([[[-0.0284, -0.0585, 0.0271], [ 0.0125, 0.0554, 0.0511], [-0.0106, 0.0574, -0.0053]]]) m: tensor([[-0.0284, -0.0585, 0.0271], [ 0.0125, 0.0554, 0.0511], [-0.0106, 0.0574, -0.0053]]) m: 504 m: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) m: tensor([[[ 0.0059, 0.0465, -0.0725]]]) m: tensor([[ 0.0059, 0.0465, -0.0725]]) m: 168 m: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) m: tensor([[[ 0.0599], [-0.1330], [ 0.2456]]]) m: tensor([[ 0.0599], [-0.1330], [ 0.2456]]) m: 168 m: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
以上为个人经验,希望能给大家一个参考,也希望大家多多支持。如有错误或未考虑完全的地方,望不吝赐教。
推荐阅读
-
移位操作函数的应用:circshift、fftshift和ifftshift在matlab中的用法
-
SQL中的几种取整函数:FLOOR、ROUND、CEIL、TRUNC、SIGN的用法介绍
-
简单了解JavaScript中的反正弦函数Math.asin的用法
-
在C++中自定义sort函数的用法指南
-
C++中的转换函数详解:第二部分 - 应用与用法
-
深入理解PyTorch中常用交叉熵损失函数CrossEntropyLoss的用法与解析
-
理解深度学习基础:从神经网络构造到实践 - 1.评分函数介绍 2.SVM损失函数解析 3.正规化惩罚项说明 4.Softmax与交叉熵损失函数详解 5.前向传播中的最优化挑战 6.批量大小(batch_size)实操指南...
-
异步编程RxJava-介绍-前言 前段时间写了一篇对协程的一些理解,里面提到了不管是协程还是callback,本质上其实提供的是一种异步无阻塞的编程模式;并且介绍了java中对异步无阻赛这种编程模式的支持,主要提到了Future和CompletableFuture;之后有同学在下面留言提到了RxJava,刚好最近在看微服务设计这本书,里面提到了响应式扩展(Reactive extensions,Rx),而RxJava是Rx在JVM上的实现,所有打算对RxJava进一步了解。 RxJava简介 RxJava的官网地址:https://github.com/ReactiveX/RxJava, 其中对RxJava进行了一句话描述:RxJava – Reactive Extensions for the JVM – a library for composing asynchronous and event-based programs using observable sequences for the Java VM. 大意就是:一个在Java VM上使用可观测的序列来组成异步的、基于事件的程序的库。 更详细的说明在Netflix技术博客的一篇文章中描述了RxJava的主要特点: 1.易于并发从而更好的利用服务器的能力。 2.易于有条件的异步执行。 3.一种更好的方式来避免回调地狱。 4.一种响应式方法。 与CompletableFuture对比 之前提到CompletableFuture真正的实现了异步的编程模式,一个比较常见的使用场景: CompletableFuture<Integer> future = CompletableFuture.supplyAsync(耗时函数); Future<Integer> f = future.whenComplete((v, e) -> { System.out.println(v); System.out.println(e); }); System.out.println("other..."); 下面用一个简单的例子来看一下RxJava是如何实现异步的编程模式: Observable<Long> observable = Observable.just(1, 2) .subscribeOn(Schedulers.io).map(new Func1<Integer, Long> { @Override public Long call(Integer t) { try { Thread.sleep(1000); //耗时的操作 } catch (InterruptedException e) { e.printStackTrace; } return (long) (t * 2); } }); observable.subscribe(new Subscriber<Long> { @Override public void onCompleted { System.out.println("onCompleted"); } @Override public void onError(Throwable e) { System.out.println("error" + e); } @Override public void onNext(Long result) { System.out.println("result = " + result); } }); System.out.println("other..."); Func1中以异步的方式执行了一个耗时的操作,Subscriber(观察者)被订阅到Observable(被观察者)中,当耗时操作执行完会回调Subscriber中的onNext方法。 其中的异步方式是在subscribeOn(Schedulers.io)中指定的,Schedulers.io可以理解为每次执行耗时操作都启动一个新的线程。 结构上其实和CompletableFuture很像,都是异步的执行一个耗时的操作,然后在有结果的时候主动告诉我结果。那我们还需要RxJava干嘛,不知道你有没有注意,上面的例子中其实提供2条数据流[1,2],并且处理完任何一个都会主动告诉我,当然这只是它其中的一项功能,RxJava还有很多好用的功能,在下面的内容会进行介绍。 异步观察者模式 上面这段代码有没有发现特别像设计模式中的:观察者模式;首先提供一个被观察者Observable,然后把观察者Subscriber添加到了被观察者列表中; RxJava中一共提供了四种角色:Observable、Observer、Subscriber、Subjects Observables和Subjects是两个被观察者,Observers和Subscribers是观察者; 当然我们也可以查看一下源码,看一下jdk中的Observer和RxJava的Observer jdk中的Observer: public interface Observer { void update(Observable o, Object arg); } RxJava的Observer: public interface Observer<T> { void onCompleted; void onError(Throwable e); void onNext(T t); } 同时可以发现Subscriber是implements Observer的: public abstract class Subscriber<T> implements Observer<T>, Subscription 可以发现RxJava中在Observer中引入了2个新的方法:onCompleted和onError onCompleted:即通知观察者Observable没有更多的数据,事件队列完结 onError:在事件处理过程中出异常时,onError会被触发,同时队列自动终止,不允许再有事件发出。 正是因为RxJava提供了同步和异步两种方式进行事件的处理,个人觉得异步的方式更能体现RxJava的价值,所以这里给他命名为异步观察者模式。 好了,下面正式介绍RxJava的那些灵活的操作符,这里仅仅是简单的介绍和简单的实例,具体用在什么场景下,会在以后的文章中介绍 Maven引入
-
匿名函数在 js 中的作用和使用说明
-
关于 R 中 predict 函数用法的简短说明