API

 torch / torch


torch.mm

torch.mm(input, mat2, *, out=None) → Tensor

Performs a matrix multiplication of the matrices input and mat2.

If input is a (n×m)(n \times m) tensor, mat2 is a (m×p)(m \times p) tensor, out will be a (n×p)(n \times p) tensor.

Note

This function does not broadcast. For broadcasting matrix products, see torch.matmul().

This operator supports TensorFloat32.

Parameters
  • input (Tensor) – the first matrix to be multiplied

  • mat2 (Tensor) – the second matrix to be multiplied

Keyword Arguments

out (Tensor, optional) – the output tensor.

Example:

>>> mat1 = torch.randn(2, 3)
>>> mat2 = torch.randn(3, 3)
>>> torch.mm(mat1, mat2)
tensor([[ 0.4851,  0.5037, -0.3633],
        [-0.0760, -3.6705,  2.4784]])

此页内容是否对您有帮助