API

 torch / torch


torch.argmin

torch.argmin(input) → LongTensor

Returns the indices of the minimum value of all elements in the input tensor.

This is the second value returned by torch.min(). See its documentation for the exact semantics of this method.

Note

If there are multiple minimal values then the indices of the first minimal value are returned.

Parameters

input (Tensor) – the input tensor.

Example:

>>> a = torch.randn(4, 4)
>>> a
tensor([[ 0.1139,  0.2254, -0.1381,  0.3687],
        [ 1.0100, -1.1975, -0.0102, -0.4732],
        [-0.9240,  0.1207, -0.7506, -1.0213],
        [ 1.7809, -1.2960,  0.9384,  0.1438]])
>>> torch.argmin(a)
tensor(13)
torch.argmin(input, dim, keepdim=False) → LongTensor

Returns the indices of the minimum values of a tensor across a dimension.

This is the second value returned by torch.min(). See its documentation for the exact semantics of this method.

Parameters
  • input (Tensor) – the input tensor.

  • dim (int) – the dimension to reduce. If None, the argmin of the flattened input is returned.

  • keepdim (bool) – whether the output tensor has dim retained or not. Ignored if dim=None.

Example:

>>> a = torch.randn(4, 4)
>>> a
tensor([[ 0.1139,  0.2254, -0.1381,  0.3687],
        [ 1.0100, -1.1975, -0.0102, -0.4732],
        [-0.9240,  0.1207, -0.7506, -1.0213],
        [ 1.7809, -1.2960,  0.9384,  0.1438]])
>>> torch.argmin(a, dim=1)
tensor([ 2,  1,  3,  1])

此页内容是否对您有帮助