torch / torch.jit
torch.jit.trace_module¶
-
torch.jit.
trace_module
(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch._C.CompilationUnit object>)[source]¶ Trace a module and return an executable
ScriptModule
that will be optimized using just-in-time compilation. When a module is passed totorch.jit.trace
, only theforward
method is run and traced. Withtrace_module
, you can specify a dictionary of method names to example inputs to trace (see theinputs
) argument below.See
torch.jit.trace
for more information on tracing.- Parameters
mod (torch.nn.Module) – A
torch.nn.Module
containing methods whose names are specified ininputs
. The given methods will be compiled as a part of a single ScriptModule.inputs (dict) – A dict containing sample inputs indexed by method names in
mod
. The inputs will be passed to methods whose names correspond to inputs’ keys while tracing.{ 'forward' : example_forward_input, 'method2': example_method2_input}
- Keyword Arguments
check_trace (
bool
, optional) – Check if the same inputs run through traced code produce the same outputs. Default:True
. You might want to disable this if, for example, your network contains non- deterministic ops or if you are sure that the network is correct despite a checker failure.check_inputs (list of dicts, optional) – A list of dicts of input arguments that should be used to check the trace against what is expected. Each tuple is equivalent to a set of input arguments that would be specified in
inputs
. For best results, pass in a set of checking inputs representative of the space of shapes and types of inputs you expect the network to see. If not specified, the originalinputs
are used for checkingcheck_tolerance (float, optional) – Floating-point comparison tolerance to use in the checker procedure. This can be used to relax the checker strictness in the event that results diverge numerically for a known reason, such as operator fusion.
- Returns
A
ScriptModule
object with a singleforward
method containing the traced code. Whenfunc
is atorch.nn.Module
, the returnedScriptModule
will have the same set of sub-modules and parameters asfunc
.
Example (tracing a module with multiple methods):
import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv = nn.Conv2d(1, 1, 3) def forward(self, x): return self.conv(x) def weighted_kernel_sum(self, weight): return weight * self.conv.weight n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3) # Trace a specific method and construct `ScriptModule` with # a single `forward` method module = torch.jit.trace(n.forward, example_forward_input) # Trace a module (implicitly traces `forward`) and construct a # `ScriptModule` with a single `forward` method module = torch.jit.trace(n, example_forward_input) # Trace specific methods on a module (specified in `inputs`), constructs # a `ScriptModule` with `forward` and `weighted_kernel_sum` methods inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight} module = torch.jit.trace_module(n, inputs)
此页内容是否对您有帮助