#!/usr/bin/env python
# Copyright (c) 2007-2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import CliSave, Vlan, Tracing, MultiRangeRule, EthIntfUtil
from IntfCliSave import IntfConfigMode
from CliMode.Stp import MstMode

__defaultTraceHandle__ = Tracing.Handle( 'StpCli' )

# Import Stp TAC object accessors, which create the objects if needed.
from StpCliUtil import *
from StpConst import *
from Toggles.StpToggleLib import toggleStpSuperRootEnabled

class StpConfigMode( MstMode, CliSave.Mode ):

   def __init__( self, param ):
      MstMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

   def skipIfEmpty( self ):
      return True

# Used to generate a block of spanning tree global level config commands,
# which I want to come out before any interface level config.
CliSave.GlobalConfigMode.addCommandSequence( 'Stp.global', 
                                             before=[ IntfConfigMode ] )

# Used to generate a block of spanning tree config level commands
CliSave.GlobalConfigMode.addChildMode( StpConfigMode, 
                                       before=[ IntfConfigMode ] )
StpConfigMode.addCommandSequence( 'Stp.config' )

# Used to generate a block of spanning tree interface config level commands
IntfConfigMode.addCommandSequence( 'Stp.intf' )

# When someone types "show running-config", the Cli code walks the sysdb
# object tree, calling all the functions which registered with CliSave.saver
# to be called for a particular object. Below, I register
# saveStpConfig to be called for 'stp/input/config/cli'. When I'm called, I walk our
# entire spanning tree config object tree, generating all non-default config.

@CliSave.saver( 'Stp::Input::Config', 'stp/input/config/cli',
                requireMounts = ( 'bridging/input/config/cli',
                                  'bridging/config' ) )
