#!/usr/bin/env python
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac
import BasicCliUtil
import CliCommand
import CliMatcher
import CliParser
import Ip6AddrMatcher
import IpAddrMatcher

from MultiRangeRule import MultiRangeMatcher
from CliMode.Classification import ( AppConfigMode, AppProfileConfigMode )
from ClassificationCliContextLib import ( IpPrefixFieldSetContext,
                                          L4PortFieldSetContext )
from ClassificationLib import ( numericalRangeToSet, rangeSetToNumericalRange,
                                numericalRangeToRangeString, extraIpv4Protocols,
                                genericIpProtocols, tcpUdpProtocols, getKeywordMap,
                                getProtectedFieldSetNames, icmpV4Types, icmpV6Types,
                                icmpV4Codes, icmpV6Codes, icmpV4TypeWithValidCodes,
                                icmpV6TypeWithValidCodes )
from ClassificationAppProfileId import updateAppProfileId
from eunuchs.in_h import IPPROTO_TCP, IPPROTO_ICMP, IPPROTO_ICMPV6

tacFieldConflict = Tac.Type( 'Classification::FieldConflict' )
conflictNeighborProtocol = tacFieldConflict.conflictNeighborProtocol
conflictMatchL4Protocol = tacFieldConflict.conflictMatchL4Protocol
conflictMatchAllFragments = tacFieldConflict.conflictMatchAllFragments
conflictFragmentOffset = tacFieldConflict.conflictFragmentOffset
conflictOther = tacFieldConflict.conflictOther
FragmentType = Tac.Type( 'Classification::FragmentType' )
matchAll = FragmentType.matchAll
matchNone = FragmentType.matchNone
matchOffset = FragmentType.matchOffset
UniqueId = Tac.Type( 'Ark::UniqueId' )

configConflictMsg = (
      "The '%s' subcommand is not "
      "supported when either the 'protocol neighbors' or "
      "'protocol bgp' subcommand are configured" )
invalidPortConflictMsg = (
      "The '%s' subcommand is not supported if protocols other than "
      "'{tcp|udp|tcp udp}' are configured" )
invalidProtocolConflictMsg = (
      "The 'protocol' subcommand only supports 'tcp' or 'udp' if"
      " '{source|destination} port' is configured" )
invalidL4PortConflictMsg = (
      "The '{source|destination} port' command is not supported when "
      "'fragment' or 'fragment offset' command is configured" )
invalidFragmentConflictMsg = (
      "The 'fragment' command is not supported when 'source port' or "
      "'destination port' command is configured" )
invalidFragOffsetConflictMsg = (
      "The 'fragment offset' command is not supported when 'source port' or "
      "'destination port' command is configured" )

appConstants = Tac.Value( 'Classification::Constants' )

tcpKeywordMatcher = CliMatcher.KeywordMatcher( "tcp", helpdesc="tcp",
                                        value=lambda mode, _:set( [ IPPROTO_TCP ] ) )
tcpFlagTokens = {
   'established': 'match on established',
   'initial': 'match on initial'
}

def generateMultiRangeMatcher( name, maxVal, minVal=0, helpdesc='' ):
   if not helpdesc:
      helpdesc = '%s values(s) or range(s) of %s values' % ( name, name )
   return MultiRangeMatcher( rangeFn=lambda: ( minVal, maxVal ),
                             noSingletons=False,
                             helpdesc=helpdesc,
                             value=lambda mode, grList: set( grList.values() ) )

ipLengthRangeMatcher = generateMultiRangeMatcher( 'length', appConstants.maxLength )
fragOffsetRangeMatcher = generateMultiRangeMatcher( 'fragment offset',
                                                    appConstants.maxFragment )
portRangeMatcher = generateMultiRangeMatcher( 'port', appConstants.maxL4Port )
protoRangeMatcher = generateMultiRangeMatcher( 'protocol',
                                               appConstants.maxProto, 1 )
portKwMatcher = CliMatcher.KeywordMatcher( 'port', helpdesc='Port' )
fieldSetKwMatcher = CliMatcher.KeywordMatcher( 'field-set', helpdesc='Field set' )
icmpTypeRangeMatcher = generateMultiRangeMatcher( 'icmp type',
                                                  appConstants.maxIcmpType )
icmpCodeRangeMatcher = generateMultiRangeMatcher( 'icmp code',
                                                  appConstants.maxIcmpCode )

def getIcmpCodeMap( mode, context, icmpTypeWithValidCodes, icmpCodes ):
   icmpTypeName = context.sharedResult.get( 'TYPE_NAME' )
   for name, value in icmpTypeWithValidCodes.iteritems():
      if name == icmpTypeName:
         icmpTypeValue = value[ 0 ]
         break
   if icmpTypeValue is None:
      return {}
   icmpCodeMap = getKeywordMap( icmpCodes.get( icmpTypeValue, {} ) )
   return icmpCodeMap

icmpV4CodeNameMatcher = CliMatcher.DynamicKeywordMatcher(
   lambda mode, context: getIcmpCodeMap( mode, context, icmpV4TypeWithValidCodes,
                                         icmpV4Codes ), passContext=True )
icmpV6CodeNameMatcher = CliMatcher.DynamicKeywordMatcher(
   lambda mode, context: getIcmpCodeMap( mode, context, icmpV6TypeWithValidCodes,
                                         icmpV6Codes ), passContext=True )
icmpProtocols = {
   'icmp': ( IPPROTO_ICMP, 'Internet Control Message Protocol' ),
   'icmpv6': ( IPPROTO_ICMPV6, 'Internet Control Message Protocol version 6' ),
}
icmpV4KwMatcher = CliMatcher.KeywordMatcher(
   'icmp', helpdesc='%s (%d)' % ( icmpProtocols[ 'icmp' ][ 1 ],
                                  icmpProtocols[ 'icmp' ][ 0 ] ) )
icmpV6KwMatcher = CliMatcher.KeywordMatcher(
   'icmpv6', helpdesc='%s (%d)' % ( icmpProtocols[ 'icmpv6' ][ 1 ],
                                    icmpProtocols[ 'icmpv6' ][ 0 ] ) )

def generateTcpFlagExpression( tcpFlagsSupported=False, notAllowed=False ):
   class TcpFieldsExpression( CliCommand.CliExpression ):
      if tcpFlagsSupported:
         expression = "tcp flags %s { TCP_FLAGS }" % \
                      ( '[ not ]' if notAllowed else '' )
         data = {
            'tcp': tcpKeywordMatcher,
            'flags': 'flags',
            'TCP_FLAGS': CliMatcher.EnumMatcher( tcpFlagTokens )
         }
         if notAllowed:
            data[ 'not' ] = 'not'
      else:
         expression = ""
         data = {}

      @staticmethod
      def adapter( mode, args, argsList ):
         flags = args.get( 'TCP_FLAGS', [] )
         finalFlags = []
         for flag in flags:
            if flag == 'established':
               finalFlags.append( 'est' )
            elif flag == 'initial':
               finalFlags.append( 'init' )
            else:
               finalFlags.append( flag )
         args[ 'FLAGS_EXPR' ] = finalFlags

   return TcpFieldsExpression

