#-----------------------------------------------------------------------
#Copyright 2019 Centrum Wiskunde & Informatica, Amsterdam
#
#Author: Daniel M. Pelt
#Contact: D.M.Pelt@cwi.nl
#Website: http://dmpelt.github.io/msdnet/
#License: MIT
#
#This file is part of MSDNet, a Python implementation of the
#Mixed-Scale Dense Convolutional Neural Network.
#-----------------------------------------------------------------------
"""
Module for defining and processing validation sets.
"""
from . import store
from . import operations
import abc
import numpy as np
[docs]class Validation(abc.ABC):
    """Base class for processing a validation set."""
[docs]    @abc.abstractmethod
    def validate(self, n):
        """Compute validation metrics.
        
        :param n: :class:`.network.Network` to validate with
        :return: True if validation metric is lower than best validation error encountered, False otherwise.
        """
        pass 
    
[docs]    @abc.abstractmethod
    def to_dict(self):
        """Compute validation metrics."""
        pass 
    
[docs]    @abc.abstractmethod
    def load_dict(self, dct):
        """Return a dictionary containing all network variables and parameters.
        :return: all network variables and parameters
        """
        pass 
    
[docs]    @classmethod
    @abc.abstractmethod
    def from_dict(cls, dct):
        """Initialize Validation object from dictionary.
        :param dct: dictionary with all parameters
        """
        pass 
    
[docs]    @classmethod
    def from_file(cls, fn):
        """Initialize Validation object from file.
        :param fn: filename
        """
        dct = store.get_dict(fn, 'validation')
        return cls.from_dict(dct) 
    
[docs]    def to_file(self, fn):
        """Save all Validation object parameters to file.
        :param fn: filename
        """
        store.store_dict(fn, 'validation', self.to_dict())  
[docs]class MetricValidation(Validation):
    """Validation object that computes simple difference metrics.
    :param data: list of :class:`.data.DataPoint` objects to validate with.
    :param keep: (optional) whether to keep the best, worst, and typical result in memory.
    """
    def __init__(self, data, keep=True):
        self.d = data
        self.keep = keep
        self.best = np.Inf
    
[docs]    def errorfunc(self, output, target, msk):
        """Error function used for validation.
        :param output: network output image.
        :param target: target image.
        :param mask: mask image to indicate where to compute error function for.
        :return: error function value.
        """
        pass 
    
[docs]    def getbest(self):
        """Return the input, target, and network output for best result.
        :return: list of images (input, target, network output)
        """
        d = self.d[self.idx[0]]
        out = []
        out.append(d.input)
        out.append(d.target)
        if self.keep:
            out.append(self.outputs[0])
        else:
            out.append(self.n.forward(d.input))
        return out 
    
[docs]    def getworst(self):
        """Return the input, target, and network output for worst result.
        :return: list of images (input, target, network output)
        """
        d = self.d[self.idx[1]]
        out = []
        out.append(d.input)
        out.append(d.target)
        if self.keep:
            out.append(self.outputs[1])
        else:
            out.append(self.n.forward(d.input))
        return out 
    
    
[docs]    def validate(self, n):
        self.n = n
        errs = np.zeros(len(self.d))
        if self.keep:
            self.outputs = [0,0,0]
        low = np.Inf
        high = -np.Inf
        self.idx = [0,0,0]
        for i,d in enumerate(self.d):
            out = self.n.forward(d.input)
            err = self.errorfunc(out, d.target, d.mask)
            errs[i] = err
            if err<low:
                low = err
                self.idx[0] = i
                if self.keep:
                    self.outputs[0] = out
            if err>high:
                high = err
                self.idx[1] = i
                if self.keep:
                    self.outputs[1] = out
        
        median = np.argsort(errs)[errs.shape[0]//2]
        self.idx[2] = median
        if self.keep:
            if median==self.idx[0]:
                self.outputs[2] = self.outputs[0]
            elif median==self.idx[1]:
                self.outputs[2] = self.outputs[1]
            else:
                self.outputs[2] = self.n.forward(self.d[median].input)
        error = errs.mean()
        self.curerr = error
        if error<self.best:
            self.best = error
            return True
        return False 
    
[docs]    def to_dict(self):
        dct = {}
        dct['best'] = self.best
        dct['keep'] = self.keep
        return dct 
    
[docs]    def load_dict(self, dct):
        self.best = dct['best']
        self.keep = dct['keep'] 
    
[docs]    @classmethod
    def from_dict(cls, dct):
        v = cls(None, None)
        v.load_dict(dct)
        return v  
[docs]class MSEValidation(MetricValidation):
    """Validation object that uses mean-squared error"""
    
[docs]    def errorfunc(self, output, target, msk):
        err = output-target
        npix = err.size
        if not msk is None:
            msk = (msk == 0)
            err[:, msk] = 0
            npix -= err.shape[0]*msk.sum()
        return operations.squaresum(err)/npix