#!/usr/pkg/bin/python2.7
# -*- coding: utf-8 -*-

# Copyright (C) 2009 The Tegaki project contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

# Contributors to this file:
# - Mathieu Blondel

import sys
import os
import time
from optparse import OptionParser

from tegaki.charcol import CharacterCollection
from tegaki.recognizer import Recognizer, RecognizerError

from tegakitools.charcol import *

VERSION = '0.3.1'

def harmonic_mean(x1, x2):
    if x1 == 0.0 and x2 == 0.0:
        return 0.0
    else:
        return 2 * float(x1 * x2) / float(x1 + x2)

class TegakiEvalError(Exception):
    pass

class TegakiEval(object):

    MATCH_RESULTS = (1, 5, 10)

    def __init__(self, options, args):
        self._verbosity_level = options.verbosity_level
        self._directories = options.directories
        self._databases = options.databases
        self._charcols = options.charcols
        self._tomoe = options.tomoe
        self._kuchibue = options.kuchibue
        self._list = options.list
        self._include = options.include
        self._exclude = options.exclude
        self._max_samples = options.max_samples

        if not self._list:
            self._recognizer = args[0]
            self._model = args[1]

    def run(self):
        if self._list:
            self._list_recognizers()
        else:
            self._recognize()

    def _list_recognizers(self):
        avail_recognizers = Recognizer.get_all_available_models()
        print "\n".join(["- %s (%s)" % (model, recog) for recog, model, meta \
                         in avail_recognizers])

    def _recognize(self):
        charcol = get_aggregated_charcol(
                        ((TYPE_CHARCOL, self._charcols),
                         (TYPE_CHARCOL_DB, self._databases),
                         (TYPE_DIRECTORY, self._directories),
                         (TYPE_TOMOE, self._tomoe),
                         (TYPE_KUCHIBUE, self._kuchibue)))


        charcol.include_characters_from_files(self._include)
        charcol.exclude_characters_from_files(self._exclude)

        # max samples
        if self._max_samples:
            charcol.remove_samples(keep_at_most=self._max_samples)

        # FIXME: don't load all characters in memory
        all_chars = charcol.get_all_characters()

        if len(all_chars) == 0:
            raise TegakiEvalError, "No character samples to evaluate!"

        recognizer_class = self._get_recognizer_class()
        recognizer = self._get_recognizer(recognizer_class)

        self._eval(recognizer, all_chars)

    def _get_recognizer_class(self):
        avail_recognizers = Recognizer.get_available_recognizers()

        if not self._recognizer in avail_recognizers:
            err = "Not an available recognizer!\n"
            err += "Available ones include: %s" % \
                ", ".join(avail_recognizers.keys())
            raise TegakiEvalError, err

        return avail_recognizers[self._recognizer]

    def _get_recognizer(self, recognizer_class):
        recognizer = recognizer_class()
        if os.path.exists(self._model):
            # the path exists so we consider the parameter to be a model path
            method = recognizer.open

            # try to find a .meta file
            meta_file = self._model.replace(".model", ".meta")
            if os.path.exists(meta_file) and meta_file.endswith(".meta"):
                try:
                    meta = Recognizer.read_meta_file(meta_file)
                except RecognizerError, e:
                    raise TegakiEvalError, str(e) 
            else:
                meta = {}
        else:
            # otherwise we consider the parameter to be a model name
            avail_models = recognizer_class.get_available_models()
            if not self._model in avail_models:
                err = "Not an available model!\n"
                err += "Available ones include: %s" % \
                    ", ".join(["\"%s\"" % k for k in avail_models.keys()])
                raise TegakiEvalError, err 

            meta = avail_models[self._model]
            method = recognizer.set_model
        try:
            method(self._model)
            recognizer.set_options(meta)
        except RecognizerError, e:
            raise TegakiEvalError, str(e)

        return recognizer

    def _eval(self, recognizer, all_chars):
        # number of samples present per character
        n_samples = {}
        # number of correctly predicted samples per character
        n_corr_pred = {}
        # number of times a character was predicted (correctly or not)
        n_pred = {}

        for n in self.MATCH_RESULTS:
            n_corr_pred[n] = {}
            n_pred[n] = {}

        # calculate our statistics for each character

        canddict = {} # store ALL the candidate results for verbosity >= 2

        start_time = time.time()

        for char in all_chars:
            utf8 = char.get_utf8()
            if not utf8:
                continue

            n_samples[utf8] = n_samples.get(utf8, 0) + 1

            cand = recognizer.recognize(char.get_writing(), 
                                        n=max(self.MATCH_RESULTS))
            cand = [char for char, prob in cand] # we don't need the probability

            if self._verbosity_level >= 2:
                if utf8 not in canddict: canddict[utf8] = []
                canddict[utf8].append(cand)

            for n in self.MATCH_RESULTS:
                if utf8 in cand[0:n]:
                    n_corr_pred[n][utf8] = n_corr_pred[n].get(utf8, 0) + 1
                for c in cand[0:n]:
                    n_pred[n][c] = n_pred[n].get(c, 0) + 1

        end_time = time.time()

        # Calculate accuracy/recall and precision for each character
        # Print the overall results
        print "Overall results"
        print "\tRecognizer: %s" % self._recognizer
        print "\tNumber of characters evaluated: %d\n" % len(all_chars)
        total_time = end_time - start_time
        print "\tTotal time: %0.2f sec" % float(total_time)
        print "\tAverage time per character: %0.2f sec" % \
            (float(total_time) / len(all_chars))
        print "\tRecognition speed: %0.2f char/sec\n" % \
            (len(all_chars) / float(total_time))

        total_samples = sum(n_samples.values())
        recall = {}
        precision = {}
        for n in self.MATCH_RESULTS:
            recall[n] = {}
            precision[n] = {}

        for n in self.MATCH_RESULTS:
            total_corr_pred = sum(n_corr_pred[n].values())
            #total_pred = sum(n_pred[n].values())

            recall_sum = 0
            precision_sum = 0

            for k in n_samples.keys():

                # recall accounts for the recognizer "completeness"
                # i.e. number of correct predictions / number of samples
                recall[n][k] = float(n_corr_pred[n].get(k, 0)) / \
                               float(n_samples[k])
                recall_sum += recall[n][k]

                # i.e. number of correct predictions / number of predictions
                try:
                    precision[n][k] = float(n_corr_pred[n].get(k, 0)) / \
                                       float(n_pred[n][k])
                except KeyError:
                    precision[n][k] = 0

                precision_sum += precision[n][k]
            
            recall_sum *= 100 / float(len(n_samples))
            precision_sum *= 100 / float(len(n_samples))

            print "\tmatch%d" % n
            print "\t\tAccuracy/Recall: %0.2f" % recall_sum
            if n == 1:
                # Precision doesn't make sense for n > 1
                print "\t\tPrecision: %0.2f" % precision_sum
                print "\t\tF1 score: %0.2f" % harmonic_mean(recall_sum, 
                                                            precision_sum)
            print ""

        # verbosity level 1
        if self._verbosity_level < 1:
            return

        print "Result details"
        for k in n_samples.keys():
            print "\tCharacter: %s" %k
            print "\tNumber of samples: %d\n" % n_samples[k]

            for n in self.MATCH_RESULTS:
                print "\t\tmatch%d" % n
                print "\t\tAccuracy/Recall: %0.2f" % (recall[n][k] * 100)
                if n == 1:
                    # Precision doesn't make sense for n > 1
                    print "\t\tPrecision: %0.2f" % (precision[n][k] * 100)
                    f1s = harmonic_mean(recall[n][k], precision[n][k]) * 100
                    print "\t\tF1 score: %0.2f" % f1s
                print ""

            if self._verbosity_level < 2:
                continue
            
            # verbosity level 2
            print "\tCandidates:"
            i = 0
            for cand in canddict[k]:
                print "\tsample%d: %s" % (i, ", ".join(cand))
                i += 1
            print ""