def generateIpProtoExpression( name, genericRules, extraRules ):
   # Generate protocol expression that accepts both names and ranges
   protoMap = getKeywordMap( genericRules, extraRules )
   protoName = name + "_PROTO"
   protoRange = name + "_RANGE"

   class IpProtoExpression( CliCommand.CliExpression ):
      expression = "{ %s } | %s" % ( protoName, protoRange )
      data = { protoName: CliMatcher.DynamicKeywordMatcher( lambda mode:
                                                            protoMap ),
               protoRange: protoRangeMatcher }

      @staticmethod
      def adapter( mode, args, argsList ):
         names = args.pop( protoName, None )
         if names:
            valueSet = set( ( genericRules.get( n, None ) or
                              extraRules[ n ] )[ 0 ] for n in names )
         else:
            valueSet = args.pop( protoRange, None )
         if valueSet is not None:
            args[ name ] = valueSet

   return IpProtoExpression

def generateTcpUdpProtoExpression( name, tcpUdpRules, allowMultiple=True ):
   # generate an expression to have a set of tcp/udp:
   # "tcp | udp | ( tcp udp ) | ( udp | tcp )"
   matchers = { name + '_' + proto: CliCommand.singleKeyword(
      proto, helpdesc=val[ 1 ] ) for proto, val in tcpUdpRules.iteritems() }

   class TcpUdpExpression( CliCommand.CliExpression ):
      expression = ' | '.join( m for m in matchers )
      if allowMultiple:
         expression = '{' + expression + '}'
      data = matchers

      @staticmethod
      def adapter( mode, args, argsList ):
         for m in matchers:
            arg = args.pop( m, None )
            if arg:
               args.setdefault( name, set() )
               args[ name ].add( tcpUdpRules[ arg ][ 0 ] )

   return TcpUdpExpression

def generateIcmpTypeRangeExpression( name, icmpTypeRules ):
   # Generate icmp type expression that accepts a range of names or a range of values
   icmpTypeMap = getKeywordMap( icmpTypeRules )
   icmpTypeName = name + "_NAME"
   icmpTypeRange = name + "_RANGE"

   class IcmpTypeRangeExpression( CliCommand.CliExpression ):
      expression = "{ %s } | %s" % ( icmpTypeName, icmpTypeRange )
      data = { icmpTypeName: CliMatcher.DynamicKeywordMatcher( lambda mode:
                                                               icmpTypeMap ),
               icmpTypeRange: icmpTypeRangeMatcher }

      @staticmethod
      def adapter( mode, args, argsList ):
         typeNames = args.pop( icmpTypeName, None )
         if typeNames:
            valueSet = set( icmpTypeRules[ n ][ 0 ] for n in typeNames )
         else:
            valueSet = args.pop( icmpTypeRange, None )
         if valueSet is not None:
            args[ name ] = valueSet

   return IcmpTypeRangeExpression

def generateIcmpTypeSingleExpression( name, icmpTypeRules ):
   # Generate icmp type expression that acceptes a single name or a single value
   icmpTypeMap = getKeywordMap( icmpTypeRules )
   icmpTypeMatcher = CliMatcher.DynamicKeywordMatcher( lambda mode: icmpTypeMap )
   icmpTypeName = name + "_NAME"

   class IcmpTypeSingleExpression( CliCommand.CliExpression ):
      expression = "%s" % icmpTypeName
      data = { icmpTypeName: CliCommand.Node( icmpTypeMatcher,
                                              storeSharedResult=True ) }
      @staticmethod
      def adapter( mode, args, argsList ):
         typeName = args.pop( icmpTypeName, None )
         value = icmpTypeRules[ typeName ][ 0 ]
         args[ name ] = value

   return IcmpTypeSingleExpression

def generateIcmpCodeExpression( name, icmpCodeRules, icmpCodeNameMatcher ):
   # Generate icmp code expression that accepts names or ranges
   icmpCodeName = name + "_NAME"
   icmpCodeRange = name + "_RANGE"

   class IcmpCodeExpression( CliCommand.CliExpression ):
      expression = "{ %s } | %s" % ( icmpCodeName, icmpCodeRange )
      data = { icmpCodeName: icmpCodeNameMatcher,
               icmpCodeRange: icmpCodeRangeMatcher }

      @staticmethod
      def adapter( mode, args, argsList ):
         codeNames = args.pop( icmpCodeName, None )
         valueSet = set()
         codeNameValueMap = dict()
         if codeNames:
            for codeMap in icmpCodeRules.values():
               codeNameValueMap.update( codeMap )
            for codeName in codeNames:
               for n, value in codeNameValueMap.iteritems():
                  if n == codeName:
                     valueSet.add( value[ 0 ] )
                     break
         else:
            valueSet = args.pop( icmpCodeRange, None )
         if valueSet:
            args[ name ] = valueSet

   return IcmpCodeExpression

icmpV4TypeRangeExpr = generateIcmpTypeRangeExpression( 'TYPE', icmpV4Types )
icmpV6TypeRangeExpr = generateIcmpTypeRangeExpression( 'TYPE', icmpV6Types )

icmpV4TypeSingleExpr = generateIcmpTypeSingleExpression( 'TYPE',
                                                         icmpV4TypeWithValidCodes )
icmpV6TypeSingleExpr = generateIcmpTypeSingleExpression( 'TYPE',
                                                         icmpV6TypeWithValidCodes )

icmpV4CodeExpr = generateIcmpCodeExpression( 'CODE', icmpV4Codes,
                                             icmpV4CodeNameMatcher )
icmpV6CodeExpr = generateIcmpCodeExpression( 'CODE', icmpV6Codes,
                                             icmpV6CodeNameMatcher )

tcpUdpProtoExpr = generateTcpUdpProtoExpression( 'TCP_UDP', tcpUdpProtocols )

ipv4ProtoExpr = generateIpProtoExpression( 'PROTOCOL', genericIpProtocols,
                                           extraIpv4Protocols )

def generateFieldSetExpression( nameMatcher, name, allowMultiple=True ):
   class FieldSetExpression( CliCommand.CliExpression ):
      expression = name
      if allowMultiple:
         expression = '{' + expression + '}'
      data = {
         name: nameMatcher
      }
   return FieldSetExpression

