torch / nn / torch.nn
CustomFromMask¶
-
class
torch.nn.utils.prune.
CustomFromMask
(mask)[source]¶ -
classmethod
apply
(module, name, mask)[source]¶ Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask.
-
apply_mask
(module)¶ Simply handles the multiplication between the parameter being pruned and the generated mask. Fetches the mask and the original tensor from the module and returns the pruned version of the tensor.
- Parameters
module (nn.Module) – module containing the tensor to prune
- Returns
pruned version of the input tensor
- Return type
pruned_tensor (torch.Tensor)
-
prune
(t, default_mask=None)¶ Computes and returns a pruned version of input tensor
t
according to the pruning rule specified incompute_mask()
.- Parameters
t (torch.Tensor) – tensor to prune (of same dimensions as
default_mask
).default_mask (torch.Tensor, optional) – mask from previous pruning iteration, if any. To be considered when determining what portion of the tensor that pruning should act on. If None, default to a mask of ones.
- Returns
pruned version of tensor
t
.
-
remove
(module)¶ Removes the pruning reparameterization from a module. The pruned parameter named
name
remains permanently pruned, and the parameter namedname+'_orig'
is removed from the parameter list. Similarly, the buffer namedname+'_mask'
is removed from the buffers.Note
Pruning itself is NOT undone or reversed!
-
classmethod
此页内容是否对您有帮助