#!/usr/bin/env python
# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

""" Common code shared among TrafficPolicy CliPlugins """

import CliCommand
import PolicyMapCliLib
from ClassificationCliContextLib import MatchRuleBaseContext
from CliMode.TrafficPolicy import ( MatchRuleIpv4ConfigMode,
                                    MatchRuleIpv6ConfigMode,
                                    MatchRuleDefaultConfigMode,
                                    ActionsConfigMode,
                                    TrafficPolicyConfigMode )
from ClassificationCliLib import ( ProtocolMixin,
                                   generateTcpFlagExpression )
import Tac
import BasicCliUtil

ActionType = Tac.Type( "PolicyMap::ActionType" )
CHANGED = PolicyMapCliLib.CHANGED
IDENTICAL = PolicyMapCliLib.IDENTICAL
RESEQUENCED = PolicyMapCliLib.RESEQUENCED
UniqueId = Tac.Type( "Ark::UniqueId" )
ClassPriorityConstant = Tac.Type( 'TrafficPolicy::ClassPriorityConstant' )
tacMatchOption = Tac.Type( 'PolicyMap::ClassMapMatchOption' )
matchIpAccessGroup = tacMatchOption.matchIpAccessGroup
matchIpv6AccessGroup = tacMatchOption.matchIpv6AccessGroup
ReservedClassMapNames = Tac.Type( 'TrafficPolicy::ReservedClassMapNames' )

neighborsConfigConflictMsg = (
   "The 'protocol neighbors' subcommand is not supported"
   " when any other match subcommands are configured" )
matchL4ProtocolConflictMsg = (
   "The 'protocol bgp' subcommand is not supported"
   " when any other match subcommands are configured" )
invalidPortConflictMsg = (
      "The '%s' subcommand is not supported if protocols other than "
      "'{tcp|udp|tcp udp}' are configured"
      )
invalidProtocolConflictMsg = (
      "The 'protocol' subcommand only supports 'tcp' or 'udp' if"
      " '{source|destination} port' is configured"
      )

def policyHasAction( action, config, status, policyName ):
   if policyName in config.pmapType.pmap:
      currCfg = config.pmapType.pmap[ policyName ].currCfg
      for classAction in currCfg.classAction.values():
         for actionType in classAction.policyAction.keys():
            if actionType == action:
               return True
   return False

def protectedTrafficPolicyNamesRegex():
   excludeKeywords = [ 'cpu', 'interface' ]
   excludePattern = ''.join( BasicCliUtil.notAPrefixOf( k )
         for k in excludeKeywords )
   pattern = excludePattern + r'[A-Za-z0-9_:{}\[\]-]+'
   return pattern

class TrafficPolicyMatchRuleAction( PolicyMapCliLib.PolicyRawRuleActionBase ):
   keyTag_ = 'matchRuleFilter'

   def __init__( self, trafficPolicyContext, ruleName, matchOption, sfilter ):
      super( TrafficPolicyMatchRuleAction, self ).__init__( trafficPolicyContext,
                                                            ruleName,
                                                            None,
                                                            matchOption )
      self.filter = sfilter

   def key( self ):
      return self.cmapName

   def addToPmap( self, pmap, seqnum ):
      super( TrafficPolicyMatchRuleAction, self ).addToPmap( pmap, seqnum )
      if self.cmapName in pmap.rawClassMap:
         cmap = pmap.rawClassMap[ self.cmapName ]
      else:
         cmap = pmap.rawClassMap.newMember( self.cmapName, UniqueId() )
      cmap.match.clear()
      self.cmapMatch = cmap.match.newMember( self.matchOption )
      self.cmapMatch.structuredFilter = ( "", )
      # When addRuleCommon is called, the match rule will be removed and re-added
      # this is needed in order to save structured filter
      self.cmapMatch.structuredFilter.copy( self.filter )

   def actionCombinationError( self ):
      actions = self.actions()
      if 'log' in actions:
         if 'deny' not in actions:
            return "The 'log' action cannot be used without the 'drop' action"
      return None

