.. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_train_segm_tomography.py: Example 07: Train a network for segmentation (tomography) ========================================================= This script trains a MS-D network for segmentation (i.e. labeling) Run generatedata.py first to generate required training data. .. code-block:: default # Import code import msdnet import glob # Define dilations in [1,10] as in paper. dilations = msdnet.dilations.IncrementDilations(10) # Create main network object for segmentation, with 100 layers, # [1,10] dilations, 5 input channels (5 slices), 4 output channels (one for each label), # using the GPU (set gpu=False to use CPU) n = msdnet.network.SegmentationMSDNet(100, dilations, 5, 4, gpu=True) # Initialize network parameters n.initialize() # Define training data # First, create lists of input files (low quality) and target files (labels) flsin = sorted(glob.glob('tomo_train/lowqual/*.tiff')) flstg = sorted(glob.glob('tomo_train/label/*.tiff')) # Create list of datapoints (i.e. input/target pairs) dats = [] for i in range(len(flsin)): # Create datapoint with file names d = msdnet.data.ImageFileDataPoint(flsin[i],flstg[i]) # Convert datapoint to one-hot, using labels 0, 1, 2, and 3, # which are the labels given in each label TIFF file. d_oh = msdnet.data.OneHotDataPoint(d, [0,1,2,3]) # Add datapoint to list dats.append(d_oh) # Note: The above can also be achieved using a utility function for such 'simple' cases: # dats = msdnet.utils.load_simple_data('tomo_train/lowqual/*.tiff', 'tomo_train/label/*.tiff', augment=False, labels=[0,1,2,3]) # Convert input slices to input slabs (i.e. multiple slices as input) dats = msdnet.data.convert_to_slabs(dats, 2, flip=True) # Augment data by rotating and flipping dats_augm = [msdnet.data.RotateAndFlipDataPoint(d) for d in dats] # Normalize input and output of network to zero mean and unit variance using # training data images n.normalizeinout(dats) # Use image batches of a single image bprov = msdnet.data.BatchProvider(dats,1) # Define validation data (not using augmentation) flsin = sorted(glob.glob('tomo_val/lowqual/*.tiff')) flstg = sorted(glob.glob('tomo_val/label/*.tiff')) datsv = [] for i in range(len(flsin)): d = msdnet.data.ImageFileDataPoint(flsin[i],flstg[i]) d_oh = msdnet.data.OneHotDataPoint(d, [0,1,2,3]) datsv.append(d_oh) # Note: The above can also be achieved using a utility function for such 'simple' cases: # datsv = msdnet.utils.load_simple_data('tomo_val/lowqual/*.tiff', 'tomo_val/label/*.tiff', augment=False, labels=[0,1,2,3]) # Convert input slices to input slabs (i.e. multiple slices as input) datsv = msdnet.data.convert_to_slabs(datsv, 2, flip=False) # Validate with Mean-Squared Error val = msdnet.validate.MSEValidation(datsv) # Use ADAM training algorithms t = msdnet.train.AdamAlgorithm(n) # Log error metrics to console consolelog = msdnet.loggers.ConsoleLogger() # Log error metrics to file filelog = msdnet.loggers.FileLogger('log_tomo_segm.txt') # Log typical, worst, and best images to image files imagelog = msdnet.loggers.ImageLabelLogger('log_tomo_segm', chan_in=2, onlyifbetter=True) # Log typical, worst, and best images to image files # Output probability map for a single channel (in this case, channel 3) singlechannellog = msdnet.loggers.ImageLogger('log_tomo_segm_singlechannel', chan_in=2, chan_out=3, onlyifbetter=True) # Train network until program is stopped manually # Network parameters are saved in segm_params.h5 # Validation is run after every len(datsv) (=256) # training steps. msdnet.train.train(n, t, val, bprov, 'tomo_segm_params.h5',loggers=[consolelog,filelog,imagelog,singlechannellog], val_every=len(datsv)) .. _sphx_glr_download_auto_examples_train_segm_tomography.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download :download:`Download Python source code: train_segm_tomography.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: train_segm_tomography.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_