API

 torch / torch


torch.bmm

torch.bmm(input, mat2, *, deterministic=False, out=None) → Tensor

Performs a batch matrix-matrix product of matrices stored in input and mat2.

input and mat2 must be 3-D tensors each containing the same number of matrices.

If input is a (b×n×m)(b \times n \times m) tensor, mat2 is a (b×m×p)(b \times m \times p) tensor, out will be a (b×n×p)(b \times n \times p) tensor.

outi=inputi@mat2i\text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i

This operator supports TensorFloat32.

Note

This function does not broadcast. For broadcasting matrix products, see torch.matmul().

Parameters
  • input (Tensor) – the first batch of matrices to be multiplied

  • mat2 (Tensor) – the second batch of matrices to be multiplied

Keyword Arguments
  • deterministic (bool, optional) – flag to choose between a faster non-deterministic calculation, or a slower deterministic calculation. This argument is only available for sparse-dense CUDA bmm. Default: False

  • out (Tensor, optional) – the output tensor.

Example:

>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])

此页内容是否对您有帮助