#!/usr/bin/env python
# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from CliMode.Classification import ( FieldSetPrefixModeBase,
                                     FieldSetL4PortModeBase )
from CliMode.TrafficPolicy import TrafficPoliciesConfigMode
from CliMode.TrafficPolicy import TrafficPoliciesModeBase
from CliMode.TrafficPolicy import TrafficPolicyModeBase
from CliMode.TrafficPolicy import MatchRuleModeBase
from CliMode.TrafficPolicy import ActionsModeBase
from CliMode.TrafficPolicy import FEATURE
import CliSave
from IntfCliSave import IntfConfigMode
from TrafficPolicyLib import structuredFilterToCmds
from ClassificationLib import numericalRangeToRangeString
import Tac

tacMatchOption = Tac.Type( 'PolicyMap::ClassMapMatchOption' )
matchIpAccessGroup = tacMatchOption.matchIpAccessGroup
matchIpv6AccessGroup = tacMatchOption.matchIpv6AccessGroup

class TrafficPoliciesSaveMode( TrafficPoliciesModeBase, CliSave.Mode ):
   def __init__( self, param ):
      TrafficPoliciesModeBase.__init__( self )
      CliSave.Mode.__init__( self, param )

CliSave.GlobalConfigMode.addChildMode( TrafficPoliciesSaveMode,
                                       after=[ IntfConfigMode ] )
TrafficPoliciesSaveMode.addCommandSequence( 'TrafficPolicy.TrafficPolicies' )
TrafficPoliciesSaveMode.addCommandSequence( 'TrafficPolicyParam' )

class TrafficPolicySaveMode( TrafficPolicyModeBase, CliSave.Mode ):
   def __init__( self, param ):
      TrafficPolicyModeBase.__init__( self, param )
      CliSave.Mode.__init__( self, param )

TrafficPoliciesSaveMode.addChildMode( TrafficPolicySaveMode )
TrafficPolicySaveMode.addCommandSequence( 'TrafficPolicy.TrafficPolicy' )

class MatchRuleSaveMode( MatchRuleModeBase, CliSave.Mode ):
   def __init__( self, param ):
      MatchRuleModeBase.__init__( self, param )
      CliSave.Mode.__init__( self, param )

   def __cmp__( self, other ):
      return cmp( self.prio, other.prio )

TrafficPolicySaveMode.addChildMode( MatchRuleSaveMode,
                                    after=[ 'TrafficPolicy.TrafficPolicy' ] )
MatchRuleSaveMode.addCommandSequence( 'TrafficPolicy.MatchRule' )

class ActionsSaveMode( ActionsModeBase, CliSave.Mode ):
   def __init__( self, param ):
      ActionsModeBase.__init__( self, param )
      CliSave.Mode.__init__( self, param )

MatchRuleSaveMode.addChildMode( ActionsSaveMode,
                                after=[ 'TrafficPolicy.MatchRule' ] )
ActionsSaveMode.addCommandSequence( 'TrafficPolicy.Actions' )

class FSPrefixSaveMode( FieldSetPrefixModeBase, CliSave.Mode ):
   def __init__( self, param ):
      FieldSetPrefixModeBase.__init__( self, param )
      CliSave.Mode.__init__( self, param )

TrafficPoliciesSaveMode.addChildMode( FSPrefixSaveMode )
FSPrefixSaveMode.addCommandSequence( 'TrafficPolicy.FieldSetPrefix' )

class FSL4PortSaveMode( FieldSetL4PortModeBase, CliSave.Mode ):
   def __init__( self, param ):
      FieldSetL4PortModeBase.__init__( self, param )
      CliSave.Mode.__init__( self, param )

TrafficPoliciesSaveMode.addChildMode( FSL4PortSaveMode )
FSL4PortSaveMode.addCommandSequence( 'TrafficPolicy.FieldSetL4Port' )

def _hasComments( commentKey, requireMounts ):
   return commentKey in requireMounts[ 'cli/config' ].comment

