#!/usr/pkg/bin/ruby32

# Copyright (c) 2015 Alexandra Figlovskaya <fglval@gmail.com>
# Copyright (c) 2015 Aleksey Cheusov <vle@gmx.net>
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

require 'optparse'

@options = {}
@err = nil
@unspecified_class="__strelka_i_raketa__"

def print_pretty(class_name, p, p_comment, r, r_comment, f1, f1_comment)
  puts "%13s P, R, F1:  %-6.4g %-13s,  %-6.4g %-13s,  %-6.4g" \
    % [class_name, p, p_comment, r, r_comment, f1, f1_comment]
end

def print_accuracy_pretty(a, a_comment)
  puts "Accuracy              :  %-6.4g %-13s" % [a, a_comment]
end

def print_raw(class_name, p, p_comment, r, r_comment, f1, f1_comment)
  puts "#{class_name}\tprecision\t#{p}\t#{p_comment.strip}"
  puts "#{class_name}\trecall\t#{r}\t#{r_comment.strip}"
  puts "#{class_name}\tf1\t#{f1}\t#{f1_comment.strip}"
end

def print_accuracy_raw(a, a_comment)
  puts "\taccuracy\t#{a}\t#{a_comment.strip}"
end

def print_stat(class_name, p, p_comment, r, r_comment, f1, f1_comment)
  if @options[:raw]
    print_raw(class_name, p, p_comment, r, r_comment, f1, f1_comment)
  else
    print_pretty(class_name, p, p_comment, r, r_comment, f1, f1_comment)
  end
end

def print_accuracy(a, a_comment)
  if @options[:raw]
    print_accuracy_raw(a, a_comment)
  else
    print_accuracy_pretty(a, a_comment)
  end
end

def pretty_div(a, b)
  "%5s/%-5s" % [a, b]
end

def normalize_tag(tag)
  tag = tag.to_s.sub(/^[+]/, "") # +1 => 1
  if tag =~ /^-?[0-9]+[.][0-9]+$/
    tag = tag.sub(/[.]0+$/, "") # -1.0000 => -1
  end
  return tag
end

def split_into_3(line, fn)
  line = line.gsub(/\s+/, " ").strip()

  ret = ["", "", Float::MAX]

  tokens = line.split(/ /)
  case tokens.size
  when 2
    ret = [normalize_tag(tokens[0]), normalize_tag(tokens[1]), Float::MAX]
  when 3
    ret = [normalize_tag(tokens[0]), normalize_tag(tokens[1]), tokens[2].to_f]
  else
    ret = [normalize_tag(tokens[0]), normalize_tag(tokens[1]), Float::MAX]
    line.sub!(/^fake ?/, "")
    STDERR.puts("Bad line '#{line}' in file '#{fn}'")
    @err = 1
  end

  if ret [2] < @options[:treshold]
    ret [1] = @unspecified_class
  end

  return ret
end

OptionParser.new do |opts|
  opts.banner = <<EOF
heri-stat calculates precision, recall, F1
  and some other things for given golden data and predictions.
Usage:
    heri-stat    [OPTIONS] <golden_file> <predictions_file>
    heri-stat -1 [OPTIONS] [files...]
