#!/usr/bin/env python2.7
#
# Copyright (c) 2015-2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from abc import ABCMeta, abstractmethod, abstractproperty
from MssPolicyMonitor.Lib import ( HA_ACTIVE_PASSIVE, HA_ACTIVE_ACTIVE, HA_ACTIVE,
                                   HA_ACTIVE_PRIMARY, LINK_STATE_UNKNOWN, t3 )


class ServiceDevicePolicy( object ):
   ''' Defines a raw policy object to be populated by service device 
       vendor API plugins.
   '''
   def __init__( self, policyName, managementIp='', vsys=None, number=0 ):
      self.name = policyName
      self.managementIp = managementIp
      self.action = ''             # typically ALLOW or DENY/DROP
      self.tags = ''               # list of associated tags
      self.isOffloadPolicy = None
      self.number = number         # loading order number
      self.vsys = vsys
      self.srcZoneName = ''
      self.dstZoneName = ''
      self.srcZoneType = ''          # 'vwire', 'layer2' or 'layer3'
      self.dstZoneType = ''          # 'vwire', 'layer2' or 'layer3'
      self.srcZoneInterfaces = []  # list of NetworkInterface
      self.dstZoneInterfaces = []  # list of NetworkInterface
      self.srcIpAddrList = []      # IP address, range or subnet
      self.dstIpAddrList = []      # IP address, range or subnet
      self.dstL4Services = {}      # e.g. { 'TCP': [ 22, 80 ], 'UDP': [ 67, 68 ] }
      self.intfNeighbors = {}      # key=zone intf name
      self.interceptZone = ''      # used for testing only

   def __str__( self ):
      out = 'rawPolicy: %s Dev: %s Vsys: %s Num: %s %s Action: %s ' % (
         self.name, self.managementIp, self.vsys, self.number,
         'Tags: %s' % self.tags if self.tags else '', self.action  )
      out += 'ZnTypes: %s/%s %s' % (
         self.srcZoneType, self.dstZoneType,
         'L4: %s' % self.dstL4Services if self.dstL4Services else '' )
      out += '\n  SrcZone: %s  Intf: %s  IP_Addrs: %s  ->' % (
         self.srcZoneName, 
         (', ').join( [ intf.displayName() for intf in self.srcZoneInterfaces ] ),
         (', ').join( self.srcIpAddrList ) )
      out += '  DstZone: %s  Intf: %s  IP_Addrs: %s' % (
         self.dstZoneName, 
         (', ').join( [ intf.displayName() for intf in self.dstZoneInterfaces ] ),
         (', ').join( self.dstIpAddrList ) )
      if self.interceptZone:
         out += '\n  InterceptZone: %s' % self.interceptZone
      if self.intfNeighbors:
         out += '\n  Neighbors: %s' % [
            '%s--%s_swId=%s' % ( intf,
            nb['switchIntf'] if nb and 'switchIntf' in nb else '',
            nb['switchChassisId'] if nb and 'switchChassisId' in nb else '' )
            for intf, nb in self.intfNeighbors.items() ]
      return out

####################################################################################
class ServiceDevice( object ):

   def __init__( self, deviceId, deviceType, threadName ):
      self.deviceId = deviceId
      self.deviceSetName = ''
      self.threadName = threadName
      self.plugin = None
      self.deviceType = deviceType
      self.mgmtIp = None
      self.haPeerMgmtIp = None
      self.name = ''
      self.model = ''
      self.isSingleLogicalDeviceHaModel = False  #ie. cluster, virtual mac, shared IP
      self.isCurrent = True

   def setDeviceInfo( self, deviceInfo ):
      if not deviceInfo:
         return
      self.mgmtIp = deviceInfo[ 'ipAddr' ]
      self.name = deviceInfo[ 'name' ]
      self.model = deviceInfo[ 'model' ],  # manufacturer model name/number
      if self.plugin:
         self.isSingleLogicalDeviceHaModel = \
            self.plugin.isSingleLogicalDeviceHaModel()

   def initComplete( self ):
      complete = bool( self.deviceId and self.mgmtIp and self.threadName and
                       self.plugin )
      t3('PluginController.ServiceDevice initComplete:', complete, self.deviceType,
          self.deviceId, 'thread:', self.threadName )
      return complete

   def __str__( self ):
      return '<ServiceDevice IP=%s name=%s model=%s id=%s thread=%s>' % (
         self.mgmtIp, self.name, self.model, self.deviceId, self.threadName )