def saveStpConfig( entity, root, sysdbRoot, options,
                   requireMounts ):

   cmds = root[ 'Stp.global' ] 
   saveAll = options.saveAll
   saveAllDetail = options.saveAllDetail
   # We always save the configured protocol version. At FCS, our default
   # was 'rstp'. However, we planned to change this to 'mstp' later, so
   # we have always saved the stp mode since day one.  This way, there
   # is no ambiguity about what the default is when reading the startup
   # config.  Now that we support mstp, it is the default.
   cmds.addCommand( 'spanning-tree mode %s' %
                    stpVersionCliString( entity ) )

   # The 'spanning-tree vlan-id <vlan list>' command is only saved in its 'no'
   # form.
   for s in Vlan.vlanSetToCanonicalStringGen( entity.disabledVlan.keys(), 40 ):
      cmds.addCommand( 'no spanning-tree vlan-id ' + s )

   mstConfigSpec = entity.mstConfigSpec
   defaultMstConfigSpec = Tac.Value(
      "Stp::MstConfigSpec", regionId='', configRevision=0, vidToMstiMap='' )

   # Creating the instance for the spanning-tree mst config mode
   mode = root[ StpConfigMode ].getSingletonInstance()
   configCmds = mode[ 'Stp.config' ]

   if mstConfigSpec != defaultMstConfigSpec or saveAll:
      if mstConfigSpec.regionId:
         configCmds.addCommand( 'name %s' % mstConfigSpec.regionId )
      elif saveAll:
         configCmds.addCommand( 'no name' )

      if mstConfigSpec.configRevision != 0 or saveAll:
         configCmds.addCommand( 'revision %d' % mstConfigSpec.configRevision )
      if mstConfigSpec.vidToMstiMap != '':
         currentMstConfig = MstConfig( entity )
         currentMstConfig.revert()
         vlanMap = currentMstConfig.vlanMapDict()
         idMap = inverseVlanMap( vlanMap, includeInst0=False )
         for instId in sorted( idMap.keys() ):
            prefix = 'instance %d vlan ' % instId
            for vlanSubStr in \
                    Vlan.vlanSetToCanonicalStringGen( idMap[ instId ],
                                                      80 - len( prefix ) ):
               configCmds.addCommand( '%s %s' % (prefix, vlanSubStr) )   

   if entity.bridgeHelloTime != entity.defaultBridgeHelloTime or saveAll:
      cmds.addCommand( 'spanning-tree hello-time %d' % entity.bridgeHelloTime )

   if entity.maxAge != entity.defaultMaxAge or saveAll:
      cmds.addCommand( 'spanning-tree max-age %d' % entity.maxAge )
      
   if entity.bridgeForwardDelay != entity.defaultBridgeForwardDelay \
      or saveAll:
      cmds.addCommand( 'spanning-tree forward-time %d' % (
         entity.bridgeForwardDelay ) )
      
   if entity.txHoldCount != entity.defaultTxHoldCount:
      cmds.addCommand( 'spanning-tree bpdu tx hold-count %d' %
                       entity.txHoldCount )
   elif saveAll:
      cmds.addCommand( 'spanning-tree bpdu tx hold-count %d' % (
         entity.txHoldCount ) )

   if entity.maxHops != entity.defaultMaxHops or saveAll:
      cmds.addCommand( 'spanning-tree max-hops %d' % entity.maxHops )

   if entity.mstPvstBorder != entity.mstPvstBorderDefault:
      cmds.addCommand( 'spanning-tree mst pvst border' )
   elif saveAll:
      cmds.addCommand( 'no spanning-tree mst pvst border' )

   #
   # Global bpduguard config for portfast interfaces
   # Global portchannel guard config
   #

   if entity.portfastBpduguard != entity.portfastBpduguardDefault:
      cmds.addCommand( 'spanning-tree edge-port bpduguard default' )
   elif saveAll:
      cmds.addCommand( 'no spanning-tree edge-port bpduguard default' )

   #
   # Global bpdufilter config for portfast interfaces
   #
   if entity.portfastBpdufilter != entity.portfastBpdufilterDefault:
      cmds.addCommand( 'spanning-tree edge-port bpdufilter default' )
   elif saveAll:
      cmds.addCommand( 'no spanning-tree edge-port bpdufilter default' )

   if not entity.bridgeAssurance:
      cmds.addCommand( 'no spanning-tree transmit active' )
   elif saveAll:
      cmds.addCommand( 'spanning-tree transmit active' )

   if entity.portChannelGuardEnabled:
      cmds.addCommand( 'spanning-tree portchannel guard misconfig' )
   elif saveAll:
      cmds.addCommand( 'no spanning-tree portchannel guard misconfig' )

   if not entity.rateLimitEnabled:
      cmds.addCommand( 'no spanning-tree bpduguard rate-limit default' )
   elif saveAll:
      cmds.addCommand( 'spanning-tree bpduguard rate-limit default' )

   if entity.rateLimitMaxCount:
      if entity.rateLimitInterval:
         cmds.addCommand( 
         ( 'spanning-tree bpduguard rate-limit count %d interval %d' ) %
         (entity.rateLimitMaxCount , entity.rateLimitInterval) )
      else:
         cmds.addCommand( ('spanning-tree bpduguard rate-limit count %d' ) %
         (entity.rateLimitMaxCount) )

   if entity.spanTreeLogging == 'off':
      cmds.addCommand( 'no logging event spanning-tree global' )
   elif saveAll:
      cmds.addCommand( 'logging event spanning-tree global' )

   def compareGroupList( intv1, intv2 ):
      """Sort multirange list by its first element, e.g. [1,10] < [2,3]"""
      if len( intv1 ) == 0 or len( intv2 ) == 0:
         return cmp( len( intv1 ), len( intv2 ) )
      return cmp( intv1[ 0 ], intv2[ 0 ] )
   # This will generate a range string (ex: 5-10,13,15-17) for each
   # key, and return a list of the key,string pairs
   def generateIntervals( groupedList ):
      # Print mstp/rstp/pvst bridge priorities grouped together
      intervals = []
      # groupedList is a dict of key/pair <metric>: <list of instance IDs>
      # The CLI is a list of "spanning-tree ... <list of ID> ... <metric>"
      # and is sorted on <list of ID>
      sortedGroupedList = sorted( groupedList.items(), 
                                  cmp=compareGroupList,
                                  key=lambda gl:gl[1] )
      for ( key, values ) in sortedGroupedList:
         totalStr = MultiRangeRule.multiRangeToCanonicalString( values )
         intervals.append( (key,totalStr) )
      return intervals

   # Unlike the other root commands, this one is not an alias.
   if toggleStpSuperRootEnabled():
      if entity.superRoot:
         cmds.addCommand( "spanning-tree root super" )
      elif saveAll:
         cmds.addCommand( "no spanning-tree root super" )

   # Dictionaries used to remember vlan settings and aggregate
   # them into a single CliSave comamnd using vlan ranges
   pvstBridgePriority = {}
   pvstPortPriority = {}
   pvstExternalCost = {}

   # Dictionaries used to remember mst instance settings and aggregate
   # them into a single CliSave comamnd using mst ranges
   mstpBridgePriority = {}
   mstpInternalCost = {}
   mstpPortPriority = {}

   if saveAll:
      # List of instances for config are:
      # instances in stp/input/cli, default Mst instance and  pvst instances
      # for the configured vlans.
      vlanConfig = requireMounts[ 'bridging/input/config/cli' ].vlanConfig
      pvstInstNames = [ pvstInstName(vlanId ) for vlanId in vlanConfig.keys() ]
      mstInstNames = [ MstStpiName ]
      stpiNames = sorted( set( entity.stpiConfig.keys() + \
                               pvstInstNames + mstInstNames ), cmpStpiNames ) 
   else:
      stpiNames = sorted( entity.stpiConfig.keys(), cmpStpiNames )

   for stpiName in stpiNames:
      stpiConfig = entity.stpiConfig.get( stpiName )
      if not stpiConfig:
         if saveAll:
            # Create stpiConfig so that defaults are displayed for it.
            stpiConfig = Tac.newInstance( 'Stp::Input::StpiConfig', stpiName )
            if isPvstInstName( stpiName ):
               instName = stpiName
               instId = pvstInstNameToVlanId( stpiName )
            else:
               instName = stpMstiInstName()
               instId = 0
            mstiConfig = stpiConfig.mstiConfig.newMember( instName, instId )
         else:
            continue
      for key in sorted( stpiConfig.mstiConfig.keys(), cmpMstiNames ):

         mstiConfig = stpiConfig.mstiConfig.get( key )
         if not mstiConfig:
            continue
         instanceId = mstiConfig.instanceId

         # The cist vs other msti gets a little hairy here, because we use it for
         # both rstp and for Mst0 in mstp mode.  For the bridge priority and port
         # priority, specifying no instance and instance 0 both alias to the same
         # MstiConfig (the 'Cist' instance).  For the path cost and port priority
         # commands, we do not alias them.  Instead the version without an instance
         # sets the external path cost (and default internal cost) and the version
         # with an instance sets the internal path cost.
         # There is no internal path cost for rstp, so this works out well.
         # Port priorities are basically the same.
         
         mstiName = key
         if (mstiName == CistName) and (entity.forceProtocolVersion == 'rstp'):
            priority = mstiConfig.bridgePriority
            if priority != BridgePriorityDefault or saveAll:
               cmds.addCommand( 'spanning-tree priority %d' % priority )
         elif isPvstInstName( mstiName ):
            priority = mstiConfig.bridgePriority
            if priority != BridgePriorityDefault or saveAllDetail or \
               ( saveAll and entity.forceProtocolVersion == 'rapidPvstp' ):
               pvstBridgePriority.setdefault( priority, [] ).append( instanceId )
         else:
            priority = mstiConfig.bridgePriority
            if priority != BridgePriorityDefault or saveAllDetail or \
               ( saveAll and entity.forceProtocolVersion == 'mstp' ):
               mstpBridgePriority.setdefault( priority, [] ).append( instanceId )

         mstiPortConfigs = mstiConfig.mstiPortConfig
         for intfKey in mstiPortConfigs.keys():
         
            mstiPortConfig = mstiPortConfigs.get( intfKey )
            if not mstiPortConfig:
               continue

            saveIntPathCost = (mstiPortConfig.intPathCost != PathCostUnknown)
            savePortPriority = (mstiPortConfig.portPriority !=
                                PortPriorityUnconfigured)
            if( not (saveIntPathCost or savePortPriority) ):
               # Don't instantiate the ModeInstance if we have no commands
               # to generate.
               continue
         
            mode = root[ IntfConfigMode ].getOrCreateModeInstance( \
                   mstiPortConfig.name )
            intfCmds = mode[ 'Stp.intf' ]
      
            cost = mstiPortConfig.intPathCost
            if saveIntPathCost:
               # We should never have an internal path cost on an mstp instance
               assert not isPvstInstName( mstiName )
               mstpInternalCost.setdefault( mstiPortConfig.name, {} ).setdefault\
                   ( cost, [] ).append( instanceId )

            portPriority = mstiPortConfig.portPriority
            if savePortPriority:
               if isPvstInstName( mstiName ):
                  pvstPortPriority.setdefault( mstiPortConfig.name, {} ).setdefault\
                      ( portPriority, [] ).append( instanceId )
               else:
                  mstpPortPriority.setdefault( mstiPortConfig.name, {} ).setdefault\
                      ( portPriority, [] ).append( instanceId )
      for stpiPortName in stpiConfig.stpiPortConfig.keys():

         stpiPortConfig = stpiConfig.stpiPortConfig[ stpiPortName ]

         saveExtPathCost = (stpiPortConfig.extPathCost != PathCostUnknown)

         if not saveExtPathCost:
            # Don't instantiate the ModeInstance if we have no commands
            # to generate.
            continue

         mode = root[ IntfConfigMode ].getOrCreateModeInstance(
            stpiPortConfig.name )
         intfCmds = mode[ 'Stp.intf' ]
         
         # For mst, the external path cost is set via the 'spanning-tree cost'
         # command.  For pvst instances, it is set via the
         # 'spanning-tree vlan <vlanId> cost' command.
         if isPvstInstName( stpiName ):
            instanceId = pvstInstNameToVlanId( stpiName )
            pvstExternalCost.setdefault( stpiPortConfig.name, {} ).setdefault\
                ( stpiPortConfig.extPathCost, [] ).append( instanceId )
         else:
            # There is only one external cost per port for all mstis
            intfCmds.addCommand( 'spanning-tree cost %d' % \
                               (stpiPortConfig.extPathCost) )


   # Print all of the aggregated vlan ranges for bridge priority settings
   for (priority,vlansStr) in generateIntervals( pvstBridgePriority ):
      cmds.addCommand( 'spanning-tree vlan-id %s priority %d' %
                       ( vlansStr, priority ) )

   # Print all of the aggregated vlan ranges for external cost for each port
   for portName in pvstExternalCost:
      mode = root[ IntfConfigMode ].getOrCreateModeInstance(
            portName )
      intfCmds = mode[ 'Stp.intf' ]
      for (cost,vlansStr) in generateIntervals( \
                   pvstExternalCost[ portName ] ):
         intfCmds.addCommand( 'spanning-tree vlan %s cost %d' % \
                            ( vlansStr, cost ) )

   # Print all of the aggregated vlan ranges for port-priority for each port
   for portName in pvstPortPriority:
      mode = root[ IntfConfigMode ].getOrCreateModeInstance(\
           portName )
      intfCmds = mode[ 'Stp.intf' ]
      for ( portPriority,vlansStr ) in generateIntervals( \
            pvstPortPriority[ portName ] ):
         intfCmds.addCommand( 'spanning-tree vlan %s port-priority %d' % \
                      ( vlansStr, portPriority ) )


   # Print all of the aggregated mst ranges for bridge priority settings
   for (priority,vlansStr) in generateIntervals( mstpBridgePriority ):
      cmds.addCommand( 'spanning-tree mst %s priority %d' % \
                         ( vlansStr, priority ) )

   # Print all of the aggregated mst ranges for internal cost for each port
   for portName in mstpInternalCost:
      mode = root[ IntfConfigMode ].getOrCreateModeInstance(
          portName )
      intfCmds = mode[ 'Stp.intf' ]
      for ( cost,rangeStr ) in generateIntervals( mstpInternalCost[ portName ] ):
         intfCmds.addCommand( 'spanning-tree mst %s cost %d' % \
                           ( rangeStr, cost ) )

   # Print all of the aggregated mst ranges for port-priority for each port
   for portName in mstpPortPriority:
      mode = root[ IntfConfigMode ].getOrCreateModeInstance(\
           portName )
      intfCmds = mode[ 'Stp.intf' ]
      for ( portPriority,rangeStr ) in generateIntervals( \
                                       mstpPortPriority[ portName ] ):
         intfCmds.addCommand( 'spanning-tree mst %s port-priority %d' % \
                      ( rangeStr, portPriority ) )

   # port Configs
   bridgingConfig = requireMounts[ 'bridging/config' ]
   if saveAllDetail:
      cfgPorts = EthIntfUtil.allSwitchportNames( bridgingConfig,
                                                 includeEligible=True )
   elif saveAll:
      # We allow L2 configuration on routed ports as well. 
      # Display STP configs for switchport as well as routed ports present
      # in entity.portConfigs.
      swPorts = EthIntfUtil.allSwitchportNames( bridgingConfig )
      cfgPorts = set( swPorts + entity.portConfig.keys() )
   else:
      cfgPorts = entity.portConfig

   for portName in cfgPorts:
      portConfig = entity.portConfig.get( portName )
      if not portConfig:
         if saveAll:
            portConfig = Tac.newInstance( 'Stp::Input::PortConfig', portName )
         else:
            continue
      saveStpIntfConfig( portConfig, root, sysdbRoot, saveAll ) 


