# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
import Tac
from ClassificationLib import ( addOrDeleteRange, numericalRangeToSet,
                                numericalRangeToRangeString,
                                rangeSetToNumericalRange,
                                prefixesACoveredByB )
from eunuchs.in_h import IPPROTO_TCP
appConstants = Tac.Value( 'Classification::Constants' )
ProtocolRange = Tac.Type( 'Classification::ProtocolRange' )
IpGenPrefix = Tac.Type( 'Arnet::IpGenPrefix' )
FragmentType = Tac.Type( 'Classification::FragmentType' )
matchOffset = FragmentType.matchOffset
IcmpTypeRange = Tac.Type( 'Classification::IcmpTypeRange' )
UniqueId = Tac.Type( "Ark::UniqueId" )

class MatchRuleBaseContext( object ):
   def __init__( self, ruleName, matchOption ):
      self.ruleName = ruleName
      self.matchOption = matchOption
      self.filter = Tac.newInstance( "Classification::StructuredFilter", "" )

   def setAction( self, actionType, actionValue=None, no=False, clearActions=None ):
      raise NotImplementedError

   def copyEditMatchRule( self, ruleName, seqnum ):
      raise NotImplementedError

   def newEditMatchRule( self, ruleName, seqnum ):
      raise NotImplementedError

   def isValidConfig( self, conflictType ):
      if not self.filter:
         return False
      return self.filter.isValidConfig( conflictType )

   def addOrRemovePrefix( self, prefixes, filterType, add ):
      if filterType not in [ 'source', 'destination' ]:
         return
      # remove any previously configured prefixes
      if filterType == 'source':
         fieldSet = getattr( self.filter, 'srcPrefixSet' )
      else:
         fieldSet = getattr( self.filter, 'dstPrefixSet' )
      fieldSet.clear()

      prefixList = getattr( self.filter, filterType )
      for prefix in prefixes:
         prefix = IpGenPrefix( str( prefix ) )
         if add:
            prefixList[ prefix ] = True
         else:
            del prefixList[ prefix ]

   def updatePrefixFieldSet( self, source=True, names=None, add=True ):
      # clear all source and destination configurations
      srcFilterTypes = [ 'source', 'sourceExcept' ]
      dstFilterTypes = [ 'destination', 'destinationExcept' ]
      clearFilterTypes = srcFilterTypes if source else dstFilterTypes
      for f in clearFilterTypes:
         prefixList = getattr( self.filter, f )
         prefixList.clear()

      # add or remove fieldSet
      if source:
         fieldSetList = getattr( self.filter, 'srcPrefixSet' )
      else:
         fieldSetList = getattr( self.filter, 'dstPrefixSet' )
      if add:
         for n in names:
            fieldSetList.add( n )
      else:
         if names:
            for n in names:
               del fieldSetList[ n ]
         else:
            fieldSetList.clear()

   def getProto( self ):
      return self.filter.proto

   def updateTcpFlags( self, tcpFlags, notFlags, add=True ):
      """
      Updates TcpFlagAndMask for a given match rule
      """
      # XXX handle NOT option
      acceptedFlags = [ 'est', 'init' ]
      for tcpFlag in tcpFlags:
         if tcpFlag not in acceptedFlags:
            return
         tcpProto = ProtocolRange( IPPROTO_TCP, IPPROTO_TCP )
         protoField = self.filter.proto.get( tcpProto )
         if not protoField:
            return
         currTcpFlagAndMask = protoField.tcpFlagAndMask
         newTcpFlag = Tac.Value( "Classification::TcpFlag",
                                 currTcpFlagAndMask.tcpFlag.value )
         newTcpFlagMask = Tac.Value( "Classification::TcpFlag",
                                     currTcpFlagAndMask.tcpFlagMask.value )
         if notFlags:
            # XXX to be handled
            pass
         else:
            setattr( newTcpFlag, tcpFlag, add )
            setattr( newTcpFlagMask, tcpFlag, add )
         protoField.tcpFlagAndMask = Tac.Value( "Classification::TcpFlagAndMask",
                                                newTcpFlag, newTcpFlagMask )

   def updateRangeAttr( self, attrName, rangeType, rangeSet, add=True ):
      acceptedAttr = [ 'ttl', 'sport', 'dport', 'proto', 'dscp', 'length',
                       'fragmentOffset' ]
      if attrName not in acceptedAttr:
         return
      # retrieve the current collection of NumericalRange objs
      if attrName == "fragmentOffset":
         currentAttrList = getattr( self.filter, "fragment" ).offset
      else:
         currentAttrList = getattr( self.filter, attrName )
      newRangeList = addOrDeleteRange( currentAttrList,
                                       rangeSet, rangeType, add )
      # Cannot delete all members and re-add for protocols because protocols have
      # additional protoFields
      for currAttr in currentAttrList:
         if currAttr not in newRangeList:
            del currentAttrList[ currAttr ]
      for aRange in newRangeList:
         if aRange not in currentAttrList:
            if attrName == 'proto':
               currentAttrList.newMember( aRange )
            else:
               currentAttrList.add( aRange )

   def updateIcmpRangeAttr( self, icmpValue, rangeType, rangeSet, add=True ):
      # Each ICMP type may or may not have ICMP codes configured. For example, we
      # have type [1,2], type [3,3] code [5,10], type [10,20], when type [2,4] is
      # added. Firstly merge type [2,4] with type [1,2] and [10,20]. Then remove
      # type [ 3, 3 ] code [ 5, 10 ] because it covered by the added type [ 2, 4 ].
      icmpProto = ProtocolRange( icmpValue, icmpValue )
      if self.filter.proto.get( icmpProto ) is None:
         return
      icmpTypeRangeColl = self.filter.proto[ icmpProto ].icmpType
      # Remove all ICMP types if type is not specified
      if not rangeSet:
         for aRange in icmpTypeRangeColl:
            del icmpTypeRangeColl[ aRange ]
         return
      typeRangeWithCodeList = []
      typeRangeWithoutCodeList = []
      for aTypeRange in icmpTypeRangeColl:
         if icmpTypeRangeColl[ aTypeRange ].icmpCode:
            typeRangeWithCodeList.append( aTypeRange )
         else:
            typeRangeWithoutCodeList.append( aTypeRange )
      newRangeList = addOrDeleteRange( typeRangeWithoutCodeList, rangeSet,
                                       rangeType, add )
      for currRange in typeRangeWithoutCodeList:
         if currRange not in newRangeList:
            del icmpTypeRangeColl[ currRange ]
      for aRange in newRangeList:
         if aRange not in typeRangeWithoutCodeList:
            icmpTypeRangeColl.newMember( aRange )
      for aRange in typeRangeWithCodeList:
         aRangeValue = numericalRangeToSet( [ aRange ] ).pop()
         if aRangeValue in rangeSet:
            # e.g., when 'protocol icmp type 3 code all' is configured after
            # 'protocol icmp type 3 code 1', we have aRangeValue=3 and
            # rangeSet = set( [ 3 ] ), so update to type [3,3] code all.
            if len( rangeSet ) == 1:
               icmpTypeRangeColl[ aRange ].icmpCode.clear()
            # e.g., when 'protocol icmp type 1-3 code all' is configured after
            # 'protocol icmp type 3 code 1', we have aRangeValue=3 and
            # rangeSet = set( [ 1,2,3 ] ), so remove type [3,3] code 1.
            else:
               del icmpTypeRangeColl[ aRange ]

   def addIcmpTypeCodeRangeAttr( self, icmpValue, typeRangeType, typeValue,
                                 codeRangeType, codeSet ):
      icmpProto = ProtocolRange( icmpValue, icmpValue )
      if self.filter.proto.get( icmpProto ) is None:
         return
      currIcmpTypeRangeSet = self.filter.proto[ icmpProto ].icmpType
      currIcmpTypeValueSet = numericalRangeToSet( currIcmpTypeRangeSet )
      icmpType = IcmpTypeRange( typeValue, typeValue )
      # e.g., IcmpType [3, 3], currIcmpTypeRangeSet [1, 2], add the new ICMP type and
      # its codes.
      if typeValue not in currIcmpTypeValueSet:
         currIcmpTypeRangeSet.newMember( icmpType )
         icmpCodeRangeList = rangeSetToNumericalRange( codeSet, codeRangeType )
         for icmpCodeRange in icmpCodeRangeList:
            currIcmpTypeRangeSet[ icmpType ].icmpCode.add( icmpCodeRange )
      # e.g., icmpType [3, 3], currIcmpTypeRangeSet [1, 2], [3, 3], only add the ICMP
      # codes when it's not code all.
      elif icmpType in currIcmpTypeRangeSet:
         currIcmpCodeRangeSet = currIcmpTypeRangeSet[ icmpType ].icmpCode
         if not currIcmpCodeRangeSet:
            return
         newCodeRangeSet = addOrDeleteRange( currIcmpCodeRangeSet, codeSet,
                                             codeRangeType, add=True )
         for currCode in currIcmpCodeRangeSet:
            if currCode not in newCodeRangeSet:
               del currIcmpCodeRangeSet[ currCode ]
         for aCode in newCodeRangeSet:
            currIcmpCodeRangeSet.add( aCode )
      # e.g., icmpType [3, 3], currIcmpTypeRangeSet [1, 4], [10, 20], ignore it
      # because it is coverd by code all.

   def removeIcmpTypeCodeRangeAttr( self, icmpValue, typeRangeType, typeValue,
                                    codeRangeType, codeSet ):
      icmpProto = ProtocolRange( icmpValue, icmpValue )
      currIcmpTypeRangeSet = self.filter.proto[ icmpProto ].icmpType
      icmpType = IcmpTypeRange( typeValue, typeValue )
      # If currIcmpTypeRangeSet has icmpType and it's not code all, remove codes of
      # this ICMP type, otherwise do nothing.
      # e.g., icmpType [ 3, 3 ], currIcmpTypeRangeSet [ 1, 2 ], [ 3, 3 ], if the
      # current ICMP type has code all, do nothing. If not, remove the codes from the
      # current codes.
      if icmpType in currIcmpTypeRangeSet:
         currIcmpCodeRangeSet = currIcmpTypeRangeSet[ icmpType ].icmpCode
         if not currIcmpCodeRangeSet:
            return
         newCodeRangeSet = addOrDeleteRange( currIcmpCodeRangeSet, codeSet,
                                             codeRangeType, add=False )
         for currCode in currIcmpCodeRangeSet:
            if currCode not in newCodeRangeSet:
               del currIcmpCodeRangeSet[ currCode ]
         for aCode in newCodeRangeSet:
            if aCode not in currIcmpCodeRangeSet:
               currIcmpCodeRangeSet.add( aCode )
         # When all ICMP codes are removed, also remove the corresponding ICMP type.
         if not currIcmpCodeRangeSet:
            del currIcmpTypeRangeSet[ icmpType ]

   def updatePortFieldSetAttr( self, attrName, fieldSetNames, add=True,
                               protoSet=None ):
      """
      Adds/removes l4 field-sets for a given protocol. When field-set is applied,
      all configured l4 ports are clear. (both sport and dport)
      Example:
         original config: protocol tcp source port 10 destination port 50
         new config: protocol tcp source port field-set sample
         result: protocol tcp source port field-set sample
                 both source port of 10 and destination port of 50 are removed
      """
      if attrName not in [ 'sportFieldSet', 'dportFieldSet' ]:
         return
      protoColl = getattr( self.filter, 'proto' )

      protos = rangeSetToNumericalRange( protoSet, "Classification::ProtocolRange" )
      # for each protocol remove the l4 ports, if no more port remaining remove proto
      # XXXBUG494947
      for proto in protos:
         if add and not protoColl[ proto ].port:
            if fieldSetNames:
               portField = protoColl[ proto ].port.newMember( UniqueId() )
            for fs in fieldSetNames:
               getattr( portField, attrName ).add( fs )
            continue
         # There's only one element in protoColl[ proto ].port
         portField = protoColl[ proto ].port.values().pop()
         portProtoAttrList = getattr( portField, attrName )
         # After 'protocol tcp source port 10 destination port 20', when
         # 'no protocol tcp source port field-set bar destination port field-set bar'
         # is configured, 'sport' and 'dport' should not be cleared.
         if not add and not portProtoAttrList:
            return
         # Ensure l4 ports are cleared. Cannot have l4 ports and field-set
         # configured in the same line
         portField.sport.clear()
         portField.dport.clear()
         for fs in fieldSetNames:
            if add:
               getattr( portField, attrName ).add( fs )
            else:
               getattr( portField, attrName ).remove( fs )

   def updatePortRangeAttr( self, attrName, protoSet, portSet, add=True,
                            clearPrev=False ):
      """
      Adds/removes l4 ports for a given protocol. When port is applied,
      all configured l4 field-sets are clear for that type (sport/dport)
      """
      if attrName not in [ 'sport', 'dport' ]:
         return
      protoColl = getattr( self.filter, 'proto' )

      protos = rangeSetToNumericalRange( protoSet, "Classification::ProtocolRange" )
      # for each protocol remove the l4 ports, if no more port remaining remove proto
      # XXXBUG494947
      for proto in protos:
         if add and not protoColl[ proto ].port:
            newPortList = rangeSetToNumericalRange( portSet,
                                                    "Classification::PortRange" )
            if newPortList:
               portField = protoColl[ proto ].port.newMember( UniqueId() )
            for pRange in newPortList:
               getattr( portField, attrName ).add( pRange )
            continue
         # There's only one element in protoColl[ proto ].port
         portField = protoColl[ proto ].port.values().pop()
         portProtoAttrList = getattr( portField, attrName )
         if clearPrev:
            portProtoAttrList.clear()
         # After 'protocol tcp source port field-set bar destination port field-set
         # bar 20', when 'no protocol tcp source port 10 destination port 20' is
         # configured, 'sportFieldSet' and 'dportFieldSet' should not be cleared.
         if not add and not portProtoAttrList:
            continue
         # Ensure field-set cleared. Cannot have field-set and l4 ports configured
         portField.sportFieldSet.clear()
         portField.dportFieldSet.clear()
         newPortList = addOrDeleteRange( portProtoAttrList, portSet,
                                         "Classification::PortRange", add )
         getattr( portField, attrName ).clear()
         for pRange in newPortList:
            getattr( portField, attrName ).add( pRange )

   def maybeUpdateProto( self, protocolRangeSet ):
      """
      Only called when removing additional protocol fields (sport/dport/flags).
      When all additional fields have been removed, we removed the protocol too
      """
      protos = rangeSetToNumericalRange( protocolRangeSet,
                                         "Classification::ProtocolRange" )
      protoColl = getattr( self.filter, 'proto' )
      for proto in protos:
         for portId, portField in protoColl[ proto ].port.iteritems():
            sport = portField.sport
            sportFieldSet = portField.sportFieldSet
            dport = portField.dport
            dportFieldSet = portField.dportFieldSet
            if sport or sportFieldSet or dport or dportFieldSet:
               continue
            del protoColl[ proto ].port[ portId ]
         protoField = protoColl[ proto ]
         tcpFlagMask = protoField.tcpFlagAndMask.tcpFlagMask
         icmpType = protoField.icmpType
         port = protoField.port
         if port or tcpFlagMask or icmpType:
            continue
         del protoColl[ proto ]

   def updateMatchAllFragment( self, add=False ):
      self.filter.matchAllFragments = add

   def updateMatchIpOptions( self, add=False ):
      self.filter.matchIpOptions = add

   def commit( self ):
      raise NotImplementedError

   def updateFilterAttr( self, attrName, attrValue ):
      setattr( self.filter, attrName, attrValue )

   def updateFragmentType( self, fragmentType ):
      if self.filter.fragment is None:
         self.filter.fragment = ( fragmentType, )
      elif self.filter.fragment.fragmentType != fragmentType:
         self.filter.fragment = None
         self.filter.fragment = ( fragmentType, )

   def clearFragment( self ):
      self.filter.fragment = None

   def maybeDelFragment( self ):
      if self.filter.fragment and self.filter.fragment.fragmentType == matchOffset \
         and not self.filter.fragment.offset:
         self.filter.fragment = None