class AppRecognitionContext( object ):
   def __init__( self, appRecognitionConfig, appProfileIdMap, fieldSetConfig ):
      self.appRecognitionCurrConfig = appRecognitionConfig
      self.fieldSetCurrConfig = fieldSetConfig
      self.appRecognitionEditConfig = None
      self.appProfileIdMap = appProfileIdMap
      self.fieldSetEditConfig = None
      self.mode_ = None

   def modeIs( self, mode ):
      self.mode_ = mode

   def copyEditAppRecognitionConfig( self ):
      self.appRecognitionEditConfig = Tac.newInstance(
              'Classification::AppRecognitionConfig', 'appRecognitionConfig' )
      self.fieldSetEditConfig = Tac.newInstance(
              'Classification::FieldSetConfig', 'fieldSetConfig' )
      self.copyFieldSet( toSysdb=False )
      self.copyAppRec( toSysdb=False )

   def copyFieldSetL4Port( self, src, dst ):
      for name, srcFsCfg in src.iteritems():
         dstFsCfg = dst.get( name )
         srcSubCfg = srcFsCfg.currCfg
         if not dstFsCfg:
            dstFsCfg = dst.newMember( name )
            prevDstCurrCfg = None
         else:
            prevDstCurrCfg = dstFsCfg.currCfg
            # don't copy if no change in ports set
            srcPortStr = numericalRangeToRangeString( srcSubCfg.ports )
            dstPortStr = numericalRangeToRangeString( prevDstCurrCfg.ports )
            if srcPortStr == dstPortStr:
               continue

         dstSubCfg = dstFsCfg.subConfig.newMember( name, UniqueId() )

         # clear all the ports before copying to avoid stale entry
         dstSubCfg.ports.clear()
         for portRange in srcSubCfg.ports:
            dstSubCfg.ports.add( portRange )

         dstFsCfg.currCfg = dstSubCfg

         if prevDstCurrCfg is not None:
            assert prevDstCurrCfg.version in dstFsCfg.subConfig
            del dstFsCfg.subConfig[ prevDstCurrCfg.version ]

      # Delete stale entries
      for name in dst:
         if name not in src:
            del dst[ name ]

   def copyFieldSetIpPrefix( self, src, dst ):
      for name, srcFsCfg in src.iteritems():
         dstFsCfg = dst.get( name )
         srcSubCfg = srcFsCfg.currCfg
         if not dstFsCfg:
            # XXX - We're always copying the same AF, right?
            dstFsCfg = dst.newMember( name, srcFsCfg.af )
            prevDstCurrCfg = None
         else:
            prevDstCurrCfg = dstFsCfg.currCfg
            if ( sorted( srcSubCfg.prefixes ) == sorted( prevDstCurrCfg.prefixes )
                 and ( sorted( srcSubCfg.exceptPrefix ) ==
                      sorted( prevDstCurrCfg.exceptPrefix ) ) ):
               continue
         dstSubCfg = dstFsCfg.subConfig.newMember( name, UniqueId() )

         # clear all the prefixes before copying to avoid stale entry
         dstSubCfg.prefixes.clear()
         for prefix in srcSubCfg.prefixes:
            dstSubCfg.prefixes.add( prefix )
         dstSubCfg.exceptPrefix.clear()
         for exceptPrefix in srcSubCfg.exceptPrefix:
            dstSubCfg.exceptPrefix.add( exceptPrefix )

         dstFsCfg.currCfg = dstSubCfg

         if prevDstCurrCfg is not None:
            assert prevDstCurrCfg.version in dstFsCfg.subConfig
            del dstFsCfg.subConfig[ prevDstCurrCfg.version ]

      # Delete stale entries
      for name in dst:
         if name not in src:
            del dst[ name ]

   def copyAppProfile( self, src, dst ):
      for name in src:
         if name not in dst:
            dst.newMember( name )
         # don't copy if no change in app set
         if sorted( src[ name ].app.keys() ) == sorted( dst[ name ].app.keys() ):
            continue

         # clear all the apps before copying to avoid stale entry
         dst[ name ].app.clear()
         for appName in src[ name ].app:
            dst[ name ].app[ appName ] = True

         dst[ name ].version = src[ name ].version

      # Delete Stale entries
      for name in dst:
         updateAppProfileId( self.appProfileIdMap, name )
         if name not in src:
            del dst[ name ]

   def copyApp( self, src, dst ):
      for name in src:
         if name not in dst:
            dst.newMember( name )

         # don't copy if no change in any fields
         srcProto = numericalRangeToRangeString( src[ name ].proto )
         dstProto = numericalRangeToRangeString( dst[ name ].proto )
         if src[ name ].srcPrefixFieldSet == dst[ name ].srcPrefixFieldSet and \
            src[ name ].dstPrefixFieldSet == dst[ name ].dstPrefixFieldSet and \
            src[ name ].srcPortFieldSet == dst[ name ].srcPortFieldSet and \
            src[ name ].dstPortFieldSet == dst[ name ].dstPortFieldSet and \
            src[ name ].af == dst[ name ].af and \
            srcProto == dstProto:
            continue

         dst[ name ].srcPrefixFieldSet = src[ name ].srcPrefixFieldSet
         dst[ name ].dstPrefixFieldSet = src[ name ].dstPrefixFieldSet
         dst[ name ].srcPortFieldSet = src[ name ].srcPortFieldSet
         dst[ name ].dstPortFieldSet = src[ name ].dstPortFieldSet
         dst[ name ].af = src[ name ].af
         dst[ name ].proto.clear()
         for protoRange in src[ name ].proto:
            dst[ name ].proto.add( protoRange )
         dst[ name ].version = src[ name ].version

      # Delete stale entries
      for name in dst:
         if name not in src:
            del dst[ name ]

   def copyFieldSet( self, toSysdb=False ):
      # When copying to Sysdb, copy the "scratchpad" EditConfig to the current
      # configuration "CurrConfig". Otherwise, the direction is reversed.
      if toSysdb:
         src = self.fieldSetEditConfig
         dst = self.fieldSetCurrConfig
      else:
         src = self.fieldSetCurrConfig
         dst = self.fieldSetEditConfig
      self.copyFieldSetL4Port( src.fieldSetL4Port, dst.fieldSetL4Port )
      self.copyFieldSetIpPrefix( src.fieldSetIpPrefix, dst.fieldSetIpPrefix )

   def copyAppRec( self, toSysdb=False ):
      # When copying to Sysdb, copy the "scratchpad" EditConfig to the current
      # configuration "CurrConfig". Otherwise, the direction is reversed.
      if toSysdb:
         src = self.appRecognitionEditConfig
         dst = self.appRecognitionCurrConfig
      else:
         src = self.appRecognitionCurrConfig
         dst = self.appRecognitionEditConfig
      self.copyApp( src.app, dst.app )
      self.copyAppProfile( src.appProfile, dst.appProfile )

   def abort( self ):
      self.appRecognitionEditConfig = None
      self.fieldSetEditConfig = None

   def commit( self ):
      if self.fieldSetEditConfig:
         self.copyFieldSet( toSysdb=True )

      if self.appRecognitionEditConfig:
         self.copyAppRec( toSysdb=True )

class AppProfileContext( object ):
   def __init__( self, appProfileName, parentContext ):
      self.childMode = AppProfileConfigMode
      self.appProfileName = appProfileName
      self.appProfile = parentContext.appRecognitionEditConfig.appProfile
      self.appProfileEdit = None
      self.mode_ = None

   def copyEditAppProfile( self ):
      self.appProfileEdit = Tac.newInstance( 'Classification::AppProfile',
                                             self.appProfileName )
      self.copyApps( self.appProfile[ self.appProfileName ], self.appProfileEdit )

   def newEditAppProfile( self ):
      self.appProfileEdit = Tac.newInstance( 'Classification::AppProfile',
                                             self.appProfileName )

   def updateApp( self, appName, add=True ):
      if add:
         self.appProfileEdit.app[ appName ] = True
      else:
         if appName in self.appProfileEdit.app.keys():
            del self.appProfileEdit.app[ appName ]

   def copyApps( self, src, dst, commit=False ):
      if sorted( dst.app.keys() ) == sorted( src.app.keys() ):
         return

      dst.app.clear()
      for appName in src.app:
         dst.app[ appName ] = True

      if commit:
         # increment the version because there is some config change
         dst.version += 1

   def hasAppProfile( self, name ):
      return name in self.appProfile

   def delAppProfile( self, name ):
      del self.appProfile[ name ]

   def modeIs( self, mode ):
      self.mode_ = mode

   def commit( self ):
      # commit to parent context
      if self.appProfileEdit:
         if self.appProfileName not in self.appProfile:
            self.appProfile.newMember( self.appProfileName )
         self.copyApps( self.appProfileEdit,
                        self.appProfile[ self.appProfileName ],
                        commit=True )

   def abort( self ):
      self.appProfileName = None
      self.appProfileEdit = None