class TrafficPolicySaver( object ):
   trafficPoliciesMode = TrafficPoliciesSaveMode
   trafficPolicyMode = TrafficPolicySaveMode
   matchRuleMode = MatchRuleSaveMode
   actionsRuleMode = ActionsSaveMode

   def __init__( self, entity, root, sysdbRoot, options, requireMounts, feature,
                 commentKey='traffic-policies' ):
      self.entity = entity
      self.root = root
      self.sysdbRoot = sysdbRoot
      self.options = options
      self.trafficPolicy = None
      self.currPolicyName = None
      self.matchRule = None
      self.requireMounts = requireMounts
      self.feature = feature
      self.policyMapType = 'traffic-policy'
      self.commentKey = commentKey

   def actionsRuleCommentKey( self, policyName, matchName ):
      return '%s-actions-%s-%s' % ( self.policyMapType, policyName, matchName )

   def trafficPolicyModeCmds( self, policyMode ):
      return policyMode[ 'TrafficPolicy.TrafficPolicy' ]

   def matchModeCmds( self, matchMode ):
      return matchMode[ 'TrafficPolicy.MatchRule' ]

   def actionModeCmds( self, actionsMode ):
      return actionsMode[ 'TrafficPolicy.Actions' ]

   def _addrListToStr( self, addrs ):
      return " ".join( map( str, sorted( addrs ) ) )

   def saveTrafficPolicyAll( self ):
      allTrafficPolicy = []
      if self.entity.pmapType is not None:
         allTrafficPolicy = sorted( self.entity.pmapType.pmap.keys() )
      if allTrafficPolicy or \
            _hasComments( self.commentKey, self.requireMounts ):
         # we have some config, go to the right mode
         trafficPoliciesMode = \
            self.root[ self.trafficPoliciesMode ].getOrCreateModeInstance( None )

         # Output any traffic policies
         for policy in allTrafficPolicy:
            self.saveTrafficPolicy( policy, trafficPoliciesMode )

   def saveTrafficPolicy( self, policyName, policiesMode ):
      self.trafficPolicy = self.entity.pmapType.pmap[ policyName ].currCfg
      self.currPolicyName = policyName
      if not self.trafficPolicy:
         return
      param = ( self.feature, self.trafficPoliciesMode, policyName )
      trafficPolicyMode = \
         policiesMode[ self.trafficPolicyMode ].getOrCreateModeInstance( param )
      self.saveNamedCounters( trafficPolicyMode )
      self.saveMatchRules( trafficPolicyMode )

   def saveNamedCounters( self, policyMode ):
      if not self.trafficPolicy.namedCounter:
         return
      cmds = self.trafficPolicyModeCmds( policyMode )
      cmd = "counter %s" % ( " ".join( sorted( self.trafficPolicy.namedCounter ) ) )
      cmds.addCommand( cmd )

   def saveMatchRules( self, policyMode ):
      for prio, ruleName in self.trafficPolicy.classPrio.iteritems():
         matchRule = self.trafficPolicy.rawClassMap[ ruleName ]
         matchOption = matchRule.match.keys()[ 0 ]
         matchType = ""
         if matchOption == matchIpAccessGroup:
            matchType = "ipv4"
         elif matchOption == matchIpv6AccessGroup:
            matchType = "ipv6"
         else:
            raise NotImplementedError
         param = ( self.feature, self.trafficPolicyMode,
                   self.currPolicyName, ruleName, matchType, prio )
         matchMode = \
               policyMode[ self.matchRuleMode ].getOrCreateModeInstance( param )
         structuredFilter = matchRule.match[ matchOption ].structuredFilter
         classAction = self.trafficPolicy.classAction.get( ruleName )
         structuredFilterCmds = structuredFilterToCmds( structuredFilter,
                                                        classAction.policyAction,
                                                        matchType )
         if structuredFilter is not None:
            self.saveSip( structuredFilterCmds, matchMode )
            self.saveFieldSet( structuredFilterCmds, 'srcPrefixSet', matchMode )
            self.saveDip( structuredFilterCmds, matchMode )
            self.saveFieldSet( structuredFilterCmds, 'dstPrefixSet', matchMode )
            self.saveProtoNeighborsBgp( structuredFilterCmds, matchMode )
            self.saveProtoMatchBgp( structuredFilterCmds, matchMode )
            self.saveProtocol( structuredFilterCmds, matchMode )
            self.saveFragment( structuredFilterCmds, matchMode )
            self.saveOptions( structuredFilterCmds, matchMode )
            self.saveActions( structuredFilterCmds, ruleName, matchMode )
            self.saveNumericalRangeOption( structuredFilterCmds, 'ttl', matchMode )
            self.saveNumericalRangeOption( structuredFilterCmds, 'dscp', matchMode )
            self.saveNumericalRangeOption( structuredFilterCmds, 'length',
                                           matchMode )

   def saveSip( self, structuredFilterCmds, matchMode ):
      cmds = self.matchModeCmds( matchMode )
      cmd = structuredFilterCmds.get( 'source' )
      if cmd:
         cmds.addCommand( cmd )

   def saveDip( self, structuredFilterCmds, matchMode ):
      cmds = self.matchModeCmds( matchMode )
      cmd = structuredFilterCmds.get( 'destination' )
      if cmd:
         cmds.addCommand( cmd )

   def saveFieldSet( self, structuredFilterCmds, prefixSet, matchMode ):
      cmds = self.matchModeCmds( matchMode )
      cmd = structuredFilterCmds.get( prefixSet )
      if cmd:
         cmds.addCommand( cmd )

   def saveActions( self, structuredFilterCmds, ruleName, matchMode ):
      actionCmds = structuredFilterCmds.get( 'actions' )
      if actionCmds or \
            _hasComments( self.actionsRuleCommentKey( self.currPolicyName,
                                                      ruleName ),
                          self.requireMounts ):
         param = ( self.feature, self.currPolicyName, ruleName )
         actionsMode = \
            matchMode[ self.actionsRuleMode ].getOrCreateModeInstance( param )
         if actionCmds:
            cmds = self.actionModeCmds( actionsMode )
            for cmd in actionCmds:
               cmds.addCommand( cmd )

   def saveProtocol( self, structuredFilterCmds, matchMode ):
      cmds = self.matchModeCmds( matchMode )
      commands = structuredFilterCmds.get( 'protocol', [] )
      for cmd in commands:
         cmds.addCommand( cmd )

   def saveFragment( self, structuredFilterCmds, matchMode ):
      cmds = self.matchModeCmds( matchMode )

      # Handle fragment
      cmd = structuredFilterCmds.get( 'fragment' )
      if cmd:
         cmds.addCommand( cmd )

   def saveOptions( self, structuredFilterCmds, matchMode ):
      cmds = self.matchModeCmds( matchMode )

      # Handle match options
      cmd = structuredFilterCmds.get( 'matchIpOptions' )
      if cmd:
         cmds.addCommand( cmd )

   def saveNumericalRangeOption( self, structuredFilterCmds, optionCmd, matchMode ):
      cmds = self.matchModeCmds( matchMode )
      cmd = structuredFilterCmds.get( optionCmd )
      if cmd:
         cmds.addCommand( cmd )

   def saveProtoNeighborsBgp( self, structuredFilterCmds, matchMode ):
      cmds = self.matchModeCmds( matchMode )
      cmd = structuredFilterCmds.get( 'protoNeighborsBgp' )
      if cmd:
         cmds.addCommand( cmd )

   def saveProtoMatchBgp( self, structuredFilterCmds, matchMode ):
      cmds = self.matchModeCmds( matchMode )
      cmd = structuredFilterCmds.get( 'protoBgp' )
      if cmd:
         cmds.addCommand( cmd )

   def save( self ):
      self.saveTrafficPolicyAll()

