Source code for msdnet.train

#Copyright 2019 Centrum Wiskunde & Informatica, Amsterdam
#Author: Daniel M. Pelt
#License: MIT
#This file is part of MSDNet, a Python implementation of the
#Mixed-Scale Dense Convolutional Neural Network.

"""Module for training networks."""

from . import store
import numpy as np
import abc
import tqdm

[docs]class TrainAlgorithm(abc.ABC): """Base class implementing a training algorithm."""
[docs] @abc.abstractmethod def step(self, n, dlist): """Take a single algorithm step. :param n: :class:`.network.Network` to train with :param dlist: list of :class:`.data.DataPoint` to train with """ pass
[docs] @abc.abstractmethod def to_dict(self): """Save algorithm state to dictionary""" pass
[docs] @abc.abstractmethod def load_dict(self, dct): """Load algorithm state from dictionary""" pass
[docs] @classmethod @abc.abstractmethod def from_dict(cls, dct): """Load algorithm from dictionary""" pass
[docs] @classmethod def from_file(cls, fn): """Load algorithm from file""" dct = store.get_dict(fn, 'trainalgorithm') return cls.from_dict(dct)
[docs] def to_file(self, fn): """Save algorithm state to file""" store.store_dict(fn, 'trainalgorithm', self.to_dict())
[docs]class AdamAlgorithm(TrainAlgorithm): """Implementation of the ADAM algorithm. :param network: :class:`.network.Network` to train with :param a: ADAM parameter :param b1: ADAM parameter :param b2: ADAM parameter :param e: ADAM parameter """ def __init__(self, network, a = 0.001, b1 = 0.9, b2 = 0.999, e = 10**-8): self.a = a self.b1 = b1 self.b1t = b1 self.b2 = b2 self.b2t = b2 self.e = e if network: self.npars = network.getgradients().shape[0] self.m = np.zeros(self.npars) self.v = np.zeros(self.npars)
[docs] def step(self, n, dlist): n.gradient_zero() tpix = 0 for d in dlist: inp, tar, msk = d.getall() out = n.forward(inp) err = tar - out if msk is None: tpix += err.size else: msk = (msk == 0) err[:, msk] = 0 tpix += err.size - err.shape[0]*msk.sum() n.backward(err) n.gradient() g = n.getgradients() g/=tpix self.m *= self.b1 self.m += (1-self.b1)*g self.v *= self.b2 self.v += (1-self.b2)*(g**2) mhat = self.m/(1-self.b1t) vhat = self.v/(1-self.b2t) self.b1t *= self.b1 self.b2t *= self.b2 upd = self.a * mhat/(np.sqrt(vhat) + self.e) n.updategradients(upd)
[docs] def to_dict(self): dct = {} dct['a'] = self.a dct['b1'] = self.b1 dct['b1t'] = self.b1t dct['b2'] = self.b2 dct['b2t'] = self.b2t dct['e'] = self.e dct['npars'] = self.npars dct['m'] = self.m.copy() dct['v'] = self.v.copy() return dct
[docs] def load_dict(self, dct): self.a = dct['a'] self.b1 = dct['b1'] self.b1t = dct['b1t'] self.b2 = dct['b2'] self.b2t = dct['b2t'] self.e = dct['e'] self.npars = dct['npars'] self.m = dct['m'].copy() self.v = dct['v'].copy()
[docs] @classmethod def from_dict(cls, dct): t = cls(None) t.load_dict(dct) return t
[docs]def restore_training(fn, netclass, trainclass, valclass, valdata, gpu=True): """Restore training from file. :param fn: filename to load :param netclass: :class:`.network.Network` class to use :param trainclass: :class:`TrainAlgorithm` class to use :param valclass: :class:`.validate.Validation` class to use :param valdata: list of :class:`.data.DataPoint` to validate with :param gpu: (optional) whether to use GPU or CPU :return: network object, training algorithm object, and validation object """ n = netclass.from_file(fn, groupname='checkpoint', gpu=gpu) t = trainclass.from_file(fn) v = valclass.from_file(fn) v.d = valdata return n, t, v
[docs]def train(network, trainalg, validation, dataprov, outputfile, val_every=None, loggers=None, stopcrit=np.Inf, progress=False): """Train network. :param network: :class:`.network.Network` to train with :param trainalg: :class:`TrainAlgorithm` object that performs training. :param validation: :class:`.validate.Validation` object that performs validation. :param dataprov: :class:`.data.BatchProvider` object that generates training batches. :param outputfile: file to store trained network parameters in :param val_every: (optional) number of training steps before each validation step :param loggers: (optional) list of :class:`loggers.Logger` objects to perform logging. :param stopcric: (optional) number of validations steps without improvement before stopping training :param progress: (optional) whether to show progress during training """ if not val_every: val_every = len(validation.d) parts = outputfile.split('.') if len(parts)>1: checkpointfile = '.'.join(parts[:-1]) + '.checkpoint' else: checkpointfile = outputfile + '.checkpoint' nworse = 0 nstep = 0 if progress: pbar = tqdm.tqdm(total=val_every) while True: trainalg.step(network, dataprov.getbatch()) nstep+=1 if progress: pbar.update() if nstep>=val_every: if progress: pbar.clear() pbar.close() nstep=0 network.to_file(checkpointfile, groupname='checkpoint') trainalg.to_file(checkpointfile) validation.to_file(checkpointfile) if validation.validate(network): network.to_file(outputfile) nworse=0 else: nworse+=1 if nworse>=stopcrit: return if loggers: try: for log in loggers: log.log(validation) except TypeError: loggers.log(validation) if progress: pbar = tqdm.tqdm(total=val_every)