class AppContext( object ):
   def __init__( self, appName, parentContext ):
      self.childMode = AppConfigMode
      self.appName = appName
      self.app = parentContext.appRecognitionEditConfig.app
      self.appEdit = None
      self.mode_ = None

   def copyEditApp( self ):
      self.appEdit = Tac.newInstance( 'Classification::AppConfig', self.appName )
      self.appEdit.af = 'ipv4'
      self.copyAppFields( self.app[ self.appName ], self.appEdit )

   def newEditApp( self ):
      self.appEdit = Tac.newInstance( 'Classification::AppConfig', self.appName )
      self.appEdit.af = 'ipv4'

   def updatePrefixFieldSet( self, source=True, names=None, add=True ):
      assert len( names ) <= 1
      if source:
         if add:
            self.appEdit.srcPrefixFieldSet = names[ 0 ]
         else:
            self.appEdit.srcPrefixFieldSet = ""
      else:
         if add:
            self.appEdit.dstPrefixFieldSet = names[ 0 ]
         else:
            self.appEdit.dstPrefixFieldSet = ""

   def portAndProtoConfigured( self ):
      protoSet = numericalRangeToSet( self.appEdit.proto )
      port = not ( self.appEdit.srcPortFieldSet == '' and
                   self.appEdit.dstPortFieldSet == '' )
      return ( protoSet, port )

   def updatePortFieldSetAttr( self, attrName, fieldSetName='', add=True, **kwargs ):
      if attrName not in [ 'srcPortFieldSet', 'dstPortFieldSet' ]:
         return
      if add:
         setattr( self.appEdit, attrName, fieldSetName )
      else:
         setattr( self.appEdit, attrName, '' )

   def updateTcpFlags( self, tcpFlags, notFlags, add=True ):
      """
      Does not support tcp flags yet
      """
      pass

   def maybeUpdateProto( self, protocolRangeSet ):
      """
      Takes in a set of protocol ranges where additional fields have been clear
      if all fields have been cleared (sport/dport/flags etc) we delete the protocol
      """
      pass

   def updateRangeAttr( self, attrName, rangeSet, rangeType, add=True ):
      currRange = getattr( self.appEdit, attrName )
      currentSet = numericalRangeToSet( currRange )
      updatedSet = set()
      if add:
         updatedSet = currentSet | rangeSet
      else:
         if not rangeSet:
            updatedSet = set()
         else:
            updatedSet = currentSet - rangeSet
      newRangeList = rangeSetToNumericalRange( updatedSet,
                                               rangeType )

      currRange.clear()
      for aRange in newRangeList:
         currRange.add( aRange )

      # delete source and dest port when no proto is present
      if not currRange:
         self.updatePortFieldSetAttr( 'srcPortFieldSet', add=False )
         self.updatePortFieldSetAttr( 'dstPortFieldSet', add=False )

   def getProto( self ):
      return self.appEdit.proto

   def isEqual( self, src, dst ):
      srcProtoStr = numericalRangeToRangeString( src.proto )
      dstProtoStr = numericalRangeToRangeString( dst.proto )
      if src.srcPrefixFieldSet == dst.srcPrefixFieldSet and \
         src.dstPrefixFieldSet == dst.dstPrefixFieldSet and \
         src.srcPortFieldSet == dst.srcPortFieldSet and \
         src.dstPortFieldSet == dst.dstPortFieldSet and \
         src.af == dst.af and \
         srcProtoStr == dstProtoStr:
         return True
      return False

   def copyAppFields( self, src, dst, commit=False ):
      if self.isEqual( src, dst ):
         # no attribute has changed
         return

      dst.af = src.af
      dst.srcPrefixFieldSet = src.srcPrefixFieldSet
      dst.dstPrefixFieldSet = src.dstPrefixFieldSet
      dst.srcPortFieldSet = src.srcPortFieldSet
      dst.dstPortFieldSet = src.dstPortFieldSet
      dst.proto.clear()
      for protoRange in src.proto:
         dst.proto.add( protoRange )
      dst.version = src.version

      if commit:
         # increment the version because there is some config change
         dst.version += 1

   def hasApp( self, name ):
      return name in self.app

   def delApp( self, name ):
      del self.app[ name ]

   def modeIs( self, mode ):
      self.mode_ = mode

   def commit( self ):
      # commit to parent context
      if self.appEdit:
         if self.appName not in self.app:
            self.app.newMember( self.appName )
         self.copyAppFields( self.appEdit, self.app[ self.appName ], commit=True )

   def abort( self ):
      self.appName = None
      self.appEdit = None

#------------------------------------------------------------------------------------
# (source | destination) prefix field-set FIELD_SET
#------------------------------------------------------------------------------------
class PrefixFieldSetCmdBase( CliCommand.CliCommandClass ):
   syntax = '( source | destination ) prefix field-set FIELD_SET'
   noOrDefaultSyntax = '( source | destination ) prefix field-set [ FIELD_SET ]'

   _baseData = {
      'source': 'Source',
      'destination': 'Destination',
      'prefix': 'Prefix',
      'field-set': 'Field set',
   }

   @classmethod
   def handler( cls, mode, args ):
      source = args.get( 'source', False )
      fieldSetName = args.get( 'FIELD_SET' )
      context = mode.getContext()
      # FIELD_SET may or may not allow for multiples
      if not isinstance( fieldSetName, list ):
         fieldSetName = [ fieldSetName ]
      context.updatePrefixFieldSet( source=source, names=fieldSetName, add=True )

   @classmethod
   def noOrDefaultHandler( cls, mode, args ):
      source = args.get( 'source', False )
      fieldSetName = args.get( 'FIELD_SET', [] )
      context = mode.getContext()
      # FIELD_SET may or may not allow for multiples
      if not isinstance( fieldSetName, list ):
         fieldSetName = [ fieldSetName ]
      context.updatePrefixFieldSet( source=source, names=fieldSetName, add=False )

#------------------------------------------------------------------------------------
# [ except <exceptPrefix> ]
#------------------------------------------------------------------------------------
def generateExceptExpression( matcher=None, exceptSupported=False,
                              allowMultiple=True ):
   class ExceptExpression( CliCommand.CliExpression ):
      if exceptSupported:
         expression = "except "
         if allowMultiple:
            expression += "{ EXCEPT_ITEMS }"
         else:
            expression += "EXCEPT_ITEMS"
         data = {
            'except': 'except',
            'EXCEPT_ITEMS': matcher
         }
      else:
         expression = ""
         data = {}

      @staticmethod
      def adapter( mode, args, argsList ):
         exceptPrefixes = args.get( 'EXCEPT_ITEMS', [] )
         if not isinstance( exceptPrefixes, list ):
            exceptPrefixes = [ exceptPrefixes ]
         args[ 'EXCEPT_ITEMS' ] = exceptPrefixes
   return ExceptExpression

#------------------------------------------------------------------------------------
# (source | destination) prefix <prefix>
#------------------------------------------------------------------------------------
class PrefixCmdBase( CliCommand.CliCommandClass ):
   _baseSyntax = '( source | destination ) prefix'

   _baseData = { 'source': 'source',
                 'destination': 'destination',
                 'prefix': 'prefix' }

   # Set to True if each invocation of this command should first clear the previous
   # values.
   _overwrite = False

   @classmethod
   def _updatePrefix( cls, mode, args, add ):
      context = mode.getContext()
      filterType = "source" if 'source' in args else "destination"
      if add and cls._overwrite:
         context.addOrRemovePrefix( getattr( context.filter, filterType ),
                                    filterType=filterType, add=False )

      prefixes = args[ 'PREFIX' ]
      if not isinstance( prefixes, list ):
         prefixes = [ prefixes ]
      if prefixes:
         context.addOrRemovePrefix( prefixes, filterType=filterType,
                                    add=add )

   @classmethod
   def handler( cls, mode, args ):
      # Check to see if the structuredFilter contains conflicting config.
      context = mode.getContext()
      if context.isValidConfig( conflictOther ):
         cls._updatePrefix( mode, args, add=True )
         return
      attr = 'source prefix' if 'source' in args else 'destination prefix'
      mode.addError( configConflictMsg % attr )

   @classmethod
   def noOrDefaultHandler( cls, mode, args ):
      cls._updatePrefix( mode, args, add=False )

