MXNet

API

 mxnet / mxnet.callback


mxnet.callback

Callback functions that can be used to track various status during epoch.

Classes

LogValidationMetricsCallback

Just logs the eval metrics at the end of an epoch.

ProgressBar(total[, length])

Displays a progress bar, indicating the percentage of batches processed within each epoch.

Speedometer(batch_size[, frequent, auto_reset])

Logs training speed and evaluation metrics periodically.

Functions

do_checkpoint(prefix[, period])

A callback that saves a model checkpoint every few epochs.

log_train_metric(period[, auto_reset])

Callback to log the training evaluation result every period.

module_checkpoint(mod, prefix[, period, …])

Callback to checkpoint Module to prefix every epoch.

class mxnet.callback.LogValidationMetricsCallback[source]

Bases: object

Just logs the eval metrics at the end of an epoch.

class mxnet.callback.ProgressBar(total, length=80)[source]

Bases: object

Displays a progress bar, indicating the percentage of batches processed within each epoch.

Parameters
  • total (int) – total number of batches per epoch

  • length (int) – number of chars to define maximum length of progress bar

Examples

>>> progress_bar = mx.callback.ProgressBar(total=2)
>>> mod.fit(data, num_epoch=5, batch_end_callback=progress_bar)
[========--------] 50.0%
[================] 100.0%
class mxnet.callback.Speedometer(batch_size, frequent=50, auto_reset=True)[source]

Bases: object

Logs training speed and evaluation metrics periodically.

Parameters
  • batch_size (int) – Batch size of data.

  • frequent (int) – Specifies how frequently training speed and evaluation metrics must be logged. Default behavior is to log once every 50 batches.

  • auto_reset (bool) – Reset the evaluation metrics after each log.

Example

>>> # Print training speed and evaluation metrics every ten batches. Batch size is one.
>>> module.fit(iterator, num_epoch=n_epoch,
... batch_end_callback=mx.callback.Speedometer(1, 10))
Epoch[0] Batch [10] Speed: 1910.41 samples/sec  Train-accuracy=0.200000
Epoch[0] Batch [20] Speed: 1764.83 samples/sec  Train-accuracy=0.400000
Epoch[0] Batch [30] Speed: 1740.59 samples/sec  Train-accuracy=0.500000
mxnet.callback.do_checkpoint(prefix, period=1)[source]

A callback that saves a model checkpoint every few epochs. Each checkpoint is made up of a couple of binary files: a model description file and a parameters (weights and biases) file. The model description file is named prefix–symbol.json and the parameters file is named prefix-epoch_number.params

Parameters
  • prefix (str) – Prefix for the checkpoint filenames.

  • period (int, optional) – Interval (number of epochs) between checkpoints. Default period is 1.

Returns

callback – A callback function that can be passed as epoch_end_callback to fit.

Return type

function

Example

>>> module.fit(iterator, num_epoch=n_epoch,
... epoch_end_callback  = mx.callback.do_checkpoint("mymodel", 1))
Start training with [cpu(0)]
Epoch[0] Resetting Data Iterator
Epoch[0] Time cost=0.100
Saved checkpoint to "mymodel-0001.params"
Epoch[1] Resetting Data Iterator
Epoch[1] Time cost=0.060
Saved checkpoint to "mymodel-0002.params"
mxnet.callback.log_train_metric(period, auto_reset=False)[source]

Callback to log the training evaluation result every period.

Parameters
  • period (int) – The number of batch to log the training evaluation metric.

  • auto_reset (bool) – Reset the metric after each log.

Returns

callback – The callback function that can be passed as iter_epoch_callback to fit.

Return type

function

mxnet.callback.module_checkpoint(mod, prefix, period=1, save_optimizer_states=False)[source]

Callback to checkpoint Module to prefix every epoch.

Parameters
  • mod (subclass of BaseModule) – The module to checkpoint.

  • prefix (str) – The file prefix for this checkpoint.

  • period (int) – How many epochs to wait before checkpointing. Defaults to 1.

  • save_optimizer_states (bool) – Indicates whether or not to save optimizer states for continued training.

Returns

callback – The callback function that can be passed as iter_end_callback to fit.

Return type

function


此页内容是否对您有帮助