#------------------------------------------------------------------------------------
# Context
# A context is created for FieldSet and stores the command till the user exits the
# mode or aborts the changes.
#
# If the FieldSet exists already, the context contains an editable copy of the
# contents; else it contains a new (editable) copy.
#------------------------------------------------------------------------------------
class FieldSetBaseContext( object ):
   def __init__( self, fieldSetName, fieldSetConfig, childMode ):
      self.fieldSetName = fieldSetName
      self.fieldSetConfig = fieldSetConfig
      self.childMode = childMode
      self.mode_ = None
      self.fieldSetColl = None # set in child class
      self.prevFieldSet = None # set in child class
      self.editFieldSet = None # set in child class
      self.editFieldSetVersion = None # set in child class

   def copyEditFieldSet( self ):
      raise NotImplementedError

   def newEditFieldSet( self ):
      raise NotImplementedError

   def updateFieldSet( self, data, add, **kwargs ):
      raise NotImplementedError

   def identicalFieldSet( self, left, right ):
      raise NotImplementedError

   def copy( self, src, dst ):
      raise NotImplementedError

   def hasPrefixFieldSet( self, name, af ):
      if self.fieldSetColl is None:
         return False
      fieldSet = self.fieldSetColl.get( name )
      if fieldSet:
         return fieldSet.af == af
      return False

   def hasL4PortFieldSet( self, name ):
      if self.fieldSetColl is None:
         return False
      return name in self.fieldSetColl

   def delFieldSet( self, name ):
      if self.fieldSetColl is None:
         return
      del self.fieldSetColl[ name ]

   def modeIs( self, mode ):
      self.mode_ = mode

   def commit( self ):
      raise NotImplementedError

   def abort( self ):
      raise NotImplementedError