class PrefixIpv4Cmd( PrefixCmdBase ):
   syntax = PrefixCmdBase._baseSyntax + ' { PREFIX }'
   noOrDefaultSyntax = syntax
   data = {
      'PREFIX': IpAddrMatcher.ipPrefixExpr(
         'Prefix address',
         'Prefix mask',
         'Prefix',
         overlap=IpAddrMatcher.PREFIX_OVERLAP_AUTOZERO )
   }
   data.update( PrefixCmdBase._baseData )

class PrefixIpv4SingletonCmd( PrefixCmdBase ):
   syntax = PrefixCmdBase._baseSyntax + ' PREFIX'
   noOrDefaultSyntax = syntax
   data = {
      'PREFIX': IpAddrMatcher.ipPrefixExpr(
         'Prefix address',
         'Prefix mask',
         'Prefix',
         overlap=IpAddrMatcher.PREFIX_OVERLAP_AUTOZERO )
   }
   data.update( PrefixCmdBase._baseData )
   _overwrite = True

class PrefixIpv6Cmd( PrefixCmdBase ):
   syntax = PrefixCmdBase._baseSyntax + ' { PREFIX }'
   noOrDefaultSyntax = syntax
   data = {
      'PREFIX': Ip6AddrMatcher.Ip6PrefixValidMatcher(
         'IPv6 address prefix',
         overlap=IpAddrMatcher.PREFIX_OVERLAP_AUTOZERO )
   }
   data.update( PrefixCmdBase._baseData )

# --------------------------------------------------------------------------
# The "field-set l4-port PORT_SET_NAME" command
# --------------------------------------------------------------------------
def protectedFieldSetNamesRegex( field ):
   excludePattern = ''.join( BasicCliUtil.notAPrefixOf( k )
                             for k in getProtectedFieldSetNames( field ) )
   return excludePattern + r'[A-Za-z0-9_:{}\[\]-]*'

class FieldSetL4PortBaseConfigCmd( CliCommand.CliCommandClass ):
   syntax = 'field-set l4-port FIELD_SET_NAME'
   noOrDefaultSyntax = syntax
   _feature = "app"
   _l4PortContext = L4PortFieldSetContext
   _baseData = {
      'field-set': 'Configure field set',
      'l4-port': 'Layer 4 port',
   }

   @classmethod
   def _getContextKwargs( cls, fieldSetL4PortName, mode=None ):
      raise NotImplementedError

   @classmethod
   def handler( cls, mode, args ):
      name = args[ 'FIELD_SET_NAME' ]
      contextKwargs = cls._getContextKwargs( name, mode )
      context = cls._l4PortContext( **contextKwargs )

      if context.hasL4PortFieldSet( name ):
         context.copyEditFieldSet()
      else:
         context.newEditFieldSet()

      childMode = mode.childMode( context.childMode, context=context,
                                  feature=cls._feature )
      mode.session_.gotoChildMode( childMode )

   @classmethod
   def noOrDefaultHandler( cls, mode, args ):
      name = args[ 'FIELD_SET_NAME' ]
      contextKwargs = cls._getContextKwargs( name, mode )
      context = cls._l4PortContext( **contextKwargs )
      if context.hasL4PortFieldSet( name ):
         context.delFieldSet( name )

   @classmethod
   def _removeFieldSet( cls, mode, name ):
      contextKwargs = cls._getContextKwargs( name, mode )
      context = cls._l4PortContext( **contextKwargs )
      if context.hasL4PortFieldSet( name ):
         context.delFieldSet( name )

# --------------------------------------------------------------------------
# The "[remove] ( all | PORT )" command
# --------------------------------------------------------------------------
class FieldSetL4PortConfigCmds( CliCommand.CliCommandClass ):
   syntax = '[remove] ( all | PORT )'
   data = {
      'remove': 'Remove l4 port(s) from port set',
      'all': 'All l4 ports from 0-%s' % appConstants.maxL4Port,
      'PORT': portRangeMatcher,
   }

   @staticmethod
   def handler( mode, args ):
      tokenRemove = args.get( 'remove' )
      tokenAll = args.get( 'all' )
      context = mode.getContext()
      if tokenRemove:
         if tokenAll:
            context.updateFieldSet( set(), add=False, allPorts=True )
         else:
            context.updateFieldSet( args.get( 'PORT' ), add=False )
      elif tokenAll:
         context.updateFieldSet( set(), add=True, allPorts=True )
      else:
         context.updateFieldSet( args.get( 'PORT' ), add=True )

# --------------------------------------------------------------------------
# The "field-set (ipv4 | ipv6) prefix FIELD_SET_NAME" command
# --------------------------------------------------------------------------
class FieldSetIpPrefixBaseConfigCmd( CliCommand.CliCommandClass ):
   _feature = "app"
   _ipPrefixFieldSetContext = IpPrefixFieldSetContext
   _baseData = {
      'field-set': 'Configure field set',
      'ipv4': 'IPv4',
      'ipv6': 'IPv6',
      'prefix': 'IPv4 prefixes',
   }

   @classmethod
   def _getContextKwargs( cls, fieldSetIpPrefixName, setType, mode=None ):
      raise NotImplementedError

   @classmethod
   def handler( cls, mode, args ):
      name = args[ 'FIELD_SET_NAME' ]
      setType = args.get( 'ipv4' ) or args.get( 'ipv6' )
      contextKwargs = cls._getContextKwargs( name, setType, mode )
      context = cls._ipPrefixFieldSetContext( **contextKwargs )

      if context.hasPrefixFieldSet( name, setType ):
         context.copyEditFieldSet()
      else:
         context.newEditFieldSet()

      childMode = mode.childMode( context.childMode, context=context,
                                  feature=cls._feature )
      mode.session_.gotoChildMode( childMode )

   @classmethod
   def noOrDefaultHandler( cls, mode, args ):
      name = args[ 'FIELD_SET_NAME' ]
      setType = args.get( 'ipv4' ) or args.get( 'ipv6' )
      contextKwargs = cls._getContextKwargs( name, setType, mode )
      context = cls._ipPrefixFieldSetContext( **contextKwargs )
      if context.hasPrefixFieldSet( name, setType ):
         context.delFieldSet( name )

   @classmethod
   def _removeFieldSet( cls, mode, name, setType ):
      contextKwargs = cls._getContextKwargs( name, setType, mode )
      context = cls._ipPrefixFieldSetContext( **contextKwargs )
      if context.hasPrefixFieldSet( name, setType ):
         context.delFieldSet( name )


# --------------------------------------------------------------------------
# The "[ ( no | remove ) ] { PREFIXES }" command
# --------------------------------------------------------------------------
class FieldSetPrefixConfigCmdsBase( CliCommand.CliCommandClass ):
   syntax = '[ remove ] {PREFIXES}'
   noOrDefaultSyntax = '{ PREFIXES }'
   _baseData = {
      'remove': 'Remove IP prefix from IP prefix set'
   }

   @staticmethod
   def handler( mode, args ):
      prefixes = args.get( 'PREFIXES' )
      add = 'remove' not in args
      context = mode.getContext()
      context.updateFieldSet( prefixes, add=add )
      if not context.exceptCoveredByAcceptPrefixFieldSet():
         context.updateFieldSet( prefixes, add=not add )
         mode.addError( "Except prefix not covered by accept" )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      prefixes = args.get( 'PREFIXES' )
      context = mode.getContext()
      context.updateFieldSet( prefixes, add=False )
      if not context.exceptCoveredByAcceptPrefixFieldSet():
         context.updateFieldSet( prefixes, add=True )
         mode.addError( "Except prefix not covered by accept" )

