torch / nn / torch.nn
torch.nn.utils.spectral_norm¶
-
torch.nn.utils.
spectral_norm
(module: T_module, name: str = 'weight', n_power_iterations: int = 1, eps: float = 1e-12, dim: Optional[int] = None) → T_module[source]¶ Applies spectral normalization to a parameter in the given module.
Spectral normalization stabilizes the training of discriminators (critics) in Generative Adversarial Networks (GANs) by rescaling the weight tensor with spectral norm of the weight matrix calculated using power iteration method. If the dimension of the weight tensor is greater than 2, it is reshaped to 2D in power iteration method to get spectral norm. This is implemented via a hook that calculates spectral norm and rescales weight before every
forward()
call.See Spectral Normalization for Generative Adversarial Networks .
- Parameters
module (nn.Module) – containing module
name (str, optional) – name of weight parameter
n_power_iterations (int, optional) – number of power iterations to calculate spectral norm
eps (float, optional) – epsilon for numerical stability in calculating norms
dim (int, optional) – dimension corresponding to number of outputs, the default is
0
, except for modules that are instances of ConvTranspose{1,2,3}d, when it is1
- Returns
The original module with the spectral norm hook
Example:
>>> m = spectral_norm(nn.Linear(20, 40)) >>> m Linear(in_features=20, out_features=40, bias=True) >>> m.weight_u.size() torch.Size([40])
此页内容是否对您有帮助