class L4PortFieldSetContext( FieldSetBaseContext ):
   def __init__( self, fieldSetL4PortName, fieldSetConfig, childMode ):
      super( L4PortFieldSetContext, self ).__init__( fieldSetL4PortName,
                                                     fieldSetConfig,
                                                     childMode )
      self.fieldSetColl = self.fieldSetConfig.fieldSetL4Port
      self.setType = 'l4-port'

   def copyEditFieldSet( self ):
      self.prevFieldSet = self.fieldSetColl[ self.fieldSetName ].currCfg
      self.editFieldSet = Tac.newInstance( 'Classification::FieldSetL4PortSubConfig',
                                           self.fieldSetName, UniqueId() )
      self.editFieldSetVersion = self.editFieldSet.version
      self.copy( self.prevFieldSet,
                 self.editFieldSet )

   def newEditFieldSet( self ):
      self.editFieldSet = Tac.newInstance( 'Classification::FieldSetL4PortSubConfig',
                                           self.fieldSetName, UniqueId() )
      self.editFieldSetVersion = self.editFieldSet.version

   def updateFieldSet( self, data, add, **kwargs ):
      allPorts = kwargs.get( 'allPorts', False )
      if allPorts:
         self.editFieldSet.ports.clear()
         if not add:
            return
         self.editFieldSet.ports.add(
            Tac.Value( 'Classification::PortRange', 0, appConstants.maxL4Port ) )
         return

      currentSet = numericalRangeToSet( self.editFieldSet.ports.keys() )
      updatedSet = set()
      if add:
         updatedSet = currentSet | data
      else:
         updatedSet = currentSet - data
      self.editFieldSet.ports.clear()

      newRangeList = rangeSetToNumericalRange( updatedSet,
                                               "Classification::PortRange" )
      for aRange in newRangeList:
         self.editFieldSet.ports.add( aRange )

   def identicalFieldSet( self, left, right ):
      if not left and not right:
         return True
      if not left or not right:
         return False
      leftPortStr = numericalRangeToRangeString( left.ports )
      rightPortStr = numericalRangeToRangeString( right.ports )
      return leftPortStr == rightPortStr

   def copy( self, src, dst ):
      if not src or not dst:
         return

      if self.identicalFieldSet( src, dst ):
         return

      dst.ports.clear()
      for aRange in src.ports:
         dst.ports.add( aRange )

   def commit( self ):
      # commit to parent context
      if self.editFieldSet:
         fieldSetCfg = self.fieldSetColl.get( self.fieldSetName )
         if not fieldSetCfg:
            fieldSetCfg = self.fieldSetColl.newMember( self.fieldSetName )
         elif self.identicalFieldSet( fieldSetCfg.currCfg, self.editFieldSet ):
            # No config change - nothing to update.
            return
         # Push subconfig to Sysdb.
         editSubCfg = fieldSetCfg.subConfig.newMember( self.fieldSetName,
                                                       self.editFieldSetVersion )
         self.copy( self.editFieldSet, editSubCfg )

         # Switch pointer to new sub-config to complete the commit.
         fieldSetCfg.currCfg = editSubCfg

         if self.prevFieldSet:
            del fieldSetCfg.subConfig[ self.prevFieldSet.version ]

   def abort( self ):
      self.fieldSetName = None
      self.editFieldSet = None
      self.prevFieldSet = None
      self.editFieldSetVersion = None