# --------------------------------------------------------------------------
# The "except { PREFIXES }" command
# --------------------------------------------------------------------------
class FieldSetPrefixExceptConfigCmdsBase( CliCommand.CliCommandClass ):
   syntax = 'EXCEPT_ITEMS'
   noOrDefaultSyntax = syntax

   @staticmethod
   def handler( mode, args ):
      context = mode.getContext()
      exceptPrefixes = args.get( 'EXCEPT_ITEMS' )
      context.updateFieldSet( exceptPrefixes, add=True, updateExcept=True )
      if not context.exceptCoveredByAcceptPrefixFieldSet():
         context.updateFieldSet( exceptPrefixes, add=False, updateExcept=True )
         mode.addError( "Except prefix not covered by accept" )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      context = mode.getContext()
      exceptPrefixes = args.get( 'EXCEPT_ITEMS' )
      context.updateFieldSet( exceptPrefixes, add=False, updateExcept=True )

# --------------------------------------------------------------------------
# The "except PORTS" command
# --------------------------------------------------------------------------
class FieldSetL4PortExceptConfigCmds( CliCommand.CliCommandClass ):
   syntax = 'EXCEPT_ITEMS'
   data = {
      'EXCEPT_ITEMS': generateExceptExpression( portRangeMatcher,
                                                exceptSupported=True,
                                                allowMultiple=False )
   }

   @staticmethod
   def handler( mode, args ):
      context = mode.getContext()
      exceptPorts = args.get( 'EXCEPT_ITEMS', set() )
      if not isinstance( exceptPorts, set ):
         exceptPorts = exceptPorts[ 0 ]
      context.updateFieldSet( exceptPorts, add=False )

class NumericalRangeConfigCmdBase( CliCommand.CliCommandClass ):
   _attrName = None
   _rangeType = None
   _argListName = None

   @classmethod
   def handler( cls, mode, args ):
      if mode.context.filter.isValidConfig( conflictOther ):
         rangeSet = args.get( cls._argListName )
         cls._updateRange( mode, rangeSet )
         return
      mode.addError( configConflictMsg % cls._attrName )

   @classmethod
   def noOrDefaultHandler( cls, mode, args ):
      rangeSet = args.get( cls._argListName )
      if rangeSet is None:
         rangeSet = set()
      cls._updateRange( mode, rangeSet, add=False )

   @classmethod
   def _updateRange( cls, mode, rangeSet, add=True ):
      mode.context.updateRangeAttr( attrName=cls._attrName,
                                    rangeSet=rangeSet,
                                    rangeType=cls._rangeType,
                                    add=add )

#--------------------------------------------------
# The "ip length LENGTH" command for traffic-policy
#--------------------------------------------------
class IpLengthConfigCmd( NumericalRangeConfigCmdBase ):
   _attrName = 'length'
   _rangeType = 'Classification::PacketLengthRange'
   _argListName = 'LENGTH'

   syntax = '''ip length LENGTH'''
   noOrDefaultSyntax = '''ip length [ LENGTH ]'''
   data = {
      'ip': 'IP',
      'length': 'Configure ip length match criteria',
      'LENGTH': ipLengthRangeMatcher
   }

class FieldSetIpPrefixConfigCmds( FieldSetPrefixConfigCmdsBase ):
   data = {
      'PREFIXES': IpAddrMatcher.ipPrefixExpr(
         'Prefix address',
         'Prefix mask',
         'Prefix',
         overlap=IpAddrMatcher.PREFIX_OVERLAP_AUTOZERO,
         maskKeyword=False )
   }
   data.update( FieldSetPrefixConfigCmdsBase._baseData )

class FieldSetIpv6PrefixConfigCmds( FieldSetPrefixConfigCmdsBase ):
   data = {
      'PREFIXES': Ip6AddrMatcher.Ip6PrefixValidMatcher(
         'IPv6 address prefix',
         overlap=IpAddrMatcher.PREFIX_OVERLAP_AUTOZERO )
   }
   data.update( FieldSetPrefixConfigCmdsBase._baseData )

class FieldSetIpPrefixExceptConfigCmds( FieldSetPrefixExceptConfigCmdsBase ):
   data = {
      'EXCEPT_ITEMS': generateExceptExpression( IpAddrMatcher.ipPrefixExpr(
         'Prefix address', 'Prefix mask', 'Prefix',
         overlap=IpAddrMatcher.PREFIX_OVERLAP_AUTOZERO ),
         exceptSupported=True )
   }

class FieldSetIpv6PrefixExceptConfigCmds( FieldSetPrefixExceptConfigCmdsBase ):
   data = {
      'EXCEPT_ITEMS': generateExceptExpression( Ip6AddrMatcher.Ip6PrefixValidMatcher(
         'IPv6 address prefix',
         overlap=IpAddrMatcher.PREFIX_OVERLAP_AUTOZERO ),
         exceptSupported=True )
   }