####################################################################################
class NetworkInterface( object ):
   ''' Represents a service device interface linked to an Arista switch.
       This class is an abstraction of a network interface which may be a
       normal Ethernet interface, a subinterface or a LAG. All network
       interfaces will also contain at least one physical Ethernet interface. 

       vlans attribute is a list of 802.1Q VLAN ids, each id is a string.
       state attribute is Lib.LINK_STATE_nn, where nn= UP, DOWN, UNKNOWN
   '''

   class EthernetIntf( object ):
      ''' Represents a physical Ethernet interface. '''
      def __init__( self, name, state=LINK_STATE_UNKNOWN ):
         self.name = name
         self.state = state
         self.macAddr = ''

      def __str__( self ):
         return '%s %s %s' % ( self.name, self.macAddr, self.state )

   def __init__( self, name, vlans=None, state=LINK_STATE_UNKNOWN,
                 isLag=False, isEthernet=False, isSubIntf=False, zone='',
                 ipAddr='' ):
      self.vlans = vlans if vlans else []
      self.name = name      # intf name on service device, e.g. ae3, eth1/7.10, port2
      self.state = state       # if isLag this is LAG link state, else Ethernet state
      self.zone = zone         # zone name if available
      self.ipAddr = ipAddr     # IP address if L3 intf
      self.vrf = ''            # vrf for IP address if L3 intf
      self.physicalIntfs = []  # list of EthernetIntf objects
      self.attribs = {}        # dict for any other useful attributes
      self.isLag = isLag
      self.isSubIntf = isSubIntf
      if isEthernet:
         self.addPhysicalIntf( name, state )

   def addPhysicalIntf( self, name, state=LINK_STATE_UNKNOWN ):
      self.physicalIntfs.append( NetworkInterface.EthernetIntf( name, state ) )

   def displayName( self ):
      if ( ( self.physicalIntfs and self.physicalIntfs[ 0 ].name == self.name ) or
           ( self.isSubIntf and len( self.physicalIntfs ) < 2 ) ):
         return self.name
      else:
         return '%s=%s' % ( 
            self.name, '+'.join( pi.name for pi in self.physicalIntfs ) )

   def __str__( self ):
      vlans = ','.join( self.vlans )
      out = '%s %s lag=%s sub=%s zn=%s %s %s' % (
         self.displayName(), self.state,
         'Y' if self.isLag else 'N',
         'Y' if self.isSubIntf else 'N',
         self.zone,
         'vl=%s' % vlans if vlans else '',
         'ip=%s' % self.ipAddr if self.ipAddr else '' )
      for key, value in sorted( self.attribs.items() ):  # sort by key
         key = key.replace( 'vwire', 'vw' )  # keep output short for trace logs
         out += ' %s=%s' % ( key, value )
      return out

####################################################################################
class ServiceDeviceHAState( object ):
   ''' Represents a ServiceDevice's current High Availability state. 
   '''
   def __init__( self ):
      self.isSingleLogicalDeviceHaModel = False
      self.enabled = False
      self.mode = ''            # ACTIVE_PASSIVE or ACTIVE_ACTIVE
      self.state = ''           # ACTIVE, PASSIVE, ACTIVE_PRIMARY, ACTIVE_SECONDARY
      self.peerDeviceState = ''
      self.mgmtIp = ''
      self.peerMgmtIp = ''

   def isHaPassiveOrSecondary( self ):
      ''' Return True if device is currently the high availability
          passive or active-secondary device in an HA pair.
      '''
      return ( self.enabled and
              ( self.mode == HA_ACTIVE_PASSIVE and self.state != HA_ACTIVE  or
                self.mode == HA_ACTIVE_ACTIVE and self.state != HA_ACTIVE_PRIMARY ) )

   def getPeerManagementIp( self ):
      return self.peerMgmtIp.split( '/' )[ 0 ] if self.peerMgmtIp else ''

   def __str__( self ):
      out = 'HA enabled: %s' % self.enabled
      if self.enabled:
         out += ' mode: %s state: %s IP: %s  %s' % (
            self.mode, self.state, self.mgmtIp,
            'isSingleLogicalDeviceHaModel' if self.isSingleLogicalDeviceHaModel else
            'peerState: %s peerIP: %s' % ( self.peerDeviceState, self. peerMgmtIp ) )
      return out

####################################################################################
class ServiceDeviceRoutingTables( object ):
   ''' Represents a ServiceDevice's routing tables, one table per vrf
   '''
   def __init__( self ):
      self.routingTables = {}
      self.featureSupported = False

   def iteritems( self ):
      return self.routingTables.iteritems()

   def getVrfList( self ):
      ''' Returns the list of vrf names
      '''
      return self.routingTables.keys()

   def addVrf( self, vrfName ):
      if vrfName not in self.routingTables:
         self.routingTables[ vrfName ] = []

   def delVrf( self, vrfName ):
      if vrfName in self.routingTables:
         del self.routingTables[ vrfName ]

   def addRoute( self, vrfName, destination, interface, nexthop ):
      if vrfName not in self.routingTables:
         self.addVrf( vrfName )
      self.routingTables[ vrfName ].append( ( destination, interface, nexthop ) )

   def delRoute( self, vrfName, destination, interface, nexthop ):
      if vrfName in self.routingTables:
         self.routingTables[ vrfName ].remove( ( destination, interface, nexthop ) )

   def __str__( self ):
      output = 'Feature Supported: ' + str( self.featureSupported ) + '\n'
      for vrf, table in self.routingTables.items() :
         output += vrf + '\n'
         for route in table:
            output += '   ' + str( route ) + '\n'
      return output