class IpPrefixFieldSetContext( FieldSetBaseContext ):
   def __init__( self, fieldSetIpPrefixName, fieldSetConfig, setType='ipv4',
                 childMode=None ):
      super( IpPrefixFieldSetContext, self ).__init__( fieldSetIpPrefixName,
                                                       fieldSetConfig,
                                                       childMode )
      if setType == "ipv4":
         self.fieldSetColl = self.fieldSetConfig.fieldSetIpPrefix
      else:
         self.fieldSetColl = self.fieldSetConfig.fieldSetIpv6Prefix
      self.setType = setType

   def copyEditFieldSet( self ):
      self.prevFieldSet = self.fieldSetColl[ self.fieldSetName ].currCfg
      self.editFieldSet = (
         Tac.newInstance( 'Classification::FieldSetIpPrefixSubConfig',
                          self.fieldSetName, UniqueId() ) )
      self.editFieldSetVersion = self.editFieldSet.version
      self.copy( self.prevFieldSet, self.editFieldSet )

   def newEditFieldSet( self ):
      self.editFieldSet = (
            Tac.newInstance( 'Classification::FieldSetIpPrefixSubConfig',
                             self.fieldSetName, UniqueId() ) )
      self.editFieldSetVersion = self.editFieldSet.version

   def updateFieldSet( self, data, add=True, **kwargs ):
      if kwargs.get( 'updateExcept' ):
         prefixSet = self.editFieldSet.exceptPrefix
      else:
         prefixSet = self.editFieldSet.prefixes
      op = prefixSet.add if add else prefixSet.remove
      for datum in data:
         prefix = IpGenPrefix( str( datum ) )
         op( prefix )

   def exceptCoveredByAcceptPrefixFieldSet( self ):
      """
      Ensure that all except prefixes lie within list of acceptable prefixes
      i.e:
         accept - 10.0.0.0/8
         except - 10.0.0.0/24
         except prefix is a subset of the accept prefix so we are able to exclude
         packets in this range
      """
      acceptPrefixes = self.editFieldSet.prefixes.keys()
      exceptPrefixes = self.editFieldSet.exceptPrefix.keys()
      return prefixesACoveredByB( exceptPrefixes, acceptPrefixes,
                                  af=self.setType )

   def identicalFieldSet( self, left, right ):
      if not left and not right:
         return True
      if not left or not right:
         return False

      return ( set( left.prefixes ) == set( right.prefixes ) and
               set( left.exceptPrefix ) == set( right.exceptPrefix ) )

   def copy( self, src, dst ):
      if not src or not dst:
         return

      if self.identicalFieldSet( src, dst ):
         return

      dst.prefixes.clear()
      for prefix in src.prefixes.keys():
         dst.prefixes.add( prefix )
      dst.exceptPrefix.clear()
      for prefix in src.exceptPrefix.keys():
         dst.exceptPrefix.add( prefix )

   def commit( self ):
      # commit to parent context
      if self.editFieldSet:
         fieldSetCfg = self.fieldSetColl.get( self.fieldSetName )
         if not fieldSetCfg:
            fieldSetCfg = self.fieldSetColl.newMember( self.fieldSetName,
                                                       self.setType )
         elif self.identicalFieldSet( fieldSetCfg.currCfg, self.editFieldSet ):
            return
         editSubCfg = fieldSetCfg.subConfig.newMember( self.fieldSetName,
                                                       self.editFieldSetVersion )
         self.copy( self.editFieldSet, editSubCfg )

         fieldSetCfg.currCfg = editSubCfg
         if self.prevFieldSet:
            del fieldSetCfg.subConfig[ self.prevFieldSet.version ]

   def abort( self ):
      self.fieldSetName = None
      self.editFieldSet = None
      self.prevFieldSet = None
      self.editFieldSetVersion = None
