API

 torch / torch


torch.solve

torch.solve(input, A, *, out=None) -> (Tensor, Tensor)

This function returns the solution to the system of linear equations represented by AX=BAX = B and the LU factorization of A, in order as a namedtuple solution, LU.

LU contains L and U factors for LU factorization of A.

torch.solve(B, A) can take in 2D inputs B, A or inputs that are batches of 2D matrices. If the inputs are batches, then returns batched outputs solution, LU.

Note

Irrespective of the original strides, the returned matrices solution and LU will be transposed, i.e. with strides like B.contiguous().transpose(-1, -2).stride() and A.contiguous().transpose(-1, -2).stride() respectively.

Parameters
  • input (Tensor) – input matrix BB of size (,m,k)(*, m, k) , where * is zero or more batch dimensions.

  • A (Tensor) – input square matrix of size (,m,m)(*, m, m) , where * is zero or more batch dimensions.

Keyword Arguments

out ((Tensor, Tensor), optional) – optional output tuple.

Example:

>>> A = torch.tensor([[6.80, -2.11,  5.66,  5.97,  8.23],
                      [-6.05, -3.30,  5.36, -4.44,  1.08],
                      [-0.45,  2.58, -2.70,  0.27,  9.04],
                      [8.32,  2.71,  4.35,  -7.17,  2.14],
                      [-9.67, -5.14, -7.26,  6.08, -6.87]]).t()
>>> B = torch.tensor([[4.02,  6.19, -8.22, -7.57, -3.03],
                      [-1.56,  4.00, -8.67,  1.75,  2.86],
                      [9.81, -4.09, -4.57, -8.61,  8.99]]).t()
>>> X, LU = torch.solve(B, A)
>>> torch.dist(B, torch.mm(A, X))
tensor(1.00000e-06 *
       7.0977)

>>> # Batched solver example
>>> A = torch.randn(2, 3, 1, 4, 4)
>>> B = torch.randn(2, 3, 1, 4, 6)
>>> X, LU = torch.solve(B, A)
>>> torch.dist(B, A.matmul(X))
tensor(1.00000e-06 *
   3.6386)

此页内容是否对您有帮助