class TrafficPolicyContext( PolicyMapCliLib.PolicyMapContext ):
   pmapType = 'mapTrafficPolicy'

   def __init__( self, config, statusReqDir, status, trafficPolicyName ):
      PolicyMapCliLib.PolicyMapContext.__init__( self, config, statusReqDir,
                                                 status, trafficPolicyName,
                                                 self.pmapType )
      self.matchRuleContext = None
      self.shouldResequence = False

   def childMode( self ):
      return TrafficPolicyConfigMode

   def hasPolicy( self, name ):
      return name in self.config().pmapType.pmap

   def mapTypeStr( self ):
      return 'traffic-policy'

   def reservedClassMapNames( self ):
      return [ ReservedClassMapNames.classV4Default,
               ReservedClassMapNames.classV6Default ]

   def initializeRuleToSeq( self ):
      super( TrafficPolicyContext, self ).initializeRuleToSeq()
      self.ruleToSeq[ TrafficPolicyMatchRuleAction.keyTag() ] = \
         PolicyMapCliLib.PolicyRuleToSeqDict()

   def setRuleToSeq( self, policyMapSubConfig ):
      maxSeq = 0
      matchRuleFilterToSeq = self.ruleToSeq[ TrafficPolicyMatchRuleAction.keyTag() ]
      for prio, cmapName in policyMapSubConfig.classPrio.items():
         cmapName = policyMapSubConfig.classPrio[ prio ]
         rawCmap = policyMapSubConfig.rawClassMap.get( cmapName, None )
         assert rawCmap
         assert matchIpAccessGroup in rawCmap.match or \
                matchIpv6AccessGroup in rawCmap.match, \
                'Unsupported ClassMap matchTypes %s' % rawCmap.match
         matchRuleFilterToSeq[ cmapName ] = prio
         if cmapName in self.reservedClassMapNames():
            # The default rules should always be at the highest priority value. Even
            # after resequencing.
            assert prio in [ ClassPriorityConstant.classV4DefaultPriority,
                             ClassPriorityConstant.classV6DefaultPriority ]
            # When sequencing rules, the default rules should not be considered as
            # they should always be at the end.
            continue
         maxSeq = max( maxSeq, prio )
      return maxSeq

   def resequence( self, start, inc ):
      result = super( TrafficPolicyContext, self ).resequence( start, inc )
      if result == 'errSequenceOutOfRange':
         return result

      ipv6DefaultRuleSeq = self.lastSequence() - inc
      ipv4DefaultRuleSeq = self.lastSequence() - 2 * inc
      lastUserRuleSeq = self.lastSequence() - 3 * inc

      # Assert the last two rules are the default rules
      # Assert the proper default rules' seqnos are not occupied by other rules
      for className in self.reservedClassMapNames():
         if className == ReservedClassMapNames.classV4Default:
            assert self.npmap.classPrio[ ipv4DefaultRuleSeq ] == className
            assert ClassPriorityConstant.classV4DefaultPriority not in \
               self.npmap.classPrio
         elif className == ReservedClassMapNames.classV6Default:
            assert self.npmap.classPrio[ ipv6DefaultRuleSeq ] == className
            assert ClassPriorityConstant.classV6DefaultPriority not in \
               self.npmap.classPrio

      # Delete the old defaults
      del self.npmap.classPrio[ ipv6DefaultRuleSeq ]
      del self.npmap.classPrio[ ipv4DefaultRuleSeq ]

      # Restore the proper defaults
      for className in self.reservedClassMapNames():
         if className == ReservedClassMapNames.classV4Default:
            self.npmap.classPrio[
               ClassPriorityConstant.classV4DefaultPriority ] = className
         elif className == ReservedClassMapNames.classV6Default:
            self.npmap.classPrio[
               ClassPriorityConstant.classV6DefaultPriority ] = className


      # Clean up the rule to sequence mapping
      matchRuleFilterToSeq = self.ruleToSeq[ TrafficPolicyMatchRuleAction.keyTag() ]
      for className in self.reservedClassMapNames():
         if className == ReservedClassMapNames.classV4Default:
            matchRuleFilterToSeq.pop( ClassPriorityConstant.classV4DefaultPriority,
                                      None )
            matchRuleFilterToSeq[ className ] = ClassPriorityConstant.\
                                                classV4DefaultPriority
         elif className == ReservedClassMapNames.classV6Default:
            matchRuleFilterToSeq.pop( ClassPriorityConstant.classV6DefaultPriority,
                                      None )
            matchRuleFilterToSeq[ className ] = ClassPriorityConstant.\
                                                classV6DefaultPriority

      # Reset the last sequence to be the last user-defined rule
      self.lastSequenceIs( lastUserRuleSeq )

      self.shouldResequence = False
      self.moving = False

      return 'success'

   def currentPolicy( self ):
      return self.currentPmap()

   def copyRawClassMap( self, src, dst, mapType ):
      ''' src and dst are of type ClassMapSubConfig
      '''
      dst.matchCondition = src.matchCondition
      for option in src.match.keys():
         srcMatch = src.match[ option ]
         dst.match.newMember( option )
         dstMatch = dst.match[ option ]
         if option == matchIpAccessGroup or option == matchIpv6AccessGroup:
            dstMatch.structuredFilter = ( "", )
            dstMatch.structuredFilter.copy( srcMatch.structuredFilter )
         else:
            assert False, 'unknown match option ' + option

   def copyAction( self, src ):
      actionType = src.actionType
      actions = self.config().actions
      if actionType == ActionType.deny:
         return actions.dropAction.newMember( src.className, UniqueId() )
      elif actionType == ActionType.police:
         return actions.policeAction.newMember( src.className, UniqueId(),
                                                src.rateLimit )
      elif actionType == ActionType.count:
         return actions.countAction.newMember( src.className, UniqueId(),
                                               src.counterName )
      elif actionType == ActionType.log:
         return actions.logAction.newMember( src.className, UniqueId() )
      elif actionType == ActionType.setDscp:
         return actions.setDscpAction.newMember( src.className, UniqueId(),
                                                 src.dscp )
      elif actionType == ActionType.setTc:
         return actions.setTcAction.newMember( src.className, UniqueId(),
                                               src.tc )
      else:
         assert False, 'unknown actionType ' + actionType
      return None

   def updateNamedCounters( self, counterNames, add=True ):
      pmap = self.npmap
      if add:
         for c in counterNames:
            pmap.namedCounter.add( c )
      else:
         if not counterNames:
            # Delete all counters
            pmap.namedCounter.clear()
         else:
            # Only delete exisiting counters
            for c in counterNames:
               del pmap.namedCounter[ c ]

   def maxRules( self ):
      pass

   def lastSequenceIs( self, seqnum ):
      if seqnum not in [ ClassPriorityConstant.classV6DefaultPriority,
                         ClassPriorityConstant.classV4DefaultPriority ]:
         super( TrafficPolicyContext, self ).lastSequenceIs( seqnum )

   def maxSeq( self ):
      return ClassPriorityConstant.classPriorityMax

   def getRawTagAndFilter( self, cmapName ):
      return 'matchRuleFilter', cmapName

   def getRuleAtSeqnum( self, seqnum ):
      '''
        Return only the matching rule. If necessary, we can return
        the actions later on.
      '''
      policyRule = None
      ruleName = self.npmap.classPrio.get( seqnum, None )
      if ruleName:
         matchRule = self.npmap.rawClassMap.get( ruleName, None )
         if matchRule:
            matchOption = matchRule.match.keys()[ 0 ]
            sfilter = matchRule.match.get( matchOption ).structuredFilter
            policyRule = TrafficPolicyMatchRuleAction( self, ruleName,
                                                       matchOption, sfilter )
      return policyRule

   def removeAction( self, action ):
      actType = action.actionType
      actions = self.config().actions
      if actType == ActionType.deny:
         del actions.dropAction[ action.id ]
      elif actType == ActionType.police:
         del actions.policeAction[ action.id ]
      elif actType == ActionType.count:
         del actions.countAction[ action.id ]
      elif actType == ActionType.log:
         del actions.logAction[ action.id ]
      elif actType == ActionType.setDscp:
         del actions.setDscpAction[ action.id ]
      elif actType == ActionType.setTc:
         del actions.setTcAction[ action.id ]
      else:
         assert False, 'unknown action ' + action

   def identicalFilter( self, filter1, filter2 ):
      if filter1 is None and filter2 is None:
         return True
      elif filter1 is None or filter2 is None:
         return False
      return filter1.isEqual( filter2 )

   def identicalMatch( self, matchRule1, matchRule2 ):
      assert len( matchRule1.match.keys() ) == 1
      if matchRule1.match.keys() != matchRule2.match.keys():
         return False
      filter1 = matchRule1.match.values()[ 0 ].structuredFilter
      filter2 = matchRule2.match.values()[ 0 ].structuredFilter
      return self.identicalFilter( filter1, filter2 )

   def compareMatchRules( self, p1, p2 ):
      # compare matchRules
      p1Rules = p1.rawClassMap.keys()
      p2Rules = p2.rawClassMap.keys()
      if p1Rules != p2Rules:
         return CHANGED
      for k, matchRule in p1.rawClassMap.iteritems():
         if not self.identicalMatch( matchRule, p2.rawClassMap.get( k ) ):
            return CHANGED
      return None

   def identicalPolicyMap( self, p1, p2 ):
      if p1 is None or p2 is None:
         return CHANGED

      if p1 == p2:
         return IDENTICAL

      # compare named counters
      p1NamedCounters = p1.namedCounter
      p2NamedCounters = p2.namedCounter
      if set( p1NamedCounters ) != set( p2NamedCounters ):
         return CHANGED

      # compare class Actions
      p1ClassActions = p1.classAction.keys()
      p2ClassActions = p2.classAction.keys()
      if p1ClassActions != p2ClassActions:
         return CHANGED
      for ruleName in p1ClassActions:
         r1Action = p1.classAction[ ruleName ]
         r2Action = p2.classAction[ ruleName ]
         ret = self.identicalClassActions( r1Action, r2Action )
         if ret != IDENTICAL:
            return ret

      # compare class priorities
      p1Prios = p1.classPrio.keys()
      p2Prios = p2.classPrio.keys()
      if p1Prios != p2Prios:
         if p1.classPrio.values() == p2.classPrio.values():
            # The keys were changed, but the values are still in the same order so
            # we've either resequenced or had one of the match rules change but the
            # name is still the same.
            if self.compareMatchRules( p1, p2 ) == CHANGED:
               return CHANGED
            return RESEQUENCED
         else:
            return CHANGED
      if p1.classPrio.values() != p2.classPrio.values():
         return CHANGED
      if self.compareMatchRules( p1, p2 ) == CHANGED:
         return CHANGED

      return IDENTICAL

   def identicalClassActions( self, c1Action, c2Action ):
      # compare action types
      c1ActTypes = c1Action.policyAction.keys()
      c2ActTypes = c2Action.policyAction.keys()
      if sorted( c1ActTypes ) != sorted( c2ActTypes ):
         return CHANGED

      # compare police actions
      c1PoliceRates = [ v.rateLimit for k, v in
                        c1Action.policyAction.iteritems() if k == 'police' ]
      c2PoliceRates = [ v.rateLimit for k, v in
                        c2Action.policyAction.iteritems() if k == 'police' ]
      if sorted( c1PoliceRates ) != sorted( c2PoliceRates ):
         return CHANGED

      # compare count actions
      c1Counter = c1Action.policyAction.get( 'count' )
      c1CounterName = c1Counter.counterName if c1Counter else ""
      c2Counter = c2Action.policyAction.get( 'count' )
      c2CounterName = c2Counter.counterName if c2Counter else ""
      if c1CounterName != c2CounterName:
         return CHANGED

      # compare set dscp actions
      c1DscpValues = [ v.dscp for k, v in
                       c1Action.policyAction.iteritems() if k == 'setDscp' ]
      c2DscpValues = [ v.dscp for k, v in
                       c2Action.policyAction.iteritems() if k == 'setDscp' ]
      if sorted( c1DscpValues ) != sorted( c2DscpValues ):
         return CHANGED

      # compare set traffic class actions
      c1Tcs = [ v.tc for k, v in
                c1Action.policyAction.iteritems() if k == 'setTc' ]
      c2Tcs = [ v.tc for k, v in
                c2Action.policyAction.iteritems() if k == 'setTc' ]
      if sorted( c1Tcs ) != sorted( c2Tcs ):
         return CHANGED

      return IDENTICAL

   def delMatchRule( self, ruleName ):
      del self.currentPolicy().rawClassMap[ ruleName ]

   def abort( self ):
      self.delPolicyResources()

   def addDefaultRule( self, cmapName, matchOption, classPriority ):
      if self.npmap.rawClassMap.get( cmapName, None ):
         # Don't add the rule twice
         return

      cmap = self.npmap.rawClassMap.newMember( cmapName, UniqueId() )

      cmap.match.clear()
      cmapMatch = cmap.match.newMember( matchOption )
      cmapMatch.structuredFilter = ( "", )

      self.npmap.classPrio[ classPriority ] = cmapName

      matchRuleAction = TrafficPolicyMatchRuleAction( self,
                                                      cmapName,
                                                      matchOption,
                                                      cmapMatch.structuredFilter )
      self.addRuleCommon( classPriority, matchRuleAction )

   def newEditPmap( self ):
      super( TrafficPolicyContext, self ).newEditPmap()
      for className in self.reservedClassMapNames():
         if className == ReservedClassMapNames.classV4Default:
            self.addDefaultRule( className,
                                 matchIpAccessGroup,
                                 ClassPriorityConstant.classV4DefaultPriority )
         elif className == ReservedClassMapNames.classV6Default:
            self.addDefaultRule( className,
                                 matchIpv6AccessGroup,
                                 ClassPriorityConstant.classV6DefaultPriority )

