TensorFlow 1 version | View source on GitHub |
Initializer base class: all Keras initializers inherit from this class.
Initializers should implement a __call__
method with the following
signature:
def __call__(self, shape, dtype=None, **kwargs):
# returns a tensor of shape `shape` and dtype `dtype`
# containing values drawn from a distribution of your choice.
Optionally, you an also implement the method get_config
and the class
method from_config
in order to support serialization -- just like with
any Keras object.
Here's a simple example: a random normal initializer.
import tensorflow as tf
class ExampleRandomNormal(tf.keras.initializers.Initializer):
def __init__(self, mean, stddev):
self.mean = mean
self.stddev = stddev
def __call__(self, shape, dtype=None, **kwargs):
return tf.random.normal(
shape, mean=self.mean, stddev=self.stddev, dtype=dtype)
def get_config(self): # To support serialization
return {"mean": self.mean, "stddev": self.stddev}
Note that we don't have to implement from_config
in the example above since
the constructor arguments of the class the keys in the config returned by
get_config
are the same. In this case, the default from_config
works fine.
Methods
from_config
@classmethod
from_config( config )
Instantiates an initializer from a configuration dictionary.
Example:
initializer = RandomUniform(-1, 1)
config = initializer.get_config()
initializer = RandomUniform.from_config(config)
Args | |
---|---|
config
|
A Python dictionary, the output of get_config .
|
Returns | |
---|---|
A tf.keras.initializers.Initializer instance.
|
get_config
get_config()
Returns the configuration of the initializer as a JSON-serializable dict.
Returns | |
---|---|
A JSON-serializable Python dict. |
__call__
__call__(
shape, dtype=None, **kwargs
)
Returns a tensor object initialized as specified by the initializer.
Args | |
---|---|
shape
|
Shape of the tensor. |
dtype
|
Optional dtype of the tensor. |
**kwargs
|
Additional keyword arguments. |