API

 torch / torch


torch.chain_matmul

torch.chain_matmul(*matrices)[source]

Returns the matrix product of the NN 2-D tensors. This product is efficiently computed using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms of arithmetic operations ([CLRS]). Note that since this is a function to compute the product, NN needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned. If NN is 1, then this is a no-op - the original matrix is returned as is.

Parameters

matrices (Tensors...) – a sequence of 2 or more 2-D tensors whose product is to be determined.

Returns

if the ithi^{th} tensor was of dimensions pi×pi+1p_{i} \times p_{i + 1} , then the product would be of dimensions p1×pN+1p_{1} \times p_{N + 1} .

Return type

Tensor

Example:

>>> a = torch.randn(3, 4)
>>> b = torch.randn(4, 5)
>>> c = torch.randn(5, 6)
>>> d = torch.randn(6, 7)
>>> torch.chain_matmul(a, b, c, d)
tensor([[ -2.3375,  -3.9790,  -4.1119,  -6.6577,   9.5609, -11.5095,  -3.2614],
        [ 21.4038,   3.3378,  -8.4982,  -5.2457, -10.2561,  -2.4684,   2.7163],
        [ -0.9647,  -5.8917,  -2.3213,  -5.2284,  12.8615, -12.2816,  -2.5095]])

此页内容是否对您有帮助