#-----------------------------------------------------------------------
#Copyright 2019 Centrum Wiskunde & Informatica, Amsterdam
#
#Author: Daniel M. Pelt
#Contact: D.M.Pelt@cwi.nl
#Website: http://dmpelt.github.io/msdnet/
#License: MIT
#
#This file is part of MSDNet, a Python implementation of the
#Mixed-Scale Dense Convolutional Neural Network.
#-----------------------------------------------------------------------
"""Module for logging progress during training."""
import abc
import skimage.transform as skt
import imageio
import numpy as np
import os.path
[docs]class Logger(abc.ABC):
'''Base logger class
Logs progress of validation set during training.
'''
def __init__(self, *args, onlyifbetter=False, **kwargs):
self.onlyifbetter = onlyifbetter
self.initialize(*args, **kwargs)
[docs] @abc.abstractmethod
def makelog(self, v):
'''Logs progress of validation set. To be implemented by each class.
:param v: validation set
'''
pass
[docs] @abc.abstractmethod
def initialize(self, *args, **kwargs):
'''Initialize logger.'''
pass
[docs] def log(self, v):
'''Logs progress of validation set.
:param v: validation set
'''
if self.onlyifbetter and v.best != v.curerr:
return
self.makelog(v)
[docs]class ConsoleLogger(Logger):
'''Output error values to the console.'''
[docs] def initialize(self):
pass
[docs] def makelog(self, v):
print('Current error: ', v.curerr, ', Best error: ', v.best)
[docs]class FileLogger(Logger):
'''Output error values to a file.'''
[docs] def initialize(self, fn):
'''Initialize logger.
:param fn: Filename to log error values to.
'''
self.fn = fn
with open(fn,'w') as _:
pass
[docs] def makelog(self, v):
with open(self.fn, 'a') as f:
f.write('Current error: {}, Best error: {}\n'.format(v.curerr, v.best))
header_image = None
header_dict = {}
[docs]def stitchimages(ims, imsize=None, scaleoutput=True):
"""Stitch three images (input, target, output).
:param ims: list of images to stitch
:param imsize: (optional) Maximum size of image
:param scaleoutput: (optional) whether to scale output image to target image range
"""
if imsize:
sz = ims[0].shape
fc = imsize/max(sz)
if fc<1:
osz = (int(fc*sz[0]), int(fc*sz[1]))
sims = []
sims.append(skt.resize(ims[0],osz,preserve_range=True, mode='constant', anti_aliasing=True))
sims.append(skt.resize(ims[1],osz,preserve_range=True, mode='constant', anti_aliasing=True))
sims.append(skt.resize(ims[2],osz,preserve_range=True, mode='constant', anti_aliasing=True))
else:
sims = [ims[0].copy(), ims[1].copy(), ims[2].copy()]
else:
sims = [ims[0].copy(), ims[1].copy(), ims[2].copy()]
mn = sims[0].min()
mx = sims[0].max()
sims[0] -= mn
if mx>mn:
sims[0]/=(mx-mn)
if scaleoutput:
mn = sims[1].min()
mx = sims[1].max()
sims[1] -= mn
sims[2] -= mn
if mx>mn:
sims[1]/=(mx-mn)
sims[2]/=(mx-mn)
sims[2][sims[2]<0]=0
sims[2][sims[2]>1]=1
allims = np.hstack(sims)
head_im = getheaderimage(allims.shape[1])
if len(allims.shape) != len(head_im.shape):
head_im = np.repeat(head_im[...,np.newaxis], 3, axis=2)
return (np.vstack((head_im,allims))*255).astype(np.uint8)
[docs]class ImageLogger(Logger):
'''Output best, worst, and typical images for validation set.'''
[docs] def initialize(self, fn, chan_in=0, chan_out=0, imsize=512):
"""Initialize logger.
:param fn: base filename to output images to.
:param chan_in: input channel to show
:param chan_out: output channel to show
:param imsize: maximum image size to output
"""
self.fn = fn
self.ci = chan_in
self.co = chan_out
self.imsize = imsize
for tpe in ['best', 'worst', 'typical']:
with open(fn+'_'+tpe+'.png','w') as _:
pass
[docs] def toimage(self, ims):
inp, tar, out = ims
return stitchimages([inp[self.ci], tar[self.co], out[self.co]], self.imsize)
[docs] def makelog(self, v):
imageio.imsave(self.fn+'_best.png',self.toimage(v.getbest()))
imageio.imsave(self.fn+'_worst.png',self.toimage(v.getworst()))
imageio.imsave(self.fn+'_typical.png',self.toimage(v.getmedian()))
[docs]class ImageLabelLogger(Logger):
'''Output best, worst, and typical images for validation set for segmentation problems.'''
[docs] def initialize(self, fn, chan_in=0, imsize=512):
"""Initialize logger.
:param fn: base filename to output images to.
:param chan_in: input channel to show
:param imsize: maximum image size to output
"""
self.fn = fn
self.ci = chan_in
self.imsize = imsize
for tpe in ['best', 'worst', 'typical']:
with open(fn+'_'+tpe+'.png','w') as _:
pass
self.colors = [
[0,0,0],
[31,120,180],
[51,160,44],
[227,26,28],
[255,127,0],
[106,61,154],
[255,255,153],
[177,89,40],
[166,206,227],
[178,223,138],
[251,154,153],
[253,191,111],
[202,178,214]
]
[docs] def toimage(self, ims):
inp, tar, out = ims
tar = np.argmax(tar,axis=0)
out = np.argmax(out,axis=0)
tm = tar.max()
om = out.max()
if tm>=13 or om>=13:
tar = tar.astype(np.float32)
out = out.astype(np.float32)
inp = inp[self.ci]
return stitchimages([inp, tar, out], self.imsize, scaleoutput=True)
else:
inp2 = np.zeros((*tar.shape,3),dtype=np.float32)
tar2 = np.zeros((*tar.shape,3),dtype=np.float32)
out2 = np.zeros((*tar.shape,3),dtype=np.float32)
inp2[...,0] = inp[self.ci]
inp2[...,1] = inp[self.ci]
inp2[...,2] = inp[self.ci]
for i in range(tm+1):
tar2[tar==i] = self.colors[i]
for i in range(om+1):
out2[out==i] = self.colors[i]
return stitchimages([inp2, tar2/255, out2/255], self.imsize, scaleoutput=False)
[docs] def makelog(self, v):
imageio.imsave(self.fn+'_best.png',self.toimage(v.getbest()))
imageio.imsave(self.fn+'_worst.png',self.toimage(v.getworst()))
imageio.imsave(self.fn+'_typical.png',self.toimage(v.getmedian()))