+2 votes

Best answer

You can use the **take() **function with indices of elements as an argument. It will return a tensor whose shape will be the same as the indices. The take() function treats the input tensor as a 1D tensor.

Here are an examples:

**When indices are 1D tensor**

>>> import torch

>>> a=torch.randn(6,4)

>>> a

tensor([[-0.3410, -2.3171, 0.2685, -1.4083],

[-0.1782, 0.4501, 0.4013, -0.4777],

[-0.8800, -0.8078, -1.0272, 0.0961],

[-1.2799, -0.5404, -1.3871, -1.5463],

[-0.3515, -0.0466, -1.5026, 0.6122],

[ 0.7668, -1.1009, -0.5753, -0.0123]])

>>> i=torch.tensor([1, 5, 6, 8])>>> torch.take(a,i)

tensor([-2.3171, 0.4501, 0.4013, -0.8800])

**When indices are 2D tensor**

>>> import torch

>>> a=torch.randn(6,4)

>>> a

tensor([[-0.3410, -2.3171, 0.2685, -1.4083],

[-0.1782, 0.4501, 0.4013, -0.4777],

[-0.8800, -0.8078, -1.0272, 0.0961],

[-1.2799, -0.5404, -1.3871, -1.5463],

[-0.3515, -0.0466, -1.5026, 0.6122],

[ 0.7668, -1.1009, -0.5753, -0.0123]])

>>> i=torch.tensor([[1,2],[3,4]])>>> torch.take(a,i)

tensor([[-2.3171, 0.2685],

[-1.4083, -0.1782]])