#---------------------------------------------------------------
# The base config command for entering policy configuration
# i.e [ no|default ] keyword <policy-name>
#---------------------------------------------------------------
class PolicyConfigCmdBase( CliCommand.CliCommandClass ):
   @staticmethod
   def _feature():
      raise NotImplementedError

   @classmethod
   def _context( cls, name ):
      raise NotImplementedError

   @classmethod
   def _removePolicy( cls, mode, name ):
      context = cls._context( name )
      context.delPmap( mode, name )

   @classmethod
   def handler( cls, mode, args ):
      name = args[ 'POLICY_NAME' ]
      context = cls._context( name )

      if context.hasPolicy( name ):
         context.copyEditPmap()
      else:
         context.newEditPmap()
      childMode = mode.childMode( context.childMode(),
                                  context=context,
                                  feature=cls._feature() )
      mode.session_.gotoChildMode( childMode )

   @classmethod
   def noOrDefaultHandler( cls, mode, args ):
      name = args[ 'POLICY_NAME' ]
      context = cls._context( name )
      if context.hasPolicy( name ):
         cls._removePolicy( mode, name ) # pylint: disable=protected-access

#------------------------------------------------------------------------------------
# The "protocol tcp flags TCP_FLAGS"
#------------------------------------------------------------------------------------
class ProtocolTcpFlagsOnlyCmd( ProtocolMixin ):
   syntax = ( 'protocol FLAGS_EXPR' )
   noOrDefaultSyntax = syntax

   data = {
      'FLAGS_EXPR' : generateTcpFlagExpression( tcpFlagsSupported=True ),
      'protocol' : 'Protocol',
   }

   @classmethod
   def handler( cls, mode, args ):
      proto = args.get( cls._tcpFlagArgsListName )
      if not proto and not args:
         return
      cls._updateProtoAndPort( mode, args, proto, flags=True, add=True )
      cls._maybeHandleErrors( mode, args, proto )

   @classmethod
   def noOrDefaultHandler( cls, mode, args ):
      proto = args.get( cls._tcpFlagArgsListName )
      cls._updateProtoAndPort( mode, args, proto, flags=True, add=False )

