#!/usr/bin/env python
# Copyright (c) 2013 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac
import CliSave, IntfCliSave
from CliMode.PolicyMap import ClassMapModeBase, PolicyMapModeBase, \
   PolicyMapClassModeBase
from AclCliSave import IpAclConfigMode
from IntfCliSave import IntfConfigMode
from AclCliLib import ruleFromValue
import re
import Intf.IntfRange

# Save after interface mode so we can parse in startup-config
CliSave.GlobalConfigMode.addCommandSequence( 'TapAgg.pmapconfig',
                           after=[ IpAclConfigMode, IntfConfigMode ] )

IntfCliSave.IntfConfigMode.addCommandSequence( 'TapAgg.intfconfig' )

class ClassMapConfigMode( ClassMapModeBase, CliSave.Mode ):
   def __init__( self, param ):
      ClassMapModeBase.__init__( self, param )
      CliSave.Mode.__init__( self, self.longModeKey )

CliSave.GlobalConfigMode.addChildMode( ClassMapConfigMode,
                                       after=[ IpAclConfigMode ] )

ClassMapConfigMode.addCommandSequence( 'TapAgg.cmap' )


#-------------------------------------------------------------------------------
# Object used for saving commands in "config-pmap" mode.
#-------------------------------------------------------------------------------
class PolicyMapConfigMode( PolicyMapModeBase, CliSave.Mode ):
   def __init__( self, param ):
      PolicyMapModeBase.__init__( self, param )
      CliSave.Mode.__init__( self, self.longModeKey )

CliSave.GlobalConfigMode.addChildMode( PolicyMapConfigMode,
                                       after=[ ClassMapConfigMode ] )
PolicyMapConfigMode.addCommandSequence( 'TapAgg.pmap' )

#-------------------------------------------------------------------------------
# Object used for saving commands in "config-pmap-c" mode.
#-------------------------------------------------------------------------------
class PolicyMapClassConfigMode( PolicyMapClassModeBase, CliSave.Mode ):
   def __init__( self, param ):
      ( self.mapType, self.mapStr, self.pmapName, self.cmapName,
        self.prio, self.entCmd_ ) = param
      PolicyMapClassModeBase.__init__( self, param[ : -1 ] )
      CliSave.Mode.__init__( self, self.longModeKey )

   def enterCmd( self ):
      if self.entCmd_:
         return self.entCmd_
      else:
         return '%s class %s' % ( self.prio, self.cmapName )

   def __cmp__( self, other ):
      return cmp( self.prio, other.prio )

   def modeSeparator( self ):
      return not self.entCmd_

PolicyMapConfigMode.addChildMode( PolicyMapClassConfigMode )
PolicyMapClassConfigMode.addCommandSequence( 'TapAgg.pmapc' )

def CmdPMapClassAction( classAction, requireMounts ):
   cmd = ''
   aggGroup = classAction.policyAction.get( 'setAggregationGroup' )
   idTag = classAction.policyAction.get( 'setIdentityTag' )
   stripHdr = classAction.policyAction.get( 'stripHeaderBytes' )

   tapAggConfig = requireMounts[ 'tapagg/cliconfig' ] 
   toolGroupNewFormat = tapAggConfig.toolGroupNewFormat

   if aggGroup or idTag:
      cmd = 'set'
      if aggGroup:
         aggGroupList = aggGroup.aggGroup.keys()
         aggIntfList = aggGroup.aggIntf.keys()
         if aggGroupList:
            if toolGroupNewFormat or aggIntfList or ( len(aggGroupList) > 1 ):
               cmd += ' aggregation-group group %s' \
                   % ( ' group '.join( sorted( aggGroupList ) ) )
            else:
               cmd += ' aggregation-group %s' \
                   % ( aggGroupList[ 0 ] )

         if aggIntfList :
            cmd += ' interface'
            printIntfList = Intf.IntfRange.intfListToCanonical( aggIntfList )
            for intf in printIntfList:
               cmd += ' ' + intf

      if idTag:
         cmd += ' id-tag %d' % idTag.idTag.outerVid
         if idTag.idTag.innerVid:
            cmd += ' inner %d' % idTag.idTag.innerVid

   if stripHdr:
      cmd += ' ' if cmd else ''
      if stripHdr.stripHdrBytes.hdrType == 'dot1q':
         cmd += 'remove dot1q outer %s' \
               % stripHdr.stripHdrBytes.dot1qRemoveVlans
   return cmd