####################################################################################
#  Software Interfaces for MssPolicyMonitor Plugins.
#    See: PEP 3119 for info on abstract classes
#
class IServiceDevicePlugin( object ):
   ''' Individual firewall with L2 vwire data plane plugin software interface
   '''
   __metaclass__ = ABCMeta

   @abstractmethod
   def __init__( self, deviceConfig ):
      ''' deviceConfig will always include the following keys:
            'protocol'    API access protocol (normally http or https)
            'ipAddress'   Management IP address or DNS name
            'username'    for API authentication
            'password'    for API authentication
      '''
      pass

   @abstractmethod
   def getInterfacesInfo( self, resolveZoneNames=True ):
      ''' Returns a list of NetworkInterface objects
      '''
      pass

   @abstractmethod
   def getInterfaceNeighbors( self ):
      ''' Get network interface neighbor information from service device.
          Returns a dict where: 
            key= service device interface name (string)
            value= dict with keys and values:
               'switchIntf'       Neighbor device interface name
               'switchChassisId'  Neighbor chassis id/MAC address
               'nborMgmtIp'       Neighbor device Mgmt IP address field  (optional)
               'nborSysName'      Neighbor device system/DNS name  (optional)
               'nborDesc'         Neighbor device description field (optional)
      '''
      pass

   @abstractmethod
   def getDeviceInfo( self ):
      ''' Return dict with at least these keys: 'ipAddr', 'name', 'model'
      '''
      pass

   @abstractmethod
   def getDeviceResources( self ):
      ''' Returns a dict with status on device resources that might include
          data on cpu, memory, process status, disk and file systems.
          In phase 1 all device info is returned with key 'resourceInfo' and
          value is nicely formatted print string.
      '''
      pass

   @abstractmethod
   def getDeviceRoutingTables( self ):
      ''' Returns a ServiceDeviceRoutingTables object
      '''
      pass

   @abstractmethod
   def closeApiConnection( self ):
      ''' close any open connections to the service device
      '''
      pass

#-----------------------------------------------------------------------------------
class IPolicyPlugin( object ):
   ''' Traffic attracting policy plugin software interface
   '''
   __metaclass__ = ABCMeta

   @abstractmethod
   def getPolicies( self, mssTags=None ):
      ''' Returns a list of ServiceDevicePolicy objects when implemented by an
          individual service device plugin.
          When implemented by an aggregation manager plugin returns a dict where
          keys are mgmt IP addresses of service devices and values are a list of
          ServiceDevicePolicy objects for each service device.
      '''
      pass

#-----------------------------------------------------------------------------------
class IHAStatePlugin( object ):
   ''' High Availability State plugin software interface
   '''
   __metaclass__ = ABCMeta

   @abstractmethod
   def getHighAvailabilityState( self ):
      ''' Returns a ServiceDeviceHAState object with current
          High Availability State for the service device.
      '''
      pass

   @abstractproperty
   def isSingleLogicalDeviceHaModel( self ):
      ''' Returns True if the service device uses a single IP address for a
          cluster of HA devices (i.e. use shared virtual MAC address).
      '''
      pass

#-----------------------------------------------------------------------------------
class IAggregationMgrPlugin( object ):
   ''' Aggregation manager plugin software interface
   '''
   __metaclass__ = ABCMeta

   @abstractmethod
   def __init__( self, deviceConfig ):
      ''' deviceConfig will include the following keys:
            'protocol'    API access protocol (normally http or https)
            'ipAddress'   Management IP address or DNS name
            'username'    for API authentication
            'password'    for API authentication
      '''
      pass

   @abstractmethod
   def getAggMgrGroupMembers( self, groupName ):
      ''' Return a list of the group members accessible from the aggregation
          manager (e.g. Palo Alto Networks Panorama, Check Point Software
          Security Management Server)
      '''
      pass

   @abstractmethod
   def getDeviceInfo( self ):
      ''' Return dict with at least these keys: 'ipAddr', 'name', 'model'
      '''
      pass

   @abstractmethod
   def closeApiConnection( self ):
      ''' close any open connections to the service device
      '''
      pass
