+1 vote

Best answer

First, you need to find which all elements of a tensor are greater than the given value, and then you can apply the** torch.numel()** function to the returned tensor to get the count.

Here is an example:

>>> import torch

>>> a=torch.randn(6,4)

>>> a

tensor([[-0.0457, -0.4924, -0.7026, 0.0567],

[-0.5104, -0.1395, -0.3003, 0.8491],

[ 2.2846, 0.5619, -0.1806, 0.9625],

[ 0.7884, 1.1767, 2.0025, -0.0589],

[-0.1579, 0.8199, -0.5279, 0.2966],

[ 0.0946, -0.7405, 0.4907, 1.3673]])>>> a>1

tensor([[False, False, False, False],

[False, False, False, False],

[ True, False, False, False],

[False, True, True, False],

[False, False, False, False],

[False, False, False, True]])>>> torch.numel(a[a>1])

4