class FieldSetSaver( object ):
   def __init__( self, entity, root, sysdbRoot, options, requireMounts ):
      self.entity = entity
      self.root = root
      self.sysdbRoot = sysdbRoot
      self.options = options
      self.trafficPolicy = None
      self.currPolicyName = None
      self.matchRule = None
      self.requireMounts = requireMounts

   def saveFieldSetAll( self ):
      fieldSetIpv4Prefix = sorted( self.entity.fieldSetIpPrefix.keys() )
      fieldSetIpv6Prefix = sorted( self.entity.fieldSetIpv6Prefix.keys() )
      fieldSetL4Port = sorted( self.entity.fieldSetL4Port )
      if not fieldSetIpv4Prefix and not fieldSetIpv6Prefix and not fieldSetL4Port:
         # No field-sets defined
         return

      trafficPoliciesMode = \
         self.root[ TrafficPoliciesSaveMode ].getOrCreateModeInstance( None )
      self.saveAllFieldSetPrefix( fieldSetIpv4Prefix, 'ipv4', trafficPoliciesMode )
      self.saveAllFieldSetPrefix( fieldSetIpv6Prefix, 'ipv6', trafficPoliciesMode )
      self.saveAllFieldSetL4Port( fieldSetL4Port, trafficPoliciesMode )

   def saveAllFieldSetPrefix( self, fieldSetNames, af, policiesMode ):
      for fieldSetName in fieldSetNames:
         if af == 'ipv4':
            fieldSet = self.entity.fieldSetIpPrefix[ fieldSetName ]
         else:
            fieldSet = self.entity.fieldSetIpv6Prefix[ fieldSetName ]
         param = ( 'tp', fieldSet.af,
                   fieldSetName, TrafficPoliciesConfigMode )
         fieldSetPrefixMode = \
            policiesMode[ FSPrefixSaveMode ].getOrCreateModeInstance( param )
         cmds = fieldSetPrefixMode[ 'TrafficPolicy.FieldSetPrefix' ]
         prefixes = [ prefix.stringValue for prefix in
                      sorted( fieldSet.currCfg.prefixes ) ]
         cmd = " ".join( prefixes )
         cmds.addCommand( cmd )
         if fieldSet.currCfg.exceptPrefix:
            cmd = "except "
            exceptPrefixes = \
               [ p.stringValue for p in sorted( fieldSet.currCfg.exceptPrefix ) ]
            cmd += " ".join( exceptPrefixes )
            cmds.addCommand( cmd )

   def saveAllFieldSetL4Port( self, fieldSetNames, policiesMode ):
      for fieldSetName in fieldSetNames:
         fieldSet = self.entity.fieldSetL4Port[ fieldSetName ]
         param = ( 'tp', 'l4-port',
                   fieldSetName, TrafficPoliciesConfigMode )
         fieldSetL4Port = \
            policiesMode[ FSL4PortSaveMode ].getOrCreateModeInstance( param )
         cmds = fieldSetL4Port[ 'TrafficPolicy.FieldSetL4Port' ]
         if fieldSet.currCfg and fieldSet.currCfg.ports:
            cmds.addCommand(
                  numericalRangeToRangeString( fieldSet.currCfg.ports ) )

   def save( self ):
      self.saveFieldSetAll()