class ProtocolMixin( CliCommand.CliCommandClass ):
   """
   This class defines numerous helper methods for "protocol" commands. These
   are commands like 'protocol udp source port 10-50' as well as the field-set
   variant 'protocol tcp source port field-set FIELD_SET_NAME
   """
   _protoArgsListName = 'PROTOCOL'
   _tcpUdpArgsListName = 'TCP_UDP'
   _tcpFlagArgsListName = 'tcp'
   _sportFieldSetAttr = ""
   _dportFieldSetAttr = ""
   # If this is True, all commands _overwrite_ whatever configuration is currently
   # present.
   _overwrite = False

   @classmethod
   def _maybeHandleErrors( cls, mode, args, proto, source=False, destination=False,
                           fieldSet=False ):
      hasError = False
      context = mode.getContext()
      if not context.isValidConfig( conflictOther ):
         mode.addError( configConflictMsg % 'protocol PROTOCOL' )
         hasError = True
      if not context.isValidConfig( conflictMatchAllFragments ) or not \
         context.isValidConfig( conflictFragmentOffset ):
         mode.addError( invalidL4PortConflictMsg )
         hasError = True
      if not hasError:
         # No error, safe to keep changes
         return
      # Errors found, revert changes made
      cls._updateProtoAndPort( mode, args, proto, source, destination,
                               fieldSet=fieldSet, add=False )

   @classmethod
   def _updateProtoAndPort( cls, mode, args, proto,
                            source=False, destination=False,
                            flags=False, fieldSet=False, add=True ):
      context = mode.getContext()
      if proto:
         if cls._protoArgsListName in args:
            argListName = cls._protoArgsListName
         elif cls._tcpFlagArgsListName in args:
            argListName = cls._tcpFlagArgsListName
         else:
            argListName = cls._tcpUdpArgsListName
         if add or ( not add and not source and not destination and not flags ):
            cls._updateProtocol( mode, args, argListName, add=add )

      if flags:
         cls._updateTcpFlags( mode, args, add=add )

      if not context.getProto():
         # all ports removed, this will remove all sport/dport/icmp etc
         return

      # When only proto is removed, no need to update sport/dport
      if source:
         if fieldSet:
            cls._updatePortFieldSet( mode, args, cls._sportFieldSetAttr,
                                     'SRC_FIELD_SET_NAME', add=add )
         else:
            cls._updatePort( mode, args, 'sport', 'SPORT', add=add )

      if destination:
         if fieldSet:
            cls._updatePortFieldSet( mode, args, cls._dportFieldSetAttr,
                                     'DST_FIELD_SET_NAME', add=add )
         else:
            cls._updatePort( mode, args, 'dport', 'DPORT', add=add )

      # Remove protocol once all additional fields have been removed
      if proto and ( source or destination or flags ):
         protocolRangeSet = args.get( argListName, set() )
         context.maybeUpdateProto( protocolRangeSet )

   @classmethod
   def _updateProtocol( cls, mode, args, argListName, add=True ):
      rangeType = 'Classification::ProtocolRange'
      context = mode.getContext()
      if add and cls._overwrite:
         originalSet = numericalRangeToSet( context.getProto() )
         context.updateRangeAttr( attrName='proto', rangeSet=originalSet,
                                  rangeType=rangeType, add=False )
      # do not remove protocol when removing tcp flags
      protocolRangeSet = args.get( argListName, set() )
      context.updateRangeAttr( attrName='proto',
                               rangeType=rangeType,
                               rangeSet=protocolRangeSet,
                               add=add )

   @classmethod
   def _updateTcpFlags( cls, mode, args, add=True ):
      context = mode.getContext()
      if not args.get( 'flags' ):
         return
      # tcp flags are configured
      notFlags = args.get( 'not' )
      context.updateTcpFlags( args.get( "FLAGS_EXPR" ), notFlags=notFlags,
                              add=add )


   @classmethod
   def _updatePortFieldSet( cls, mode, args, attrName, argListName, add=True ):
      protoRangeSet = args.get( cls._tcpUdpArgsListName, set() ) or \
                      args.get( cls._tcpFlagArgsListName, set() )
      fieldSetNames = args.get( argListName, set() )
      context = mode.getContext()
      context.updatePortFieldSetAttr( attrName, fieldSetNames, add=add,
                                      protoSet=protoRangeSet )

   @classmethod
   def _updatePort( cls, mode, args, attrName, argListName, add=True ):
      clearPrev = False
      if add and cls._overwrite:
         clearPrev = True
      # If the protocol set is either TCP or UDP, or both -- simply add the
      # port qualifier.
      protoRangeSet = args.get( cls._tcpUdpArgsListName, set() ) or \
                      args.get( "tcp", set() )
      portRangeSet = args.get( argListName, set() )
      context = mode.getContext()
      context.updatePortRangeAttr( attrName, protoRangeSet,
                                   portRangeSet, add=add, clearPrev=clearPrev )

#--------------------------------------------------------------------------------
# The "protocol icmp | icmpv6 type TYPE code all" command for traffic-policy
#--------------------------------------------------------------------------------
class ProtocolIcmpConfigBase( ProtocolMixin ):
   _icmpName = 'icmp'
   _typeArgsListName = 'TYPE'

   _baseData = {
      'protocol': 'Protocol',
      'type': 'Configure ICMP type',
      'code': 'Configure ICMP code',
      'all': 'Configure ALL ICMP codes'
   }

   @classmethod
   def handler( cls, mode, args ):
      args.pop( cls._icmpName )
      icmpValue = icmpProtocols[ cls._icmpName ][ 0 ]
      args[ cls._icmpName ] = set( [ icmpValue ] )
      cls._updateProtocol( mode, args, argListName=cls._icmpName, add=True )
      context = mode.getContext()
      icmpTypeSet = args[ cls._typeArgsListName ]
      if icmpTypeSet:
         rangeType = 'Classification::IcmpTypeRange'
         context.updateIcmpRangeAttr( icmpValue, rangeType, icmpTypeSet, add=True )

   @classmethod
   def noOrDefaultHandler( cls, mode, args ):
      args.pop( cls._icmpName )
      icmpValue = icmpProtocols[ cls._icmpName ][ 0 ]
      args[ cls._icmpName ] = set( [ icmpValue ] )
      context = mode.getContext()
      icmpTypeSet = args.get( cls._typeArgsListName, set() )
      rangeType = 'Classification::IcmpTypeRange'
      context.updateIcmpRangeAttr( icmpValue, rangeType, icmpTypeSet, add=False )
      # Remove protocol once all additional fields have been removed
      protocolRangeSet = args.get( cls._icmpName, set() )
      context.maybeUpdateProto( protocolRangeSet )

class ProtocolIcmpV4ConfigCmd( ProtocolIcmpConfigBase ):
   syntax = 'protocol icmp type TYPE code all'
   noOrDefaultSyntax = ( 'protocol icmp type [ TYPE [ code all ] ]' )
   data = {
      'icmp': icmpV4KwMatcher,
      'TYPE': icmpV4TypeRangeExpr,
   }
   data.update( ProtocolIcmpConfigBase._baseData )

class ProtocolIcmpV6ConfigCmd( ProtocolIcmpConfigBase ):
   _icmpName = 'icmpv6'

   syntax = 'protocol icmpv6 type TYPE code all'
   noOrDefaultSyntax = ( 'protocol icmpv6 type [ TYPE [ code all ] ]' )
   data = {
      'icmpv6': icmpV6KwMatcher,
      'TYPE': icmpV6TypeRangeExpr,
   }
   data.update( ProtocolIcmpConfigBase._baseData )

#--------------------------------------------------------------------------------
# The "protocol icmp | icmpv6 type TYPE code CODE" command for traffic-policy
#--------------------------------------------------------------------------------
class ProtocolIcmpTypeCodeConfigBase( ProtocolMixin ):
   _icmpName = 'icmp'
   _typeArgsListName = 'TYPE'
   _codeArgsListName = 'CODE'

   _baseData = {
      'protocol': 'Protocol',
      'type': 'Configure ICMP type',
      'code': 'Configure ICMP code',
   }

   @classmethod
   def handler( cls, mode, args ):
      args.pop( cls._icmpName )
      icmpValue = icmpProtocols[ cls._icmpName ][ 0 ]
      args[ cls._icmpName ] = set( [ icmpValue ] )
      cls._updateProtocol( mode, args, argListName=cls._icmpName, add=True )
      context = mode.getContext()
      icmpTypeValue = args.get( cls._typeArgsListName )
      icmpCodeSet = args.get( cls._codeArgsListName )
      if icmpTypeValue and icmpCodeSet:
         typeRangeType = 'Classification::IcmpTypeRange'
         codeRangeType = 'Classification::IcmpCodeRange'
         context.addIcmpTypeCodeRangeAttr( icmpValue, typeRangeType, icmpTypeValue,
                                           codeRangeType, icmpCodeSet )

   @classmethod
   def noOrDefaultHandler( cls, mode, args ):
      args.pop( cls._icmpName )
      icmpValue = icmpProtocols[ cls._icmpName ][ 0 ]
      args[ cls._icmpName ] = set( [ icmpValue ] )
      cls._updateProtocol( mode, args, argListName=cls._icmpName, add=True )
      context = mode.getContext()
      icmpTypeValue = args.get( cls._typeArgsListName )
      icmpCodeSet = args.get( cls._codeArgsListName, set() )
      if icmpTypeValue:
         typeRangeType = 'Classification::IcmpTypeRange'
         codeRangeType = 'Classification::IcmpCodeRange'
         context.removeIcmpTypeCodeRangeAttr( icmpValue, typeRangeType,
                                              icmpTypeValue, codeRangeType,
                                              icmpCodeSet )

