API

 torch / torch


torch.flatten

torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

Flattens a contiguous range of dims in a tensor.

Parameters
  • input (Tensor) – the input tensor.

  • start_dim (int) – the first dim to flatten

  • end_dim (int) – the last dim to flatten

Example:

>>> t = torch.tensor([[[1, 2],
                       [3, 4]],
                      [[5, 6],
                       [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])

此页内容是否对您有帮助