API

 torch / torch


set_grad_enabled

class torch.set_grad_enabled(mode: bool)[source]

Context-manager that sets gradient calculation to on or off.

set_grad_enabled will enable or disable grads based on its argument mode. It can be used as a context-manager or as a function.

This context manager is thread local; it will not affect computation in other threads.

Parameters

mode (bool) – Flag whether to enable grad (True), or disable (False). This can be used to conditionally enable gradients.

Example:

>>> x = torch.tensor([1], requires_grad=True)
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
...   y = x * 2
>>> y.requires_grad
False
>>> torch.set_grad_enabled(True)
>>> y = x * 2
>>> y.requires_grad
True
>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False

此页内容是否对您有帮助