msdnet.train module¶
Module for training networks.
-
class
msdnet.train.
TrainAlgorithm
[source]¶ Bases:
abc.ABC
Base class implementing a training algorithm.
-
abstract
step
(n, dlist)[source]¶ Take a single algorithm step.
- Parameters
n –
network.Network
to train withdlist – list of
data.DataPoint
to train with
-
abstract
-
class
msdnet.train.
AdamAlgorithm
(network, a=0.001, b1=0.9, b2=0.999, e=1e-08)[source]¶ Bases:
msdnet.train.TrainAlgorithm
Implementation of the ADAM algorithm.
- Parameters
network –
network.Network
to train witha – ADAM parameter
b1 – ADAM parameter
b2 – ADAM parameter
e – ADAM parameter
-
step
(n, dlist)[source]¶ Take a single algorithm step.
- Parameters
n –
network.Network
to train withdlist – list of
data.DataPoint
to train with
-
msdnet.train.
restore_training
(fn, netclass, trainclass, valclass, valdata, gpu=True)[source]¶ Restore training from file.
- Parameters
fn – filename to load
netclass –
network.Network
class to usetrainclass –
TrainAlgorithm
class to usevalclass –
validate.Validation
class to usevaldata – list of
data.DataPoint
to validate withgpu – (optional) whether to use GPU or CPU
- Returns
network object, training algorithm object, and validation object
-
msdnet.train.
train
(network, trainalg, validation, dataprov, outputfile, val_every=None, loggers=None, stopcrit=inf, progress=False)[source]¶ Train network.
- Parameters
network –
network.Network
to train withtrainalg –
TrainAlgorithm
object that performs training.validation –
validate.Validation
object that performs validation.dataprov –
data.BatchProvider
object that generates training batches.outputfile – file to store trained network parameters in
val_every – (optional) number of training steps before each validation step
loggers – (optional) list of
loggers.Logger
objects to perform logging.stopcric – (optional) number of validations steps without improvement before stopping training
progress – (optional) whether to show progress during training