msdnet.train module¶
Module for training networks.
-
class
msdnet.train.TrainAlgorithm[source]¶ Bases:
abc.ABCBase class implementing a training algorithm.
-
abstract
step(n, dlist)[source]¶ Take a single algorithm step.
- Parameters
n –
network.Networkto train withdlist – list of
data.DataPointto train with
-
abstract
-
class
msdnet.train.AdamAlgorithm(network, a=0.001, b1=0.9, b2=0.999, e=1e-08)[source]¶ Bases:
msdnet.train.TrainAlgorithmImplementation of the ADAM algorithm.
- Parameters
network –
network.Networkto train witha – ADAM parameter
b1 – ADAM parameter
b2 – ADAM parameter
e – ADAM parameter
-
step(n, dlist)[source]¶ Take a single algorithm step.
- Parameters
n –
network.Networkto train withdlist – list of
data.DataPointto train with
-
msdnet.train.restore_training(fn, netclass, trainclass, valclass, valdata, gpu=True)[source]¶ Restore training from file.
- Parameters
fn – filename to load
netclass –
network.Networkclass to usetrainclass –
TrainAlgorithmclass to usevalclass –
validate.Validationclass to usevaldata – list of
data.DataPointto validate withgpu – (optional) whether to use GPU or CPU
- Returns
network object, training algorithm object, and validation object
-
msdnet.train.train(network, trainalg, validation, dataprov, outputfile, val_every=None, loggers=None, stopcrit=inf, progress=False)[source]¶ Train network.
- Parameters
network –
network.Networkto train withtrainalg –
TrainAlgorithmobject that performs training.validation –
validate.Validationobject that performs validation.dataprov –
data.BatchProviderobject that generates training batches.outputfile – file to store trained network parameters in
val_every – (optional) number of training steps before each validation step
loggers – (optional) list of
loggers.Loggerobjects to perform logging.stopcric – (optional) number of validations steps without improvement before stopping training
progress – (optional) whether to show progress during training