#!/usr/bin/env python

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#

import os
import optparse
import sys
import socket
from cmd         import Cmd
from shlex       import split
from threading   import Lock
from time        import strftime, gmtime
from qpid.disp   import Display
import cqpid
import qmf2

class OptsAndArgs(object):

  def __init__(self, argv):
    self.argv = argv
    self.usage = """qmf-tool [OPTIONS] [<broker-host>[:<port>]]"""
    self.option_parser = optparse.OptionParser(usage=self.usage)
    self.conn_group = optparse.OptionGroup(self.option_parser, "Connection Options")
    self.conn_group.add_option("-u", "--user", action="store", type="string", help="User name for authentication")
    self.conn_group.add_option("-p", "--password", action="store", type="string", help="Password for authentication")
    self.conn_group.add_option("-t", "--transport", action="store", type="string",  help="Transport type (tcp, ssl, rdma)")
    self.conn_group.add_option("-m", "--mechanism", action="store", type="string", help="SASL Mechanism for security")
    self.conn_group.add_option("-s", "--service", action="store", type="string", default="qpidd", help="SASL Service name")
    self.conn_group.add_option("--min-ssf", action="store", type="int", metavar="<n>", help="Minimum acceptable security strength factor")
    self.conn_group.add_option("--max-ssf", action="store", type="int", metavar="<n>", help="Maximum acceptable security strength factor")
    self.conn_group.add_option("--conn-option", action="append", default=[], metavar="<NAME=VALUE>", help="Additional connection option(s)")
    self.option_parser.add_option_group(self.conn_group)

    self.qmf_group = optparse.OptionGroup(self.option_parser, "QMF Session Options")
    self.qmf_group.add_option("--domain", action="store", type="string", help="QMF Domain")
    self.qmf_group.add_option("--agent-age", action="store", type="int", metavar="<n>", help="Time, in minutes, to age out non-communicating agents")
    self.qmf_group.add_option("--qmf-option", action="append", default=[], metavar="<NAME=VALUE>", help="Additional QMF session option(s)")
    self.option_parser.add_option_group(self.qmf_group)

  def parse(self):
    host = "localhost"
    conn_options = {}
    qmf_options = []

    options, encArgs = self.option_parser.parse_args(args=self.argv)
    try:
      encoding = locale.getpreferredencoding()
      args = [a.decode(encoding) for a in encArgs]
    except:
      args = encArgs

    if len(args) > 1:
      host = args[1]

    if options.user:
      conn_options["username"] = options.user
    if options.password:
      conn_options["password"] = options.password
    if options.transport:
      conn_options["transport"] = options.transport
    if options.mechanism:
      conn_options["sasl_mechanisms"] = options.mechanism
    if options.service:
      conn_options["sasl_service"] = options.service
    if options.min_ssf:
      conn_options["sasl_min_ssf"] = options.min_ssf
    if options.max_ssf:
      conn_options["sasl_max_ssf"] = options.max_ssf
    for x in options.conn_option:
      try:
        key, val = x.split('=')
        conn_options[key] = val
      except:
        raise Exception("Improperly formatted text for --conn-option: '%s'" % x)

    if options.domain:
      qmf_options.append("domain:'%s'" % options.domain)
    if options.agent_age:
      qmf_options.append("max-agent-age:%d" % options.agent_age)
    for x in options.qmf_option:
      try:
        key, val = x.split('=')
        qmf_options.append("%s:%s" % (key, val))
      except:
        raise Exception("Improperly formatted text for --qmf-option: '%s'" % x)

    qmf_string = '{'
    first = True
    for x in qmf_options:
      if first:
        first = None
      else:
        qmf_string += ','
      qmf_string += x
    qmf_string += '}'

    return host, conn_options, qmf_string