class TrafficPolicyParamSaver( object ):
   trafficPoliciesMode = TrafficPoliciesSaveMode

   def __init__( self, entity, root, sysdbRoot, options, requireMounts ):
      self.entity = entity
      self.root = root
      self.sysdbRoot = sysdbRoot
      self.options = options
      self.requireMounts = requireMounts

   def save( self ):
      if self.entity.actionDuringUpdate == 'drop' or \
         self.entity.ingressCounterGranularity == 'counterPerInterface':
         trafficPoliciesMode = \
            self.root[ self.trafficPoliciesMode ].getOrCreateModeInstance( None )
         cmds = trafficPoliciesMode[ 'TrafficPolicyParam' ]

      if self.entity.actionDuringUpdate == 'drop':
         cmds.addCommand( "update interface default action drop" )

      if self.entity.ingressCounterGranularity == 'counterPerInterface':
         cmds.addCommand( "counter interface per-interface ingress" )

# Save all field-set
@CliSave.saver( 'Classification::FieldSetConfig', 'trafficPolicies/fieldset/cli',
                requireMounts=( 'cli/config', ) )
def saveFieldSetConfig( entity, root, sysdbRoot, options, requireMounts ):
   cliDumper = FieldSetSaver( entity, root, sysdbRoot, options, requireMounts )
   cliDumper.save()

# Save all traffic-policy
@CliSave.saver( 'TrafficPolicy::TrafficPolicyConfig', 'trafficPolicies/input/cli',
                requireMounts=( 'cli/config', ) )
def saveTrafficPolicyConfig( entity, root, sysdbRoot, options, requireMounts ):
   cliDumper = TrafficPolicySaver( entity, root, sysdbRoot, options, requireMounts,
                                   FEATURE )
   cliDumper.save()

# Save update default action
@CliSave.saver( 'TrafficPolicy::TrafficPolicyIntfParamConfig',
                'trafficPolicies/param/config/interface',
                requireMounts=( 'cli/config', ) )
def saveTrafficPolicyParamConfig( entity, root, sysdbRoot, options, requireMounts ):
   cliDumper = TrafficPolicyParamSaver( entity, root, sysdbRoot, options,
                                        requireMounts )
   cliDumper.save()