class TapAggPmapCliSaver( object ):
   def __init__( self, entity, root, sysdbRoot, options, requireMounts ):
      self.entity = entity
      self.root = root
      self.sysdbRoot = sysdbRoot
      self.options = options
      self.mapType = 'mapTapAgg'
      self.mapStr = 'tapagg'
      self.requireMounts = requireMounts
      self.cmds = None
      self.pmap = None
      self.cmap = None
      self.intfConfig = self.requireMounts[ 'tapagg/intfconfig' ] 

   def savePMapClassAction( self, cmapName ):
      classAction = self.pmap.classAction[ cmapName ]
      aggGroup = None
      idTag = None
      aggGroupList = []
      aggIntfList = []
      aggGroup = classAction.policyAction.get( 'setAggregationGroup' )
      idTag = classAction.policyAction.get( 'setIdentityTag' )
      stripHdr = classAction.policyAction.get( 'stripHeaderBytes' )

      tapAggConfig = self.requireMounts[ 'tapagg/cliconfig' ] 
      toolGroupNewFormat = tapAggConfig.toolGroupNewFormat

      if aggGroup:
         aggGroupList = aggGroup.aggGroup.keys()
      if aggGroupList:
         cmd = 'set aggregation-group'
         if toolGroupNewFormat or ( len(aggGroupList) > 1 ):
            cmd += ' group %s' % ' group '.join( sorted( aggGroupList ) )
         else:
            cmd += ' %s' % ' '.join( sorted( aggGroupList ) )
         if idTag:
            if idTag.idTag.innerVid:
               cmd += ' id-tag %d inner %d' % ( idTag.idTag.outerVid,
                                                idTag.idTag.innerVid )
            else:
               cmd += ' id-tag %d' % idTag.idTag.outerVid
         self.cmds.addCommand( cmd )
      elif idTag:
         if idTag.idTag.innerVid:
            self.cmds.addCommand( 'set id-tag %d inner %d' %
                                  ( idTag.idTag.outerVid, idTag.idTag.innerVid ) )
         else:
            self.cmds.addCommand( 'set id-tag %d' % idTag.idTag.outerVid )
      
      if aggGroup:
         aggIntfList = aggGroup.aggIntf.keys()
      if aggIntfList:
         printIntfs = Intf.IntfRange.intfListToCanonical( aggIntfList )
         for intfs in printIntfs:
            self.cmds.addCommand( 'set interface %s' % intfs )

      if stripHdr:
         if stripHdr.stripHdrBytes.hdrType == 'dot1q':
            self.cmds.addCommand( 'remove dot1q outer %s'
                                  % stripHdr.stripHdrBytes.dot1qRemoveVlans )

   def insertAclTypeAndAddCmd( self, pmapName, cmapName, prio, aclType, cmd ):
      # Given the cmd string, insert the acl type in the raw match statement
      # appropriately and add the final cmd to the list of config cmds.
      cmd = re.sub( '^permit', 'match', cmd )
      tokens = cmd.split()
      insIndex = 1

      if 'inner' in tokens:
         insIndex = tokens.index( 'inner' ) + 3
      elif 'vlan' in tokens:
         insIndex = tokens.index( 'vlan' ) + 3

      if not aclType in tokens:
         tokens.insert( insIndex, aclType )
      cmd = '%d %s' % ( prio, ' '.join( tokens ) )
      cmd += ' ' + CmdPMapClassAction( self.pmap.classAction[ cmapName ],
                                       self.requireMounts )
      pmapMode = self.root[ PolicyMapConfigMode ].getOrCreateModeInstance( \
                            ( self.mapType, self.mapStr, pmapName ) )
      _rawMatchMode = pmapMode[ PolicyMapClassConfigMode ].\
                      getOrCreateModeInstance( ( self.mapType,
                                                 self.mapStr,
                                                 pmapName,
                                                 cmapName, prio, cmd ) )

   def saveRawClassMap( self, pmapName, cmapName, prio ):
      self.cmap = self.pmap.rawClassMap.get( cmapName, None )
      if not self.cmap:
         return

      for cmapMatch in self.cmap.match.itervalues():
         for _, ipRuleCfg in cmapMatch.ipRule.iteritems():
            ipRuleCmd = ruleFromValue( ipRuleCfg, 'ip' )
            self.insertAclTypeAndAddCmd( pmapName, cmapName, prio, 'ip',
                                         ipRuleCmd )

         for _, ip6RuleCfg in cmapMatch.ip6Rule.iteritems():
            ip6RuleCmd = ruleFromValue( ip6RuleCfg, 'ipv6' )
            self.insertAclTypeAndAddCmd( pmapName, cmapName, prio, 'ipv6',
                                         ip6RuleCmd )

         for _, macRuleCfg in cmapMatch.macRule.iteritems():
            macRuleCmd = ruleFromValue( macRuleCfg, 'mac' )
            self.insertAclTypeAndAddCmd( pmapName, cmapName, prio, 'mac',
                                         macRuleCmd )

   def savePMapClass( self, pmapName, cmapName, prio ):
      pmapMode = self.root[ PolicyMapConfigMode ].getOrCreateModeInstance( \
                            ( self.mapType, self.mapStr, pmapName ) )
      if self._rawClassMap( cmapName ):
         self.saveRawClassMap( pmapName, cmapName, prio )
      else:
         configPmapClassMode = pmapMode[ PolicyMapClassConfigMode ].\
                               getOrCreateModeInstance( ( self.mapType,
                                                          self.mapStr, 
                                                          pmapName,
                                                          cmapName, prio, None ) )
         self.cmds = configPmapClassMode[ 'TapAgg.pmapc' ]
         self.savePMapClassAction( cmapName )

   def savePMapClassAll( self, pmapName ):
      self.pmap = self.entity.pmapType.pmap[ pmapName ].currCfg
      if self.pmap:
         for prio, cmap in self.pmap.classPrio.iteritems():
            self.savePMapClass( pmapName, cmap, prio )

   def savePMap( self, pmapName ):
      pmapMode = self.root[ PolicyMapConfigMode ].getOrCreateModeInstance( \
                            ( self.mapType, self.mapStr, pmapName ) )
      self.cmds = pmapMode[ 'TapAgg.pmap' ]
      self.savePMapClassAll( pmapName )

   def savePMapAll( self ):
      # display all pmaps in sorted order of names
      pmapNames = sorted( self.entity.pmapType.pmap.keys() )
      for pmapName in pmapNames:
         self.savePMap( pmapName )

   def saveCMapMatch( self, cmapName, option ):
      cmapMatch = self.cmap.match[ option ]
      mapOptions = { 'matchIpAccessGroup' : 'ip', 'matchIpv6AccessGroup' : 'ipv6',
                     'matchMacAccessGroup' : 'mac' }
      for prio, aclName in cmapMatch.acl.iteritems():
         self.cmds.addCommand( '%s match %s access-group %s' % \
                               ( prio, mapOptions[ option ], aclName ) )
   def saveCMapMatchAll( self, cmapName ):
      self.cmap = self.entity.cmapType.cmap[ cmapName ].currCfg
      if self.cmap:
         for option in self.cmap.match:
            self.saveCMapMatch( cmapName, option )

   def _rawClassMap( self, cmapName ):
      return cmapName in self.pmap.rawClassMap

   def saveCMap( self, cmapName ):
      cmapMode = self.root[ ClassMapConfigMode ].getOrCreateModeInstance( \
                          ( self.mapType, self.mapStr, cmapName ) )
      self.cmds = cmapMode[ 'TapAgg.cmap' ]
      self.saveCMapMatchAll( cmapName )

   def saveCMapAll( self ):
      # display all class maps in sorted order of names
      cmapNames = sorted( self.entity.cmapType.cmap.keys() )
      for cmapName in cmapNames:
         self.saveCMap( cmapName )

   def saveServicePolicy( self ):
      for intfName, pmap in self.intfConfig.intf.iteritems():
         intfMode = self.root[ IntfCliSave.IntfConfigMode ].\
                         getOrCreateModeInstance( intfName )
         self.cmds = intfMode[ 'TapAgg.intfconfig' ]
         self.cmds.addCommand( 'service-policy type tapagg input %s' % pmap )

   def save( self ):
      self.savePMapAll()
      self.saveCMapAll()
      self.saveServicePolicy()

@CliSave.saver( 'TapAgg::PmapConfig', 'tapagg/pmapconfig',
                requireMounts = ( 'bridging/config','tapagg/cliconfig',
                   'tapagg/intfconfig' ) )
def saveConfig( entity, root, sysdbRoot, options, requireMounts ):
   cliDumper = TapAggPmapCliSaver( entity, root, sysdbRoot, options, requireMounts )
   cliDumper.save()