class Mcli(Cmd):
  """ Management Command Interpreter """

  def __init__(self, dataObject, dispObject):
    Cmd.__init__(self)
    self.dataObject = dataObject
    self.dispObject = dispObject
    self.dataObject.setCli(self)
    self.prompt = "qmf: "
    
  def emptyline(self):
    pass

  def setPromptMessage(self, p=None):
    if p == None:
      self.prompt = "qmf: "
    else:
      self.prompt = "qmf[%s]: " % p

  def do_help(self, data):
    print "Management Tool for QMF"
    print
    print "Agent Commands:"
    print "    set filter <filter-string> - Filter the list of agents"
    print "    list agents                - Print a list of the known Agents"
    print "    set default <item-number>  - Set the default agent for operations"
    print "    show filter                - Show the agent filter currently in effect"
    print "    show agent <item-number>   - Print detailed information about an Agent"
    print "    show options               - Show option strings used in the QMF session"
    print
    print "Schema Commands:"
    print "    list packages                            - Print a list of packages supported by the default agent"
    print "    list classes [<package-name>]            - Print all classes supported byt the default agent"
    print "    show class <class-name> [<package-name>] - Show details of a class"
    print
    print "Data Commands:"
    print "    query <class-name> [<package-name>] [<predicate>] - Query for data from the agent"
    print "    list                                              - List accumulated query results"
    print "    clear                                             - Clear accumulated query results"
    print "    show <id>                                         - Show details from a data object"
    print "    call <id> <method> [<args>]                       - Call a method on a data object"
    print
    print "General Commands:"
    print "    set time-format short           - Select short timestamp format (default)"
    print "    set time-format long            - Select long timestamp format"
    print "    quit or ^D                      - Exit the program"
    print

  def complete_set(self, text, line, begidx, endidx):
    """ Command completion for the 'set' command """
    tokens = split(line[:begidx])
    if len(tokens) == 1:
      return [i for i in ('filter ', 'default ', 'time-format ') if i.startswith(text)]
    if len(tokens) == 2 and tokens[1] == 'time-format':
      return [i for i in ('long', 'short') if i.startswith(text)]
    return []

  def do_set(self, data):
    tokens = split(data)
    try:
      if tokens[0] == "time-format":
        self.dispObject.do_setTimeFormat(tokens[1])
      else:
        self.dataObject.do_set(data)
    except Exception, e:
      print "Exception in set command:", e

  def complete_list(self, text, line, begidx, endidx):
    tokens = split(line[:begidx])
    if len(tokens) == 1:
      return [i for i in ('agents', 'packages', 'classes ') if i.startswith(text)]
    return []

  def do_list(self, data):
    try:
      self.dataObject.do_list(data)
    except Exception, e:
      print "Exception in list command:", e

  def complete_show(self, text, line, begidx, endidx):
    tokens = split(line[:begidx])
    if len(tokens) == 1:
      return [i for i in ('options', 'filter', 'agent ', 'class ') if i.startswith(text)]
    return []

  def do_show(self, data):
    try:
      self.dataObject.do_show(data)
    except Exception, e:
      print "Exception in show command:", e

  def complete_query(self, text, line, begidx, endidx):
    return []

  def do_query(self, data):
    try:
      self.dataObject.do_query(data)
    except Exception, e:
      if e.message.__class__ == qmf2.Data:
        e = e.message.getProperties()
      print "Exception in query command:", e

  def do_call(self, data):
    try:
      self.dataObject.do_call(data)
    except Exception, e:
      if e.message.__class__ == qmf2.Data:
        e = e.message.getProperties()
      print "Exception in call command:", e

  def do_clear(self, data):
    try:
      self.dataObject.do_clear(data)
    except Exception, e:
      print "Exception in clear command:", e

  def do_EOF(self, data):
    print "quit"
    try:
      self.dataObject.do_exit()
    except:
      pass
    return True

  def do_quit(self, data):
    try:
      self.dataObject.do_exit()
    except:
      pass
    return True

  def postcmd(self, stop, line):
    return stop

  def postloop(self):
    print "Exiting..."
    self.dataObject.close()


