.. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_train_regr.py: Example 01: Train a network for regression ========================================== This script trains a MS-D network for regression (i.e. denoising/artifact removal) Run generatedata.py first to generate required training data. .. code-block:: default # Import code import msdnet import glob # Define dilations in [1,10] as in paper. dilations = msdnet.dilations.IncrementDilations(10) # Create main network object for regression, with 100 layers, # [1,10] dilations, 1 input channel, 1 output channel, using # the GPU (set gpu=False to use CPU) n = msdnet.network.MSDNet(100, dilations, 1, 1, gpu=True) # Initialize network parameters n.initialize() # Define training data # First, create lists of input files (noisy) and target files (noiseless) flsin = sorted(glob.glob('train/noisy/*.tiff')) flstg = sorted(glob.glob('train/noiseless/*.tiff')) # Create list of datapoints (i.e. input/target pairs) dats = [] for i in range(len(flsin)): # Create datapoint with file names d = msdnet.data.ImageFileDataPoint(flsin[i],flstg[i]) # Augment data by rotating and flipping d_augm = msdnet.data.RotateAndFlipDataPoint(d) # Add augmented datapoint to list dats.append(d_augm) # Note: The above can also be achieved using a utility function for such 'simple' cases: # dats = msdnet.utils.load_simple_data('train/noisy/*.tiff', 'train/noiseless/*.tiff', augment=True) # Normalize input and output of network to zero mean and unit variance using # training data images n.normalizeinout(dats) # Use image batches of a single image bprov = msdnet.data.BatchProvider(dats,1) # Define validation data (not using augmentation) flsin = sorted(glob.glob('val/noisy/*.tiff')) flstg = sorted(glob.glob('val/noiseless/*.tiff')) datsv = [] for i in range(len(flsin)): d = msdnet.data.ImageFileDataPoint(flsin[i],flstg[i]) datsv.append(d) # Note: The above can also be achieved using a utility function for such 'simple' cases: # datsv = msdnet.utils.load_simple_data('val/noisy/*.tiff', 'val/noiseless/*.tiff', augment=False) # Validate with Mean-Squared Error val = msdnet.validate.MSEValidation(datsv) # Use ADAM training algorithms t = msdnet.train.AdamAlgorithm(n) # Log error metrics to console consolelog = msdnet.loggers.ConsoleLogger() # Log error metrics to file filelog = msdnet.loggers.FileLogger('log_regr.txt') # Log typical, worst, and best images to image files imagelog = msdnet.loggers.ImageLogger('log_regr', onlyifbetter=True) # Train network until program is stopped manually # Network parameters are saved in regr_params.h5 # Validation is run after every len(datsv) (=25) # training steps. msdnet.train.train(n, t, val, bprov, 'regr_params.h5',loggers=[consolelog,filelog,imagelog], val_every=len(datsv)) .. _sphx_glr_download_auto_examples_train_regr.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download :download:`Download Python source code: train_regr.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: train_regr.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_