API

 torch / torch


torch.lu

torch.lu(*args, **kwargs)

Computes the LU factorization of a matrix or batches of matrices A. Returns a tuple containing the LU factorization and pivots of A. Pivoting is done if pivot is set to True.

Note

The pivots returned by the function are 1-indexed. If pivot is False, then the returned pivots is a tensor filled with zeros of the appropriate size.

Note

LU factorization with pivot = False is not available for CPU, and attempting to do so will throw an error. However, LU factorization with pivot = False is available for CUDA.

Note

This function does not check if the factorization was successful or not if get_infos is True since the status of the factorization is present in the third element of the return tuple.

Note

In the case of batches of square matrices with size less or equal to 32 on a CUDA device, the LU factorization is repeated for singular matrices due to the bug in the MAGMA library (see magma issue 13).

Note

L, U, and P can be derived using torch.lu_unpack().

Parameters
  • A (Tensor) – the tensor to factor of size (,m,n)(*, m, n)

  • pivot (bool, optional) – controls whether pivoting is done. Default: True

  • get_infos (bool, optional) – if set to True, returns an info IntTensor. Default: False

  • out (tuple, optional) – optional output tuple. If get_infos is True, then the elements in the tuple are Tensor, IntTensor, and IntTensor. If get_infos is False, then the elements in the tuple are Tensor, IntTensor. Default: None

Returns

A tuple of tensors containing

  • factorization (Tensor): the factorization of size (,m,n)(*, m, n)

  • pivots (IntTensor): the pivots of size (,m)(*, m)

  • infos (IntTensor, optional): if get_infos is True, this is a tensor of size ()(*) where non-zero values indicate whether factorization for the matrix or each minibatch has succeeded or failed

Return type

(Tensor, IntTensor, IntTensor (optional))

Example:

>>> A = torch.randn(2, 3, 3)
>>> A_LU, pivots = torch.lu(A)
>>> A_LU
tensor([[[ 1.3506,  2.5558, -0.0816],
         [ 0.1684,  1.1551,  0.1940],
         [ 0.1193,  0.6189, -0.5497]],

        [[ 0.4526,  1.2526, -0.3285],
         [-0.7988,  0.7175, -0.9701],
         [ 0.2634, -0.9255, -0.3459]]])
>>> pivots
tensor([[ 3,  3,  3],
        [ 3,  3,  3]], dtype=torch.int32)
>>> A_LU, pivots, info = torch.lu(A, get_infos=True)
>>> if info.nonzero().size(0) == 0:
...   print('LU factorization succeeded for all samples!')
LU factorization succeeded for all samples!

此页内容是否对您有帮助