class ProtocolIcmpV4TypeCodeConfigCmd( ProtocolIcmpTypeCodeConfigBase ):
   syntax = 'protocol icmp type TYPE code CODE'
   noOrDefaultSyntax = 'protocol icmp type TYPE code [ CODE ]'

   data = {
      'icmp': icmpV4KwMatcher,
      'TYPE': icmpV4TypeSingleExpr,
      'CODE': icmpV4CodeExpr
   }
   data.update( ProtocolIcmpTypeCodeConfigBase._baseData )

class ProtocolIcmpV6TypeCodeConfigCmd( ProtocolIcmpTypeCodeConfigBase ):
   _icmpName = 'icmpv6'

   syntax = 'protocol icmpv6 type TYPE code CODE'
   noOrDefaultSyntax = 'protocol icmpv6 type TYPE code [ CODE ]'
   data = {
      'icmpv6': icmpV6KwMatcher,
      'TYPE': icmpV6TypeSingleExpr,
      'CODE': icmpV6CodeExpr
   }
   data.update( ProtocolIcmpTypeCodeConfigBase._baseData )

#------------------------------------------------------------------------------------
# The "protocol (tcp [ flags [ not ] TCP_FLAGS ] | udp) source port field-set
# FIELD_SET destination port field-set FIELD_SET" command for traffic-policy
#------------------------------------------------------------------------------------
class ProtocolFieldSetBaseCmd( ProtocolMixin ):
   syntax = ( 'protocol ( ( TCP_UDP ) | ( FLAGS_EXPR ) ) '
              '( source port1 field-set1 SRC_FIELD_SET_NAME ) |'
              '( destination port2 field-set2 DST_FIELD_SET_NAME ) |'
              '( source port1 field-set1 SRC_FIELD_SET_NAME '
              'destination port2 field-set2 DST_FIELD_SET_NAME )' )
   noOrDefaultSyntax = ( 'protocol ( ( TCP_UDP ) | ( FLAGS_EXPR ) )'
                     '( source port1 field-set1 [ SRC_FIELD_SET_NAME ] ) | '
                     '( destination port2 field-set2 [ DST_FIELD_SET_NAME ] ) | '
                     '( source port1 field-set1 SRC_FIELD_SET_NAME '
                     'destination port2 field-set2 DST_FIELD_SET_NAME )' )
   _baseData = {
      'protocol': 'Protocol',
      'TCP_UDP': tcpUdpProtoExpr,
      'field-set1': fieldSetKwMatcher,
      'field-set2': fieldSetKwMatcher,
      'port1': portKwMatcher,
      'port2': portKwMatcher,
      'source': 'Source',
      'destination': 'Destination',
   }

   @classmethod
   def handler( cls, mode, args ):
      proto = ( args.get( cls._tcpUdpArgsListName ) or
                args.get( cls._tcpFlagArgsListName ) )
      if not proto and not args:
         return
      source = args.get( 'source' )
      destination = args.get( 'destination' )
      flags = args.get( 'flags', False )
      cls._updateProtoAndPort( mode, args, proto,
                               source, destination,
                               flags=flags, fieldSet=True, add=True )
      cls._maybeHandleErrors( mode, args, proto, source, destination, fieldSet=True )

   @classmethod
   def noOrDefaultHandler( cls, mode, args ):
      proto = ( args.get( cls._tcpUdpArgsListName ) or
                args.get( cls._tcpFlagArgsListName ) )
      source = args.get( 'source' )
      destination = args.get( 'destination' )
      flags = args.get( 'flags', False )

      if not proto and not source and not destination:
         # 'no protocol' removes all protocols
         proto = True

      cls._updateProtoAndPort( mode, args, proto,
                               source, destination,
                               flags=flags, fieldSet=True, add=False )

#--------------------------------------------------
# The "fragment" command for traffic-policy
#--------------------------------------------------
class MatchAllFragmentConfigCmd( CliCommand.CliCommandClass ):
   syntax = '''fragment'''
   noOrDefaultSyntax = syntax
   data = {
     'fragment': 'fragment'
   }

   @classmethod
   def _maybeHandleErrors( cls, mode, args ):
      hasError = False
      context = mode.getContext()
      if not context.isValidConfig( conflictMatchAllFragments ):
         mode.addError( invalidFragmentConflictMsg )
         hasError = True
      if not hasError:
         return
      context.clearFragment()

   @classmethod
   def handler( cls, mode, args ):
      context = mode.getContext()
      context.updateFragmentType( matchAll )
      cls._maybeHandleErrors( mode, args )

   @classmethod
   def noOrDefaultHandler( cls, mode, args ):
      context = mode.getContext()
      context.clearFragment()

#--------------------------------------------------------
# The "fragment offset OFFSET" command for traffic-policy
#--------------------------------------------------------
class FragmentOffsetConfigCmd( NumericalRangeConfigCmdBase ):
   _attrName = 'fragmentOffset'
   _rangeType = 'Classification::FragmentOffsetRange'
   _argListName = 'OFFSET'

   syntax = '''fragment offset OFFSET'''
   noOrDefaultSyntax = '''fragment offset [ OFFSET ]'''
   data = {
      'fragment': 'fragment',
      'offset': 'Offset keyword',
      'OFFSET': fragOffsetRangeMatcher
   }

   @classmethod
   def _maybeHandleErrors( cls, mode, args ):
      hasError = False
      context = mode.getContext()
      if not context.isValidConfig( conflictFragmentOffset ):
         mode.addError( invalidFragOffsetConflictMsg )
         hasError = True
      if not hasError:
         return
      context.clearFragment()

   @classmethod
   def handler( cls, mode, args ):
      context = mode.getContext()
      context.updateFragmentType( matchOffset )
      super( FragmentOffsetConfigCmd, cls ).handler( mode, args )
      cls._maybeHandleErrors( mode, args )

   @classmethod
   def noOrDefaultHandler( cls, mode, args ):
      super( FragmentOffsetConfigCmd, cls ).noOrDefaultHandler( mode, args )
      context = mode.getContext()
      context.maybeDelFragment()

#--------------------------------------------------
# The "ip options" command for traffic-policy
#--------------------------------------------------
class IpOptionsConfigCmd( CliCommand.CliCommandClass ):
   syntax = '''ip options'''
   noOrDefaultSyntax = syntax
   data = {
      'ip': 'IP',
      'options': 'Match packets with IPv4 options',
   }

   @classmethod
   def handler( cls, mode, args ):
      context = mode.getContext()
      context.updateMatchIpOptions( add=True )

   @classmethod
   def noOrDefaultHandler( cls, mode, args ):
      context = mode.getContext()
      context.updateMatchIpOptions( add=False )

# modelet to be added to modes that support commit/abort
class CommitAbortModelet( CliParser.Modelet ):
   modeletParseTree = CliParser.ModeletParseTree()

   def __init__( self, mode ):
      CliParser.Modelet.__init__( self )

   @staticmethod
   def shouldAddModeletRule( mode ):
      return True

class CommitCommand( CliCommand.CliCommandClass ):
   syntax = "commit"
   data = { "commit": "Commit all changes" }

   @staticmethod
   def handler( mode, args ):
      mode.commit()

class AbortCommand( CliCommand.CliCommandClass ):
   syntax = "abort"
   data = { "abort": "Abandon all changes" }

   @staticmethod
   def handler( mode, args ):
      mode.abort()

CommitAbortModelet.addCommandClass( CommitCommand )
CommitAbortModelet.addCommandClass( AbortCommand )
