Note
Click here to download the full example code
Example 01: Train a network for regressionΒΆ
This script trains a MS-D network for regression (i.e. denoising/artifact removal) Run generatedata.py first to generate required training data.
# Import code
import msdnet
import glob
# Define dilations in [1,10] as in paper.
dilations = msdnet.dilations.IncrementDilations(10)
# Create main network object for regression, with 100 layers,
# [1,10] dilations, 1 input channel, 1 output channel, using
# the GPU (set gpu=False to use CPU)
n = msdnet.network.MSDNet(100, dilations, 1, 1, gpu=True)
# Initialize network parameters
n.initialize()
# Define training data
# First, create lists of input files (noisy) and target files (noiseless)
flsin = sorted(glob.glob('train/noisy/*.tiff'))
flstg = sorted(glob.glob('train/noiseless/*.tiff'))
# Create list of datapoints (i.e. input/target pairs)
dats = []
for i in range(len(flsin)):
# Create datapoint with file names
d = msdnet.data.ImageFileDataPoint(flsin[i],flstg[i])
# Augment data by rotating and flipping
d_augm = msdnet.data.RotateAndFlipDataPoint(d)
# Add augmented datapoint to list
dats.append(d_augm)
# Note: The above can also be achieved using a utility function for such 'simple' cases:
# dats = msdnet.utils.load_simple_data('train/noisy/*.tiff', 'train/noiseless/*.tiff', augment=True)
# Normalize input and output of network to zero mean and unit variance using
# training data images
n.normalizeinout(dats)
# Use image batches of a single image
bprov = msdnet.data.BatchProvider(dats,1)
# Define validation data (not using augmentation)
flsin = sorted(glob.glob('val/noisy/*.tiff'))
flstg = sorted(glob.glob('val/noiseless/*.tiff'))
datsv = []
for i in range(len(flsin)):
d = msdnet.data.ImageFileDataPoint(flsin[i],flstg[i])
datsv.append(d)
# Note: The above can also be achieved using a utility function for such 'simple' cases:
# datsv = msdnet.utils.load_simple_data('val/noisy/*.tiff', 'val/noiseless/*.tiff', augment=False)
# Validate with Mean-Squared Error
val = msdnet.validate.MSEValidation(datsv)
# Use ADAM training algorithms
t = msdnet.train.AdamAlgorithm(n)
# Log error metrics to console
consolelog = msdnet.loggers.ConsoleLogger()
# Log error metrics to file
filelog = msdnet.loggers.FileLogger('log_regr.txt')
# Log typical, worst, and best images to image files
imagelog = msdnet.loggers.ImageLogger('log_regr', onlyifbetter=True)
# Train network until program is stopped manually
# Network parameters are saved in regr_params.h5
# Validation is run after every len(datsv) (=25)
# training steps.
msdnet.train.train(n, t, val, bprov, 'regr_params.h5',loggers=[consolelog,filelog,imagelog], val_every=len(datsv))