API

 torch / torch


torch.gather

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

Gathers values along an axis specified by dim.

For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

If input is an n-dimensional tensor with size (x0,x1...,xi1,xi,xi+1,...,xn1)(x_0, x_1..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1}) and dim = i, then index must be an nn -dimensional tensor with size (x0,x1,...,xi1,y,xi+1,...,xn1)(x_0, x_1, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1}) where y1y \geq 1 and out will have the same size as index.

Parameters
  • input (Tensor) – the source tensor

  • dim (int) – the axis along which to index

  • index (LongTensor) – the indices of elements to gather

  • sparse_grad (bool,optional) – If True, gradient w.r.t. input will be a sparse tensor.

  • out (Tensor, optional) – the destination tensor

Example:

>>> t = torch.tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
tensor([[ 1,  1],
        [ 4,  3]])

此页内容是否对您有帮助