API

 torch / torch


torch.tensordot

torch.tensordot(a, b, dims=2)[source]

Returns a contraction of a and b over multiple dimensions.

tensordot implements a generalized matrix product.

Parameters
  • a (Tensor) – Left tensor to contract

  • b (Tensor) – Right tensor to contract

  • dims (int or tuple of two lists of python:integers) – number of dimensions to contract or explicit lists of dimensions for a and b respectively

When called with a non-negative integer argument dims = dd , and the number of dimensions of a and b is mm and nn , respectively, tensordot() computes

ri0,...,imd,id,...,in=k0,...,kd1ai0,...,imd,k0,...,kd1×bk0,...,kd1,id,...,in.r_{i_0,...,i_{m-d}, i_d,...,i_n} = \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}.

When called with dims of the list form, the given dimensions will be contracted in place of the last dd of a and the first dd of bb . The sizes in these dimensions must match, but tensordot() will deal with broadcasted dimensions.

Examples:

>>> a = torch.arange(60.).reshape(3, 4, 5)
>>> b = torch.arange(24.).reshape(4, 3, 2)
>>> torch.tensordot(a, b, dims=([1, 0], [0, 1]))
tensor([[4400., 4730.],
        [4532., 4874.],
        [4664., 5018.],
        [4796., 5162.],
        [4928., 5306.]])

>>> a = torch.randn(3, 4, 5, device='cuda')
>>> b = torch.randn(4, 5, 6, device='cuda')
>>> c = torch.tensordot(a, b, dims=2).cpu()
tensor([[ 8.3504, -2.5436,  6.2922,  2.7556, -1.0732,  3.2741],
        [ 3.3161,  0.0704,  5.0187, -0.4079, -4.3126,  4.8744],
        [ 0.8223,  3.9445,  3.2168, -0.2400,  3.4117,  1.7780]])

此页内容是否对您有帮助