def saveStpIntfConfig( portConfig, root, sysdbRoot, saveAll ):
   saveAdminEdgePort = (portConfig.adminEdgePort !=
                        portConfig.adminEdgePortDefault)
   saveAutoEdgePort = (portConfig.autoEdgePort != portConfig.autoEdgePortDefault)
   saveNetworkPort = (portConfig.networkPort != portConfig.networkPortDefault)
   saveAdminPointToPoint = (portConfig.adminPointToPointMac !=
                            portConfig.defaultAdminPointToPoint)
   saveBpduguard = (portConfig.bpduguard != 'bpduguardDefault')
   saveGuard = (portConfig.guard != 'guardDefault')
   saveRateLimitEnabled = (portConfig.rateLimitEnabled != 'rateLimitDefault')
   saveRateLimitSettings = (portConfig.rateLimitMaxCount != 0)
   saveRateLimitSettings |= (portConfig.rateLimitInterval != 0)
   saveBpdufilter = (portConfig.bpdufilter != 'bpdufilterDefault')
   savePortPriority = (portConfig.portPriority != PortPriorityDefault)
   saveExtPathCost = (portConfig.extPathCost != PathCostUnknown)
   saveSpanTreeLogging = (portConfig.spanTreeLogging != 'useGlobal')
      
   if( not (saveAdminEdgePort or saveAutoEdgePort or saveAdminPointToPoint or
            saveBpduguard or saveGuard or savePortPriority or saveBpdufilter or
            saveRateLimitEnabled or saveRateLimitSettings or saveExtPathCost or
            saveNetworkPort or saveSpanTreeLogging or saveAll) ):
      # Don't instantiate the ModeInstance if we have no commands
      # to generate.
      return

   mode = root[ IntfConfigMode ].getOrCreateModeInstance( portConfig.name )
   intfCmds = mode[ 'Stp.intf' ]

   # We've done something a bit confusing, where the default spanning tree
   # config is adminEdge==false, autoEdge==true, because that seems best.
   # This means that the default config is:
   #
   # no spanning-tree portfast    # adminEdge
   # spanning-tree portfast auto  # autoEdge
   #
   # and the non-default config is:
   #
   # spanning-tree portfast    # adminEdge
   # no spanning-tree portfast auto  # autoEdge
   #
   if portConfig.adminEdgePort != portConfig.adminEdgePortDefault:
      intfCmds.addCommand( 'spanning-tree portfast' )
   elif saveAll:
      intfCmds.addCommand( 'no spanning-tree portfast' )

   if portConfig.autoEdgePort != portConfig.autoEdgePortDefault:
      intfCmds.addCommand( 'no spanning-tree portfast auto' )
   elif saveAll:
      intfCmds.addCommand( 'spanning-tree portfast auto' )

   if portConfig.networkPort != portConfig.networkPortDefault:
      intfCmds.addCommand( 'spanning-tree portfast network' )

   if saveAdminPointToPoint:
         
      if portConfig.adminPointToPointMac == 'forceTrue':
         typeStr = 'point-to-point'
      elif portConfig.adminPointToPointMac == 'forceFalse':
         typeStr = 'shared'
         
      intfCmds.addCommand( 'spanning-tree link-type %s' % typeStr )
   elif saveAll:
      intfCmds.addCommand( 'no spanning-tree link-type' )

   #
   # interface bpduguard config
   #
   if portConfig.bpduguard == 'bpduguardEnabled':
      intfCmds.addCommand( 'spanning-tree bpduguard enable' )
   elif portConfig.bpduguard == 'bpduguardDisabled':
      intfCmds.addCommand( 'spanning-tree bpduguard disable' )
   elif saveAll:
      intfCmds.addCommand( 'no spanning-tree bpduguard' )

   if portConfig.bpdufilter == 'bpdufilterEnabled':
      intfCmds.addCommand( 'spanning-tree bpdufilter enable' )
   elif portConfig.bpdufilter == 'bpdufilterDisabled':
      intfCmds.addCommand( 'spanning-tree bpdufilter disable' )
   elif saveAll:
      intfCmds.addCommand( 'no spanning-tree bpdufilter' )

   if saveExtPathCost:
      pathCost = portConfig.extPathCost
      intfCmds.addCommand( 'spanning-tree cost %d' % (pathCost) )
   elif saveAll:
      intfCmds.addCommand( 'no spanning-tree cost' )
   if savePortPriority or saveAll:
      priority = portConfig.portPriority
      intfCmds.addCommand( 'spanning-tree port-priority %d' % (priority) )

   if portConfig.guard == 'rootguardEnabled':
      intfCmds.addCommand( 'spanning-tree guard root' )
   elif portConfig.guard == 'guardDisabled':
      intfCmds.addCommand( 'spanning-tree guard none' )
   elif saveAll:
      intfCmds.addCommand( 'no spanning-tree guard' )

   if portConfig.rateLimitEnabled == 'rateLimitOn':
      intfCmds.addCommand( 'spanning-tree bpduguard rate-limit enable' )
   elif portConfig.rateLimitEnabled == 'rateLimitOff':
      intfCmds.addCommand( 'spanning-tree bpduguard rate-limit disable' )
   elif saveAll:
      intfCmds.addCommand( 'no spanning-tree bpduguard rate-limit' )
   
   if portConfig.rateLimitMaxCount:
      if portConfig.rateLimitInterval:
         intfCmds.addCommand( 
         ('spanning-tree bpduguard rate-limit count %d interval %d' ) %
         (portConfig.rateLimitMaxCount, portConfig.rateLimitInterval)
         )
      else:
         intfCmds.addCommand( 
         ('spanning-tree bpduguard rate-limit count %d') %
         (portConfig.rateLimitMaxCount)
         )

   if portConfig.spanTreeLogging == 'on':
      intfCmds.addCommand( 'logging event spanning-tree' )
   elif portConfig.spanTreeLogging == 'off':
      intfCmds.addCommand( 'no logging event spanning-tree' )
   elif portConfig.spanTreeLogging == 'useGlobal' and saveAll:
      intfCmds.addCommand( 'logging event spanning-tree use-global' )

