torch / torch
torch.solve¶
-
torch.
solve
(input, A, *, out=None) -> (Tensor, Tensor)¶ This function returns the solution to the system of linear equations represented by and the LU factorization of A, in order as a namedtuple solution, LU.
LU contains L and U factors for LU factorization of A.
torch.solve(B, A) can take in 2D inputs B, A or inputs that are batches of 2D matrices. If the inputs are batches, then returns batched outputs solution, LU.
Note
Irrespective of the original strides, the returned matrices solution and LU will be transposed, i.e. with strides like B.contiguous().transpose(-1, -2).stride() and A.contiguous().transpose(-1, -2).stride() respectively.
- Parameters
- Keyword Arguments
Example:
>>> A = torch.tensor([[6.80, -2.11, 5.66, 5.97, 8.23], [-6.05, -3.30, 5.36, -4.44, 1.08], [-0.45, 2.58, -2.70, 0.27, 9.04], [8.32, 2.71, 4.35, -7.17, 2.14], [-9.67, -5.14, -7.26, 6.08, -6.87]]).t() >>> B = torch.tensor([[4.02, 6.19, -8.22, -7.57, -3.03], [-1.56, 4.00, -8.67, 1.75, 2.86], [9.81, -4.09, -4.57, -8.61, 8.99]]).t() >>> X, LU = torch.solve(B, A) >>> torch.dist(B, torch.mm(A, X)) tensor(1.00000e-06 * 7.0977) >>> # Batched solver example >>> A = torch.randn(2, 3, 1, 4, 4) >>> B = torch.randn(2, 3, 1, 4, 6) >>> X, LU = torch.solve(B, A) >>> torch.dist(B, A.matmul(X)) tensor(1.00000e-06 * 3.6386)
此页内容是否对您有帮助