# [Python] How to select elements of a tensor which are greater than a given value

How can I find the elements of a tensor which are greater than some value "k"?

by (233k points)
selected by

One approach is very similar to selecting elements from a Numpy array. You check the tensor for elements greater than a given value. It returns a tensor of True/False values. Then you use the tensor of True/False to find the elements in the original tensor which are greater than the given value.

Here is an example:

>>> import torch
>>> a=torch.randn(6,4)
>>> a
tensor([[-0.0457, -0.4924, -0.7026,  0.0567],
[-0.5104, -0.1395, -0.3003,  0.8491],
[ 2.2846,  0.5619, -0.1806,  0.9625],
[ 0.7884,  1.1767,  2.0025, -0.0589],
[-0.1579,  0.8199, -0.5279,  0.2966],
[ 0.0946, -0.7405,  0.4907,  1.3673]])
>>> a>1
tensor([[False, False, False, False],
[False, False, False, False],
[ True, False, False, False],
[False,  True,  True, False],
[False, False, False, False],
[False, False, False,  True]])
>>> a[a>1]
tensor([2.2846, 1.1767, 2.0025, 1.3673])

Another approach is to use the torch.masked_select() function. Argument "mask" in the function is nothing but your selection criteria, i.e., tensor > k

Here is an example using this function:

tensor([2.2846, 1.1767, 2.0025, 1.3673])

tensor([2.2846, 1.1767, 2.0025, 1.3673])