API

 torch / torch


enable_grad

class torch.enable_grad[source]

Context-manager that enables gradient calculation.

Enables gradient calculation, if it has been disabled via no_grad or set_grad_enabled.

This context manager is thread local; it will not affect computation in other threads.

Also functions as a decorator. (Make sure to instantiate with parenthesis.)

Example:

>>> x = torch.tensor([1], requires_grad=True)
>>> with torch.no_grad():
...   with torch.enable_grad():
...     y = x * 2
>>> y.requires_grad
True
>>> y.backward()
>>> x.grad
>>> @torch.enable_grad()
... def doubler(x):
...     return x * 2
>>> with torch.no_grad():
...     z = doubler(x)
>>> z.requires_grad
True

此页内容是否对您有帮助