OPTIONS:
EOF

  opts.on('-h', '--help','display this message and exit') do
    puts opts
    exit 0
  end

  @options[:raw] = false
  opts.on('-R', '--raw','raw tab-separated output') do
    @options[:raw] = true
  end

  @options[:micro_avg] = false
  opts.on('-m', '--micro-avg','disable micro averaged P/R/F1 output') do
    @options[:micro_avg] = true
  end

  @options[:macro_avg] = false
  opts.on('-r', '--macro-avg','disable macro averaged P/R/F1 output') do
    @options[:macro_avg] = true
  end

  @options[:statistics] = false
  opts.on('-c', '--per-class','disable output of per-class statistics') do
    @options[:statistics] = true
  end

  @options[:accuracy] = false
  opts.on('-a', '--accuracy','disable output of accuracy') do
    @options[:accuracy] = true
  end

  @options[:single] = false
  opts.on('-1', '--single','obtain both golden and
predicted classes from single source. If this option is specified,
the first token on input represents the golden class
and second one -- predicted class') do
    @options[:single] = true
  end

  @options[:unclassified] = false
  opts.on("-u", "--unclassified=UNCLASSIFIED", 'set the label for "unclassified" object') do |u|
    @options[:unclassified] = u.to_s
  end

  @options[:treshold] = Float::MIN
  opts.on("-t", "--treshold=TRESHOLD", 'Minimal treshold for score') do |u|
    @options[:treshold] = u.to_f
    @options[:unclassified] = @unspecified_class
  end

  opts.separator " "
end.parse!

if @options[:unclassified]
  @options[:accuracy]=true
else
  @options[:micro_avg]=true
end

if @options[:single]
  golden_tags = []
  result_tags = []
  while line = gets do
    gt, rt, fake = split_into_3(line, "")
    golden_tags << gt
    result_tags << rt
  end
else
  golden_tags = IO.read(ARGV[0]).split("\n")
  result_tags = IO.read(ARGV[1]).split("\n")
  if golden_tags.length != result_tags.length
    STDERR.puts("Golden data and predictions should contain the same amount of classes");
    exit 1
  end

  golden_tags.each_index do |i|
    fake1, golden_tags[i], fake = split_into_3("fake " + golden_tags[i], ARGV[0])
    fake1, result_tags[i], fake = split_into_3("fake " + result_tags[i], ARGV[1])
  end
end

exit 1 if @err

tag2golden_cnt = Hash.new(0)
tag2result_cnt = Hash.new(0)
tag2TP_cnt = Hash.new(0)
all_precision = 0
all_recall = 0

golden_tags.each_index do |i|
  gt = golden_tags[i]
  rt = result_tags[i]

  tag2golden_cnt[gt] += 1 if gt != @options[:unclassified]
  if rt != @options[:unclassified]
    tag2result_cnt[rt] += 1
    tag2TP_cnt[rt] += (gt == rt ? 1 : 0)
  end
  # make sure hash cell exists
  tag2TP_cnt[gt] += 0
  tag2result_cnt[gt] += 0
end

all_tp = 0
all_f1 = 0
res_tag2TP_cnt = tag2TP_cnt.sort_by { |key, value| key }
res_tag2TP_cnt.each do |t, tp|
  p = (tag2result_cnt[t] > 0.0  ?  tp.to_f / tag2result_cnt[t]  :  0.0)
  r = (tag2golden_cnt[t] > 0.0  ?  tp.to_f / tag2golden_cnt[t]  :  0.0)
  f1 = (p+r > 0.0  ?  2*p*r / (p+r)  :  0.0)
  if !@options[:statistics]
    print_stat("Class  %-6s" % [t],
                 p, pretty_div(tp, tag2result_cnt[t]),
                 r, pretty_div(tp, tag2golden_cnt[t]),
                 f1, "")
  end
  all_precision += p
  all_recall += r
  all_tp += tp
  all_f1 += f1
end

all_rt = 0
tag2result_cnt.each do |tag, rt|
  all_rt += rt
end

all_gt = 0
tag2golden_cnt.each do |tag, gt|
  all_gt += gt
end

if !@options[:accuracy]
  accuracy = all_tp.to_f / all_rt.to_f
  print_accuracy(accuracy, pretty_div(all_tp, all_rt))
end

if !@options[:micro_avg]
  micro_avg_precision = all_tp.to_f / all_rt.to_f
  micro_avg_recall = all_tp.to_f / all_gt.to_f
  micro_avg_f1 = 2*micro_avg_precision*micro_avg_recall / (micro_avg_precision+micro_avg_recall)
  print_stat("Micro average",
               micro_avg_precision, pretty_div(all_tp, all_rt),
               micro_avg_recall, pretty_div(all_tp, all_gt),
               micro_avg_f1, "")
end

if !@options[:macro_avg] && tag2TP_cnt.size > 0
  macro_avg_precision = all_precision / tag2TP_cnt.size
  macro_avg_recall = all_recall / tag2TP_cnt.size
  macro_avg_f1 = all_f1 / tag2TP_cnt.size
  print_stat("Macro average",
               macro_avg_precision, "",
               macro_avg_recall, "",
               macro_avg_f1, "")
end
