# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from operator import attrgetter

import ArPyUtils
import TableOutput
from Arnet import sortIntf, IpGenAddr, IpGenPrefix
from CliModel import (
      Bool,
      Dict,
      Enum,
      Float,
      Int,
      List,
      Model,
      Str,
)
from IntfModel import Interface

def printt( msg, indent ):
   print " " * 3 * indent + msg

class SwitchAclStatus( Model ):
   aclName = Str( help="Name of the ACL that covers the policy" )
   status = Enum( values=( "unknown", "pending", "success", "fail" ),
                  help="Status of the ACL on the switch" )

class PolicySectionSwitchStatus( Model ):
   switchId = Str( help="Switch name or ID" )
   acl = List( valueType=SwitchAclStatus,
               help="ACLs of the switch covering the policy" )

class PolicySectionStatus( Model ):
   sectionId = Str( help="Section ID" )
   sectionName = Str( help="Name of the policy section" )
   status = Enum( values=( "unknown", "pending", "success", "fail" ),
                  help="Status of the policy" )
   switch = List( valueType=PolicySectionSwitchStatus,
                  help="Section status on switch" )

class PolicyStatus( Model ):
   section = List( valueType=PolicySectionStatus,
                   help="A list of policy status" )

   def render( self ):
      headings = ( "Policy", "Switch", "ACL", "Status" )
      table = TableOutput.createTable( headings, indent=3 )
      f = TableOutput.Format( justify='left' )
      f.noPadLeftIs( True )
      f.padLimitIs( True )
      table.formatColumns( *( [ f ] * len( headings ) ) )
      sectionCovered = []
      for section in self.section:
         switchCovered = []
         for switch in sorted( section.switch, key=attrgetter( 'switchId' ) ):
            for acl in sorted( switch.acl, key=attrgetter( 'aclName' ) ):
               if section.sectionId in sectionCovered:
                  sectionName = ""
               else:
                  sectionCovered.append( section.sectionId )
                  sectionName = section.sectionName or section.sectionId
               if switch.switchId in switchCovered:
                  switchId = ""
               else:
                  switchId = switch.switchId
                  switchCovered.append( switchId )
               table.newRow( sectionName, switchId, acl.aclName, acl.status )
         if not switchCovered:
            sectionName = section.sectionName or section.sectionId
            table.newRow( sectionName, "", "", "unknown" )

      print table.output()

class PolicySwitchErrorMsg( Model ):
   switchId = Str( help="Switch name or ID" )
   error = Str( help="Error occurred when programming policy on the switch" )

class RuleCounters( Model ):
   ruleId = Str( help="Rule ID" )
   ruleName = Str( help="Name of the rule" )
   packetCount = Int( help="Number of packets matched the rule across the switches" )
   switchError = List( valueType=PolicySwitchErrorMsg,
                       help="Errors occurred on the switches" )

   def render( self, indent=0 ):
      printt( "Rule: %s" % ( self.ruleName or self.ruleId ), indent )
      printt( "Packet count: %d" % self.packetCount, indent + 1 )
      if self.switchError:
         printt( "Failure messages from switches:", indent + 1 )
         for err in self.switchError:
            printt( "%s : %s" % ( err.switchId, err.error ), indent + 2 )

class PolicyCounters( Model ):
   policyId = Str( help="Policy ID" )
   policyName = Str( help="Name of the policy" )
   status = Enum( values=( "unknown", "pending", "success", "fail" ),
                  help="Status of the policy" )
   rules = List( valueType=RuleCounters,
                 help="Status and counter information of policy rules" )

   def render( self ):
      print "Policy:", ( self.policyName or self.policyId )
      printt( "Status: %s" % self.status, indent=1 )
      for rule in self.rules:
         rule.render( indent=1 )

class PolicyCountersList( Model ):
   policies = List( valueType=PolicyCounters,
                    help="Status and counter information of policies" )

   def render( self ):
      for policy in self.policies:
         policy.render()
         print

class ControllerRule( Model ):
   ruleId = Str( help="Rule ID" )
   ruleName = Str( help="Name" )
   priority = Int( help="Priority of the rule within a section" )
   sources = List( valueType=str, help="List of sources" )
   destinations = List( valueType=str, help="List of destinations" )
   protocol = Enum( values=( "unknown", "ip", "tcp", "udp", "icmp" ),
                    help="Protocol" )
   sourcePort = Str( help="Source port(s)" )
   destinationPort = Str( help="Destination port(s)" )
   action = Enum( values=( "unknown", "allow", "drop", "reject" ),
                  help="Action of the rule" )

   def renderAsRow( self, table ):
      name = self.ruleName or self.ruleId
      table.newRow( name, ",".join( self.sources ), ",".join( self.destinations ),
                  self.protocol, self.sourcePort, self.destinationPort, self.action )

class ControllerPolicy( Model ):
   sectionId = Str( help="Section ID" )
   sectionName = Str( help="Name of the policy section" )
   priority = Int( help="Priority of the section" )
   rules = List( valueType=ControllerRule,
                 help="List of rules of the policy section" )

   def render( self ):
      print "Policy :", ( self.sectionName or self.sectionId )
      headings = ( "Rule", "Source", "Destination", "IP protocol", "Source port",
                   "Dest port", "Action" )
      table = TableOutput.createTable( headings, indent=3 )
      f = TableOutput.Format( justify='left' )
      f.noPadLeftIs( True )
      f.padLimitIs( True )
      table.formatColumns( *( [ f ] * len( headings ) ) )
      for rule in sorted( self.rules, key=attrgetter( 'priority' ) ):
         rule.renderAsRow( table )
      print table.output()

