Python torch.quantile 示例说明
最编程
2024-04-13 17:37:30
...
沿维度 dim
计算 input
张量的每一行的 q-th 分位数。
为了计算分位数,我们将 [0, 1] 中的 q 映射到索引 [0, n] 的范围内,以找到分位数在排序输入中的位置。如果分位数位于两个数据点 a < b
之间,索引 i
和 j
按排序顺序,则使用线性插值计算结果,如下所示:
a + (b - a) * fraction
,其中 fraction
是计算的分位数索引的小数部分。
如果 q
是一维张量,则输出的第一个维度表示分位数并且大小等于 q
的大小,其余维度是归约后剩下的维度。
注意
默认情况下 dim
是 None
导致 input
张量在计算之前被展平。
例子:
>>> a = torch.randn(2, 3)
>>> a
tensor([[ 0.0795, -1.2117, 0.9765],
[ 1.1707, 0.6706, 0.4884]])
>>> q = torch.tensor([0.25, 0.5, 0.75])
>>> torch.quantile(a, q, dim=1, keepdim=True)
tensor([[[-0.5661],
[ 0.5795]],
[[ 0.0795],
[ 0.6706]],
[[ 0.5280],
[ 0.9206]]])
>>> torch.quantile(a, q, dim=1, keepdim=True).shape
torch.Size([3, 2, 1])
>>> a = torch.arange(4.)
>>> a
tensor([0., 1., 2., 3.])