usage = """usage: %prog [options] recognizer model

recognizer        a recognizer available on the system

model             a model name available for that recognizer on the system OR
                  the direct file path to the model
"""
parser = OptionParser(usage=usage, version="%prog " + VERSION)

parser.add_option("-v", "--verbosity-level",
                  type="int", dest="verbosity_level", default=0,
                  help="verbosity level between 0 and 2")


parser.add_option("-d", "--directory",
                  action="append", type="string", dest="directories",
                  default=[],
                  help="directory containing individual XML character files")
parser.add_option("-c", "--charcol",
                  action="append", type="string", dest="charcols",
                  default=[],
                  help="character collection XML files")
parser.add_option("-b", "--db",
                  action="append", type="string", dest="databases",
                  default=[],
                  help="character collection XML files")
parser.add_option("-t", "--tomoe-dict",
                  action="append", type="string", dest="tomoe",
                  default=[],
                  help="Tomoe XML dictionary files")
parser.add_option("-k", "--kuchibue",
                  action="append", type="string", dest="kuchibue",
                  default=[],
                  help="Kuchibue unipen database")


parser.add_option("-l", "--list",
                  action="store_true",dest="list", default=False,
                  help="List available recognizers and models")



parser.add_option("-i", "--include",
                  action="append", type="string", dest="include",
                  default=[],
                  help="File containing characters to include")
parser.add_option("-e", "--exclude",
                  action="append", type="string", dest="exclude",
                  default=[],
                  help="File containing characters to exclude")
parser.add_option("-m", "--max-samples",
                  type="int", dest="max_samples",
                  help="Maximum number of samples per character")



(options, args) = parser.parse_args()

try:
    if not options.list and len(args) < 2:
        raise TegakiEvalError, "Needs a recognizer and a model!"

    TegakiEval(options, args).run()
except TegakiEvalError, e:
    sys.stderr.write(str(e) + "\n\n")
    parser.print_help()
    sys.exit(1)