class MatchRuleContext( MatchRuleBaseContext ):
   def __init__( self, trafficPolicyContext, ruleName, matchOption ):
      super( MatchRuleContext, self ).__init__( ruleName, matchOption )
      self.trafficPolicyContext = trafficPolicyContext
      self.trafficPolicy = self.trafficPolicyContext.currentPolicy()
      self.trafficPolicyName = self.trafficPolicyContext.pmapName_
      self.matchType = "ipv4" if self.matchOption == matchIpAccessGroup else \
                       "ipv6"
      self.matchRuleAction = None
      # If this match rule existed previously, it will have a
      # structuredFilter containing the previous configuration.  Check
      # for this and copy any existing config to our new structuredFilter.
      if self.trafficPolicy:
         matchRule = self.trafficPolicy.rawClassMap.get( self.ruleName )
         if matchRule:
            if matchRule.match.newMember( self.matchOption ).structuredFilter \
               is not None:
               self.filter.copy(
                  matchRule.match[ self.matchOption ].structuredFilter )
      self.seqnum = 0

   def childMode( self, matchRuleName, matchOption ):
      if matchRuleName in self.trafficPolicyContext.reservedClassMapNames():
         # The default rules have their own mode
         configMode = MatchRuleDefaultConfigMode
      elif matchOption == matchIpAccessGroup:
         configMode = MatchRuleIpv4ConfigMode
      else:
         assert matchOption == matchIpv6AccessGroup
         configMode = MatchRuleIpv6ConfigMode
      return configMode

   def actionMode( self ):
      return ActionsConfigMode

   def setAction( self, actionType, actionValue=None, no=False, clearActions=None ):
      if no:
         if actionType == ActionType.count and actionValue != "":
            action = self.matchRuleAction.policyActions.get( actionType )
            if action and action.counterName != actionValue:
               # Counter being deleted is not present in match rule, ignore request
               return
         self.matchRuleAction.delAction( actionType )
      else:
         if clearActions:
            for action in clearActions:
               self.matchRuleAction.delAction( action )
         actionsConfig = self.trafficPolicyContext.config().actions
         if actionType == ActionType.deny:
            action = actionsConfig.dropAction.newMember( self.trafficPolicyName,
                                                         UniqueId() )
         elif actionType == ActionType.count:
            action = actionsConfig.countAction.newMember( self.trafficPolicyName,
                                                          UniqueId(),
                                                          actionValue )
         elif actionType == ActionType.log:
            action = actionsConfig.logAction.newMember( self.trafficPolicyName,
                                                        UniqueId() )
         elif actionType == ActionType.police:
            action = actionsConfig.policeAction.newMember( self.trafficPolicyName,
                                                           UniqueId(),
                                                           actionValue )
         elif actionType == ActionType.sample:
            action = actionsConfig.sampleAction.newMember( self.trafficPolicyName,
                                                           UniqueId() )
         elif actionType == ActionType.sampleAll:
            action = actionsConfig.sampleAllAction.newMember( self.trafficPolicyName,
                                                              UniqueId() )
         elif actionType == ActionType.setDscp:
            action = actionsConfig.setDscpAction.newMember( self.trafficPolicyName,
                                                            UniqueId(),
                                                            actionValue )
         elif actionType == ActionType.setTc:
            action = actionsConfig.setTcAction.newMember( self.trafficPolicyName,
                                                          UniqueId(),
                                                          actionValue )
         else:
            assert False, 'unknown action ' + actionType
         self.matchRuleAction.addAction( actionType, action )

   def copyEditMatchRule( self, ruleName, seqnum ):
      self.seqnum = seqnum
      self.matchRuleAction = TrafficPolicyMatchRuleAction( self.trafficPolicyContext,
                                                           ruleName,
                                                           self.matchOption,
                                                           self.filter )
      classAction = self.trafficPolicy.classAction[ ruleName ]
      for actType, action in classAction.policyAction.iteritems():
         self.matchRuleAction.originalActions.add( action )
         self.matchRuleAction.addAction( actType, action )

      # set the cmapMatch
      matchRule = self.trafficPolicy.rawClassMap.get( ruleName )
      self.matchRuleAction.cmapMatch = matchRule.match[ self.matchOption ]

   def newEditMatchRule( self, ruleName, seqnum ):
      if seqnum is None:
         seqnum = 0
      self.seqnum = seqnum
      self.matchRuleAction = TrafficPolicyMatchRuleAction( self.trafficPolicyContext,
                                                           ruleName,
                                                           self.matchOption,
                                                           self.filter )
   def commit( self ):
      # commit all pending rule changes
      self.trafficPolicyContext.addRuleCommon( self.seqnum, self.matchRuleAction )