#======================================================================================================
# QmfData
#======================================================================================================
class QmfData:
  """
  """
  def __init__(self, disp, url, conn_options, qmf_options):
    self.disp = disp
    self.url = url
    self.conn_options = conn_options
    self.qmf_options = qmf_options
    self.agent_filter = '[]'
    self.connection = cqpid.Connection(self.url, **self.conn_options)
    self.connection.open()
    self.session = qmf2.ConsoleSession(self.connection, self.qmf_options)
    self.session.setAgentFilter(self.agent_filter)
    self.session.open()
    self.lock = Lock()
    self.cli = None
    self.agents = {}            # Map of number => agent object
    self.deleted_agents = {}    # Map of number => agent object
    self.agent_numbers = {}     # Map of agent name => number
    self.next_number = 1
    self.focus_agent = None
    self.data_list = {}
    self.next_data_index = 1

  #=======================
  # Methods to support CLI
  #=======================
  def setCli(self, cli):
    self.cli = cli

  def close(self):
    try:
      self.session.close()
      self.connection.close()
    except:
      pass   # we're shutting down - ignore any errors

  def do_list(self, data):
    tokens = data.split()
    if len(tokens) == 0:
      self.listData()
    elif tokens[0] == 'agents' or tokens[0] == 'agent':
      self.listAgents()
    elif tokens[0] == 'packages' or tokens[0] == 'package':
      self.listPackages()
    elif tokens[0] == 'classes' or tokens[0] == 'class':
      self.listClasses(tokens[1:])

  def do_set(self, data):
    tokens = split(data)
    if len(tokens) == 0:
      print "What do you want to set?  type 'help' for more information."
      return
    if tokens[0] == 'filter':
      if len(tokens) == 2:
        self.setAgentFilter(tokens[1])
    elif tokens[0] == 'default':
      if len(tokens) == 2:
        self.updateAgents()
        number = int(tokens[1])
        self.focus_agent = self.agents[number]
        print "Default Agent: %s" % self.focus_agent.getName()

  def do_show(self, data):
    tokens = split(data)
    if len(tokens) == 0:
      print "What do you want to show?  Type 'help' for more information."
      return

    if tokens[0] == 'options':
      print "Options used in this session:"
      print "  Connection Options : %s" % self.scrubConnOptions()
      print "  QMF Session Options: %s" % self.qmf_options
      return

    if tokens[0] == 'agent':
      self.showAgent(tokens[1:])
      return

    if tokens[0] == 'filter':
      print self.agent_filter
      return

    if tokens[0] == "default":
      if not self.focus_agent:
        self.updateAgents()
      if self.focus_agent:
        print "Default Agent: %s" % self.focus_agent.getName()
      else:
        print "Default Agent not set"
      return

    if tokens[0] == "class":
      self.showClass(tokens[1:])
      return

    if tokens[0].isdigit():
      self.showData(tokens[0])
      return

    print "What do you want to show?  Type 'help' for more information."
    return

  def do_query(self, data):
    tokens = split(data)
    if len(tokens) == 0:
      print "Class name not specified."
      return
    cname = tokens[0]
    pname = None
    pred  = None
    if len(tokens) >= 2:
      if tokens[1][0] == '[':
        pred = tokens[1]
      else:
        pname = tokens[1]
        if len(tokens) >= 3:
          pred = tokens[2]
    query = "{class:'%s'" % cname
    if pname:
      query += ",package:'%s'" % pname
    if pred:
      query += ",where:%s" % pred
    query += "}"
    if not self.focus_agent:
      self.updateAgents()
    d_list = self.focus_agent.query(query)
    local_data_list = {}
    for d in d_list:
      local_data_list[self.next_data_index] = d
      self.next_data_index += 1
    rows = []
    for index,val in local_data_list.items():
      rows.append((index, val.getAddr().getName()))
      self.data_list[index] = val
    self.disp.table("Data Objects Returned: %d:" % len(d_list), ("Number", "Data Address"), rows)

  def do_call(self, data):
    tokens = split(data)
    if len(tokens) < 2:
      print "Data ID and method-name not specified."
      return
    idx = int(tokens[0])
    methodName = tokens[1]
    args = []
    for arg in tokens[2:]:
      ##
      ## If the argument is a map, list, boolean, integer, or floating (one decimal point),
      ## run it through the Python evaluator so it is converted to the correct type.
      ##
      ## TODO: use a regex for this instead of this convoluted logic
      if arg[0] == '{' or arg[0] == '[' or arg == "True" or arg == "False" or \
            ((arg.count('.') < 2 and (arg.count('-') == 0 or \
            (arg.count('-') == 1 and  arg[0] == '-')) and \
            arg.replace('.','').replace('-','').isdigit())):
        args.append(eval(arg))
      else:
        args.append(arg)

    if not idx in self.data_list:
      print "Unknown data index, run 'query' to get a list of data indices"
      return

    data = self.data_list[idx]
    data._getSchema()
    result = data._invoke(methodName, args, {})
    rows = []
    for k,v in result.items():
      rows.append((k,v))
    self.disp.table("Output Parameters:", ("Name", "Value"), rows)

  def do_clear(self, data):
    self.data_list = {}
    self.next_data_index = 1
    print "Accumulated query results cleared"

  def do_exit(self):
    pass

  #====================
  # Sub-Command Methods
  #====================
  def setAgentFilter(self, filt):
    self.agent_filter = filt
    self.session.setAgentFilter(filt)

  def updateAgents(self):
    agents = self.session.getAgents()
    number_list = []
    for agent in agents:
      if agent.getName() not in self.agent_numbers:
        number = self.next_number
        number_list.append(number)
        self.next_number += 1
        self.agent_numbers[agent.getName()] = number
        self.agents[number] = agent
      else:
        ## Track seen agents so we can clean out deleted ones
        number = self.agent_numbers[agent.getName()]
        number_list.append(number)
        if number in self.deleted_agents:
          self.agents[number] = self.deleted_agents.pop(number)
    deleted = []
    for number in self.agents:
      if number not in number_list:
        deleted.append(number)
    for number in deleted:
      self.deleted_agents[number] = self.agents.pop(number)
    if not self.focus_agent:
      self.focus_agent = self.session.getConnectedBrokerAgent()

  def listAgents(self):
    self.updateAgents()
    rows = []
    for number in self.agents:
      agent = self.agents[number]
      if self.focus_agent and agent.getName() == self.focus_agent.getName():
        d = '*'
      else:
        d = ''
      rows.append((d, number, agent.getVendor(), agent.getProduct(), agent.getInstance(), agent.getEpoch()))
    self.disp.table("QMF Agents:", ("", "Id", "Vendor", "Product", "Instance", "Epoch"), rows)

  def listPackages(self):
    if not self.focus_agent:
      raise "Default Agent not set - use 'set default'"
    self.focus_agent.loadSchemaInfo()
    packages = self.focus_agent.getPackages()
    for p in packages:
      print "    %s" % p

  def getClasses(self, tokens):
    if not self.focus_agent:
      raise "Default Agent not set - use 'set default'"
      return
    self.focus_agent.loadSchemaInfo()
    if len(tokens) == 1:
      classes = self.focus_agent.getSchemaIds(tokens[0]);
    else:
      packages = self.focus_agent.getPackages()
      classes = []
      for p in packages:
        classes.extend(self.focus_agent.getSchemaIds(p))
    return classes

  def listClasses(self, tokens):
    classes = self.getClasses(tokens)
    rows = []
    for c in classes:
      rows.append((c.getPackageName(), c.getName(), self.classTypeName(c.getType())))
    self.disp.table("Classes:", ("Package", "Class", "Type"), rows)

  def showClass(self, tokens):
    if len(tokens) < 1:
      return
    classes = self.getClasses([])
    c = tokens[0]
    p = None
    if len(tokens) == 2:
      p = tokens[1]
    schema = None
    sid = None
    for cls in classes:
      if c == cls.getName():
        if not p or p == cls.getPackageName():
          schema = self.focus_agent.getSchema(cls)
          sid = cls
          break
    if not sid:
      return
    print "Class: %s:%s (%s) - %s" % \
        (sid.getPackageName(), sid.getName(), self.classTypeName(sid.getType()), schema.getDesc())
    print "  hash: %r" % sid.getHash()
    props = schema.getProperties()
    methods = schema.getMethods()
    rows = []
    for prop in props:
      name = prop.getName()
      dtype = self.typeName(prop.getType())
      if len(prop.getSubtype()) > 0:
        dtype += "(%s)" % prop.getSubtype()
      access = self.accessName(prop.getAccess())
      idx = self.yes_blank(prop.isIndex())
      opt = self.yes_blank(prop.isOptional())
      unit = prop.getUnit()
      desc = prop.getDesc()
      rows.append((name, dtype, idx, access, opt, unit, desc))
    self.disp.table("Properties:", ("Name", "Type", "Index", "Access", "Optional", "Unit", "Description"), rows)
    if len(methods) > 0:
      for meth in methods:
        name = meth.getName()
        desc = meth.getDesc()
        if len(desc) > 0:
          desc = " - " + desc
        args = meth.getArguments()
        rows = []
        for prop in args:
          aname = prop.getName()
          dtype = self.typeName(prop.getType())
          if len(prop.getSubtype()) > 0:
            dtype += "(%s)" % prop.getSubtype()
          unit = prop.getUnit()
          adesc = prop.getDesc()
          io = self.dirName(prop.getDirection())
          rows.append((aname, dtype, io, unit, adesc))
        print
        print "   Method: %s%s" % (name, desc)
        self.disp.table("Arguments:", ("Name", "Type", "Dir", "Unit", "Description"), rows)

  def showAgent(self, tokens):
    self.updateAgents()
    for token in tokens:
      number = int(token)
      agent = self.agents[number]
      print
      print "    =================================================================================="
      print "      Agent Id: %d" % number
      print "    Agent Name: %s" % agent.getName()
      print "         Epoch: %d" % agent.getEpoch()
      print "    Attributes:"
      attrs = agent.getAttributes()
      keys = attrs.keys()
      keys.sort()
      pairs = []
      for key in keys:
        if key == '_timestamp' or key == '_schema_updated':
          val = disp.timestamp(attrs[key])
        else:
          val = attrs[key]
        pairs.append((key, val))
      self.printAlignedPairs(pairs)
      agent.loadSchemaInfo()
      print "    Packages:"
      packages = agent.getPackages()
      for package in packages:
        print "        %s" % package

  def showData(self, idx):
    num = int(idx)
    if not num in self.data_list:
      print "Data ID not known, run 'query' first to get data"
      return
    data = self.data_list[num]
    props = data.getProperties()
    rows = []
    for k,v in props.items():
      rows.append((k, v))
    self.disp.table("Properties:", ("Name", "Value"), rows)

  def listData(self):
    if len(self.data_list) == 0:
      print "No Query Results - Use the 'query' command"
      return
    rows = []
    for index,val in self.data_list.items():
      rows.append((index, val.getAgent().getName(), val.getAddr().getName()))
    self.disp.table("Accumulated Query Results:", ('Number', 'Agent', 'Data Address'), rows)

  def printAlignedPairs(self, rows, indent=8):
    maxlen = 0
    for first, second in rows:
      if len(first) > maxlen:
        maxlen = len(first)
    maxlen += indent
    for first, second in rows:
      for i in range(maxlen - len(first)):
        print "",
      print "%s : %s" % (first, second)

  def classTypeName(self, code):
    if code == qmf2.SCHEMA_TYPE_DATA:   return "Data"
    if code == qmf2.SCHEMA_TYPE_EVENT:  return "Event"
    return "Unknown"

  def typeName (self, typecode):
    """ Convert type-codes to printable strings """
    if   typecode == qmf2.SCHEMA_DATA_VOID:    return "void"
    elif typecode == qmf2.SCHEMA_DATA_BOOL:    return "bool"
    elif typecode == qmf2.SCHEMA_DATA_INT:     return "int"
    elif typecode == qmf2.SCHEMA_DATA_FLOAT:   return "float"
    elif typecode == qmf2.SCHEMA_DATA_STRING:  return "string"
    elif typecode == qmf2.SCHEMA_DATA_MAP:     return "map"
    elif typecode == qmf2.SCHEMA_DATA_LIST:    return "list"
    elif typecode == qmf2.SCHEMA_DATA_UUID:    return "uuid"
    else:
      raise ValueError ("Invalid type code: %s" % str(typecode))

  def valueByType(self, typecode, val):
    if   typecode == 1:  return "%d" % val
    elif typecode == 2:  return "%d" % val
    elif typecode == 3:  return "%d" % val
    elif typecode == 4:  return "%d" % val
    elif typecode == 6:  return val
    elif typecode == 7:  return val
    elif typecode == 8:  return strftime("%c", gmtime(val / 1000000000))
    elif typecode == 9:
      if val < 0: val = 0
      sec = val / 1000000000
      min = sec / 60
      hour = min / 60
      day = hour / 24
      result = ""
      if day > 0:
        result = "%dd " % day
      if hour > 0 or result != "":
        result += "%dh " % (hour % 24)
      if min > 0 or result != "":
        result += "%dm " % (min % 60)
      result += "%ds" % (sec % 60)
      return result

    elif typecode == 10: return str(self.idRegistry.displayId(val))
    elif typecode == 11:
      if val:
        return "True"
      else:
        return "False"

    elif typecode == 12: return "%f" % val
    elif typecode == 13: return "%f" % val
    elif typecode == 14: return "%r" % val
    elif typecode == 15: return "%r" % val
    elif typecode == 16: return "%d" % val
    elif typecode == 17: return "%d" % val
    elif typecode == 18: return "%d" % val
    elif typecode == 19: return "%d" % val
    elif typecode == 20: return "%r" % val
    elif typecode == 21: return "%r" % val
    elif typecode == 22: return "%r" % val
    else:
      raise ValueError ("Invalid type code: %s" % str(typecode))

  def accessName (self, code):
    """ Convert element access codes to printable strings """
    if   code == qmf2.ACCESS_READ_CREATE: return "ReadCreate"
    elif code == qmf2.ACCESS_READ_WRITE: return "ReadWrite"
    elif code == qmf2.ACCESS_READ_ONLY: return "ReadOnly"
    else:
      raise ValueError ("Invalid access code: %s" % str(code))

  def dirName(self, io):
    if   io == qmf2.DIR_IN:     return "in"
    elif io == qmf2.DIR_OUT:    return "out"
    elif io == qmf2.DIR_IN_OUT: return "in_out"
    else:
      raise ValueError("Invalid direction code: %r" % io)

  def notNone (self, text):
    if text == None:
      return ""
    else:
      return text

  def yes_blank(self, val):
    if val:
      return "Y"
    return ""

  def objectIndex(self, obj):
    if obj._objectId.isV2:
      return obj._objectId.getObject()
    result = ""
    first = True
    props = obj.getProperties()
    for prop in props:
      if prop[0].index:
        if not first:
          result += "."
        result += self.valueByType(prop[0].type, prop[1])
        first = None
    return result

  def scrubConnOptions(self):
    scrubbed = {}
    for key, val in self.conn_options.items():
      if key == "password":
        val = "***"
      scrubbed[key] = val
    return str(scrubbed)


#=========================================================
# Main Program
#=========================================================
try:
  oa = OptsAndArgs(sys.argv)
  host, conn_options, qmf_options = oa.parse()
except Exception, e:
  print "Parse Error: %s" % e
  sys.exit(1)

disp = Display()

# Attempt to make a connection to the target broker
try:
  data = QmfData(disp, host, conn_options, qmf_options)
except Exception, e:
  if str(e).find("Exchange not found") != -1:
    print "Management not enabled on broker:  Use '-m yes' option on broker startup."
  else:
    print "Failed: %s - %s" % (e.__class__.__name__, e)
  sys.exit(1)

# Instantiate the CLI interpreter and launch it.
cli = Mcli(data, disp)
print("Management Tool for QMF")
try:
  cli.cmdloop()
except KeyboardInterrupt:
  print
  print "Exiting..."
except Exception, e:
  print "Failed: %s - %s" % (e.__class__.__name__, e)

# alway attempt to cleanup broker resources
data.close()
