API

 torch / torch


torch.argsort

torch.argsort(input, dim=-1, descending=False) → LongTensor

Returns the indices that sort a tensor along a given dimension in ascending order by value.

This is the second value returned by torch.sort(). See its documentation for the exact semantics of this method.

Parameters
  • input (Tensor) – the input tensor.

  • dim (int, optional) – the dimension to sort along

  • descending (bool, optional) – controls the sorting order (ascending or descending)

Example:

>>> a = torch.randn(4, 4)
>>> a
tensor([[ 0.0785,  1.5267, -0.8521,  0.4065],
        [ 0.1598,  0.0788, -0.0745, -1.2700],
        [ 1.2208,  1.0722, -0.7064,  1.2564],
        [ 0.0669, -0.2318, -0.8229, -0.9280]])


>>> torch.argsort(a, dim=1)
tensor([[2, 0, 3, 1],
        [3, 2, 1, 0],
        [2, 1, 0, 3],
        [3, 2, 1, 0]])

此页内容是否对您有帮助