class ControllerPolicyList( Model ):
   policies = List( valueType=ControllerPolicy,
                    help="List of policies received from controller" )

   def render( self ):
      for policy in sorted( self.policies, key=attrgetter( 'priority' ) ):
         policy.render()

class ControllerNsGroup( Model ):
   name = Str( help="NS group name" )
   addrList = List( valueType=str,
                    help="List of IP/IPv6 address making up the NS group" )

   def render( self ):
      print "Group name: %s" % self.name
      print "IP Addresses:"
      if self.addrList:
         subnetIps = []
         hostIps = []
         for ip in self.addrList:
            if '/' in ip:
               subnetIps.append( ip )
            else:
               hostIps.append( ip )
         op = ", ".join( sorted( hostIps, key=lambda ip: IpGenAddr( ip ).sortKey ) )
         if subnetIps:
            op += ", " + ", ".join( sorted( subnetIps,
                                       key=lambda ip: IpGenPrefix( ip ).sortKey ) )
         print op

class ControllerNsGroupList( Model ):
   nsGroups = List( valueType=ControllerNsGroup,
                    help="List of NSX security groups received from the controller" )

   def render( self ):
      for group in sorted( self.nsGroups, key=attrgetter( 'name' ) ):
         group.render()
         print ""

class ControllerHost( Model ):
   groups = List( valueType=str,
                  help='List of NSX security groups the host belongs to' )
   tags = List( valueType=str,
                help='List of tags assigned to the host' )

   def render( self ):
      if self.tags:
         print 'Associated tags:'
         print ', '.join( sorted( self.tags ) )
         print ''
      if self.groups:
         print 'Member of:'
         print ', '.join( sorted( self.groups ) )

class IntfInfo( Model ):
   name = Interface( help='Interface name' )
   hosts = List( valueType=str, help='List of hosts learnt on the interface' )

   def render( self ):
      print "%s: %s" % ( self.name, ", ".join( self.hosts ) )

class SwitchInfo( Model ):
   switchId = Str( help='Switch ID or hostname' )
   intfs = Dict( valueType=IntfInfo,
                 help='List of interfaces the host is attached to' )

   def render( self ):
      print self.switchId
      for intf in sortIntf( self.intfs ):
         self.intfs[ intf ].render()
      print ""

class TagInfo( Model ):
   tag = Str( help='Tag' )
   switches = Dict( valueType=SwitchInfo,
                    help='List of switches the tag is configured on' )

   def render( self ):
      print self.tag
      print "-" * len( self.tag )
      for sw in ArPyUtils.naturalsorted( self.switches ):
         self.switches[ sw ].render()

class PcsTagInfo( Model ):
   tags = Dict( valueType=TagInfo,
                help='List of tags configured on the switches' )

   def render( self ):
      for tag in ArPyUtils.naturalsorted( self.tags ):
         self.tags[ tag ].render()

class PcsApiState( Model ):
   method = Enum( values=( 'get', 'post', 'put', 'delete' ),
                  help='HTTP method' )
   uri = Str( help='HTTP URI' )
   success = Bool( help='Successful HTTP request or response' )
   errorDetail = Str( help='Additional detail for error', optional=True )
   timestamp = Float( help='UTC timestamp of the last response' )

class PcsStatus( Model ):
   status = Enum( values=( "running", "notRunning" ),
                  help="Status of Policy Control Service" )
   policyCount = Int( help="Total number of policies received from controller" )
   tagCount = Int( help="Total number of tags configured on switches" )
   apiState = Dict( valueType=PcsApiState,
                    help="A dictionary of API states, keyed by API type",
                    optional=True )
   def render( self ):
      printStatus = { 'running' : 'running', 'notRunning' : 'not running' }
      print "%-35s: %s" % ( "Policy Control Service is", printStatus[ self.status ] )
      if self.status == "not running":
         return
      print "%-35s: %s" % ( "Policies received from controller", self.policyCount )
      print "%-35s: %s" % ( "Tags configured on switches", self.tagCount )

      # Communication status summary logic, failure kinds
      # 1. CVX->Controller : Indicates possible pinned public key issue
      # 2. Controller->CVX : Indicates possible CVX cert issue on controller
      # 3. Both kinds with some Success: Indicates possible Client/Server error
      # 4. All failures: Indicates possible misconfiguration

      failedApis = set()
      successApis = set()
      commStatus = ""
      lastError = [ '', '', '', 0 ]
      for apiState in self.apiState.values():
         if apiState.timestamp == 0:
            continue
         if apiState.success:
            successApis.add( apiState.uri )
         else:
            failedApis.add( apiState.uri )
            if apiState.timestamp > lastError[ 3 ]:
               lastError = [ apiState.method.upper(), apiState.uri,
                             apiState.errorDetail, apiState.timestamp ]

      if failedApis:
         sbFails = [ x for x in failedApis if x.startswith( "/pcs/" ) ]
         nbFails = failedApis - set( sbFails )
         sbSuccesses = [ x for x in successApis if x.startswith( "/pcs/" ) ]
         nbSuccesses = successApis - set( sbSuccesses )

         if nbFails and not sbFails and not nbSuccesses:
            commStatus = "All API requests from CVX to controller failed"
         elif sbFails and not nbFails and not sbSuccesses:
            commStatus = "All API requests from controller to CVX failed"
         else: # successes and failures
            commStatus = "API requests failed"

         print "%-35s:" % "Last failed API request", ' '.join( lastError[ 0:2 ] )
         if lastError[ 2 ] is not None:
            print "%-35s:" % "Error", lastError[ 2 ]
      else:
         if successApis:
            commStatus = "Successful"
         else:
            commStatus = "No requests sent or received"
      print "%-35s: " % "HTTP communication with controller", commStatus
