#!/usr/bin/env python
# Copyright (c) 2007, 2008, 2009, 2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

#-------------------------------------------------------------------------------
# Stp TAC object accessors for Stp CLI, creating the TAC objects if necessary.
#-------------------------------------------------------------------------------

import struct
import Tac, Tracing, Vlan
from StpConst import *
import Arnet

__defaultTraceHandle__ = Tracing.Handle( 'StpCli' )

CistName = "Cist"
CistCliName = "MST0"
MstStpiName = "Mst"
PvstInstNamePrefix = "Vl"
EpochDisabled = Tac.Type( "HwEpoch::RestrictionStatus::Mode" ).disabled

_stpConfig = None
_stpTxRxInputConfig = None
_stpCliInfo = None
_stpInputConfig = None
_stpInputConfigReq = None
_stpStatus = None
_stpPortCounterDir = None
_bridgingConfig = None
_hwEpochStatus = None

def stpConfigIs( config ):
   global _stpConfig
   _stpConfig = config

def stpConfig():
   assert _stpConfig is not None
   return _stpConfig

def stpInputConfigIs( config ):
   global _stpInputConfig
   _stpInputConfig = config

def stpInputConfig():
   assert _stpInputConfig is not None
   return _stpInputConfig

def stpTxRxInputConfig():
   assert _stpTxRxInputConfig is not None
   return _stpTxRxInputConfig

def stpTxRxInputConfigIs( config ):
   global _stpTxRxInputConfig
   _stpTxRxInputConfig = config

def stpInputConfigReqIs( configReq ):
   global _stpInputConfigReq
   _stpInputConfigReq = configReq

def stpInputConfigReq():
   assert _stpInputConfigReq is not None
   return _stpInputConfigReq

def stpStatusIs( status ):
   global _stpStatus
   _stpStatus = status

def stpStatus():
   assert _stpStatus is not None
   return _stpStatus

def stpCliInfo():
   assert _stpCliInfo is not None
   return _stpCliInfo

def stpCliInfoIs( status ):
   global _stpCliInfo
   _stpCliInfo = status

def stpPortCounterDirIs( counterDir ):
   global _stpPortCounterDir
   _stpPortCounterDir = counterDir

def stpPortCounterDir():
   assert _stpPortCounterDir is not None
   return _stpPortCounterDir

def bridgingConfigIs( config ):
   global _bridgingConfig
   _bridgingConfig = config

def bridgingConfig():
   assert _bridgingConfig is not None
   return _bridgingConfig

def pvstInstName( vlanId ):
   return PvstInstNamePrefix + str( vlanId )

def isPvstInstName( name ):
   return name.startswith( PvstInstNamePrefix )

def pvstInstNameToVlanId( name ):
   vlanId = int( name[ len( PvstInstNamePrefix ) : ] )
   return vlanId

def stpVersionCliString( config ):
   ver = config.forceProtocolVersion
   if ver in [ 'none', 'rstp', 'mstp', 'backup' ]:
      return ver
   if ver == 'rapidPvstp':
      return 'rapid-pvst'
   assert False, "unkown protocol version %s" % ver
   
def stpMstiInstName( instId=None ):
   if instId is None:
      # Didn't specify an "mst <instanceId>", so I assume "Rstp" instance
      return CistName
   else:
      # This names the Msti instances 'Cist', 'Mst1', etc. They will then be
      # created with these names.
      if instId == 0:
         return CistName
      return( "Mst" + str( instId ) )

def stpMstiConfig( instId, create=False ):
   config = stpInputConfig()
   stpiConfig = config.stpiConfig.get( MstStpiName, None )
   if stpiConfig is None:
      if not create:
         return None
      stpiConfig = config.stpiConfig.newMember( MstStpiName )
   instName = stpMstiInstName( instId )
   if instId is None:
      instId = 0

   mstiConfigs = stpiConfig.mstiConfig
   if instName in mstiConfigs:
      mstiConfig = mstiConfigs[ instName ]
   elif create:
      Tracing.trace2( "Creating MstiConfig", instName )
      mstiConfig = mstiConfigs.newMember( instName, instId )
      # TBD MST - Could fail if we're creating the 65th Mst instance? Need to
      # check and complain?
      if mstiConfig is None:
         Tracing.trace0( "Unable to create MstiConfig", instName )
   else:
      mstiConfig = None
   if mstiConfig is not None:
      assert mstiConfig.instanceId == instId
      assert mstiConfig.name == instName      
   return mstiConfig      

def stpPortConfigByName( intfName, create=False ):
   config = stpInputConfig()
   portConfigs = config.portConfig
   if intfName in portConfigs:
      portConfig = portConfigs[ intfName ]
   elif create:
      Tracing.trace2( "Creating PortConfig", intfName )
      portConfig = portConfigs.newMember( intfName )
      if portConfig is None:
         Tracing.trace0( "Unable to create PortConfig", intfName )
   else:
      portConfig = None
   if portConfig is not None:
      assert portConfig.name == intfName
   return portConfig

def stpPortConfig( mode, create=False ): 
   return stpPortConfigByName( mode.intf.name, create )

def stpStpiPortConfig( mode, stpiName, create=False ): 
   config = stpInputConfig()
   stpiConfig = config.stpiConfig.get( stpiName, None )
   if stpiConfig is None:
      if not create:
         return None
      stpiConfig = config.stpiConfig.newMember( stpiName )

   portConfigs = stpiConfig.stpiPortConfig
   intfName = mode.intf.name
   if intfName in portConfigs:
      portConfig = portConfigs[ intfName ]
   elif create:
      Tracing.trace2( "Creating PortConfig", intfName )
      portConfig = portConfigs.newMember( intfName )
      if portConfig is None:
         Tracing.trace0( "Unable to create PortConfig", intfName )
   else:
      portConfig = None
   if portConfig is not None:
      assert portConfig.name == intfName
   return portConfig

def stpMstiPortConfig( mode, instId, create=False ):
   
   intfName = mode.intf.name
   
   mstiConfig = stpMstiConfig( instId, create )
   if mstiConfig is None:
      return None
   
   mstiPortConfigs = mstiConfig.mstiPortConfig
   if intfName in mstiPortConfigs:
      mstiPortConfig = mstiPortConfigs[ intfName ]
   elif create:
      Tracing.trace2( "Creating MstiPortConfig", mstiConfig.name, intfName )
      mstiPortConfig = mstiPortConfigs.newMember( intfName )
      if mstiPortConfig is None:
         Tracing.trace0( "Unable to create MstiPortConfig ", mstiConfig.name,
                         intfName )
   else:
      mstiPortConfig = None
   if mstiPortConfig is not None:
      assert mstiPortConfig.name == intfName
   return mstiPortConfig


def stpMstiConfigDefaultName():

   # Depending upon the currently configured spanning tree flavor, rstp or
   # mstp, return MstiConfig instance 0's name. Note that there *may* be 2 instances
   # 0, one for Rstp and one for Mst0. This could happen if customer has
   # played around with configuring both of them.

   config = stpConfig()
   if config.forceProtocolVersion == 'mstp':
      instId = 0
   else:
      instId = None
   return stpMstiInstName( instId )

def vlanPortConfig( vlan, portName ):
   # Return the Bridging::VlanPortConfig object for the given vlan and
   # port name if it exists.  Otherwise, return None.
   config = bridgingConfig()
   assert config is not None
   vlanConfig = config.vlanConfig.get( vlan )
   if vlanConfig is None:
      return None
   if vlanConfig.adminState != 'active':
      return None
   vpc = vlanConfig.intf.get( portName )
   return vpc

# Parse a string in the MstConfigSpec.vidToMstiMap format into a dict indexed
# by integer vlanId returning an integer instId.  Any vlan not in the dict is
# implicitly mapped to instId 0.
def parseVidToMstiMap( mapStr ):
   map = {}
   if mapStr == '':
      return map

   offset = 0
   while offset < len( mapStr ):
      (vlanId, mstiId) = struct.unpack( '!HH', mapStr[ offset : offset + 4 ] )
      map[ vlanId ] = mstiId
      offset += 4
   return map

# Return the MstiConfig object for the given vlanId.  This takes the current
# mode into account.  It will return None if the MstiConfig does not already
# exist.
def stpVlanMstiConfig( stpiConfig, vlanId ):
   if stpiConfig is None:
      return None
   config = stpConfig()
   instId = None
   if config.forceProtocolVersion == 'mstp':
      vidToInstId = parseVidToMstiMap( config.mstConfigSpec.vidToMstiMap )
      instId = vidToInstId.get( vlanId, 0 )
      instName = stpMstiInstName( instId )
   elif config.forceProtocolVersion == 'rstp':
      instName = stpMstiInstName( instId )
   elif config.forceProtocolVersion == 'rapidPvstp':
      instName = pvstInstName( vlanId )
   mstiConfig = stpiConfig.mstiConfig.get( instName )
   return mstiConfig

# Return the MstiConfig object for the given vlanId.  This takes the current
# mode into account.  It will return None if the MstiConfig does not already
# exist.
def stpVlanStpiConfigStatus( vlanId ):
   config = stpConfig()
   status = stpStatus()
   version = config.forceProtocolVersion
   if version in [ 'rstp', 'mstp' ]:
      stpiName = MstStpiName
   elif version == 'rapidPvstp':
      stpiName = pvstInstName( vlanId )
   else:
      return (None, None)
   stpiConfig = config.stpiConfig.get( stpiName )
   stpiStatus = status.stpiStatus.get( stpiName )
   return (stpiConfig, stpiStatus)

# Return the pvst MstiConfig for the given vlan.  If the StpiConfig or the
# MstiConfig does not exist, then create each as necessary if requested to do so. 
def pvstVlanMstiConfig( vlanId, create=True ):
   config = stpInputConfig()
   stpiName = pvstInstName( vlanId )
   stpiConfig = config.stpiConfig.get( stpiName )
   if stpiConfig is None:
      if not create:
         return None
      stpiConfig = config.stpiConfig.newMember( stpiName )
      
   mstiConfig = stpiConfig.mstiConfig.get( stpiName )
   if( (mstiConfig is None) and create ):
      mstiConfig = stpiConfig.mstiConfig.newMember( stpiName, vlanId )
   return mstiConfig

# Return the MstiPortConfig for the interface named in the mode for the given
# vlan.  Create the necessary objects if requested to do so, or return None
# if they don't exist.
def pvstVlanMstiPortConfig( mode, vlanId, create=True ):
   mstiConfig = pvstVlanMstiConfig( vlanId, create )
   if mstiConfig is None:
      return None
   intfName = mode.intf.name
   mstiPortConfig = mstiConfig.mstiPortConfig.get( intfName )
   if mstiPortConfig:
      return mstiPortConfig
   if not create:
      return None
   Tracing.trace2( "Creating MstiPortConfig", mstiConfig.name, intfName )
   mstiPortConfig = mstiConfig.mstiPortConfig.newMember( intfName )
   if mstiPortConfig is None:
      Tracing.trace0( "Unable to create MstiPortConfig ", mstiConfig.name,
                      intfName )
   return mstiPortConfig

def mstiCliName( name ):
   if name == CistName:
      name = CistCliName
   return name.upper()
   
def mstiStatusCliName( mstiStatus ):
   name = mstiStatus.name
   return mstiCliName( name )

def isCist( mstiStatus ):
   return (mstiStatus.name == CistName) or \
          mstiStatus.name.startswith( PvstInstNamePrefix )

def vidToMstiMapEntry( vid, mstiId ):
   return struct.pack( '!HH', vid, mstiId )

class MstConfig( object ):
   # This class encapsulates the logic for reading and atomically applying changes
   # to the mst configuration, specifically the Stp::Config::mstConfigSpec
   # attribute.

   def __init__( self, stpConfig ):
      self.stpConfig = stpConfig
      self.default()

   def default( self ):
      self.pendingRegionId = ''
      self.pendingConfigRevision = 0
      self.pendingVidToMstiMap = {}
      self.pendingMsti = set()

   def revert( self ):
      current = self.stpConfig.mstConfigSpec
      self.pendingRegionId = current.regionId
      self.pendingConfigRevision = current.configRevision
      self.pendingVidToMstiMap = self.mapStringToMap( current.vidToMstiMap )
      self.pendingMsti = set( self.pendingVidToMstiMap.itervalues() )

   def apply( self ):
      newSpec = Tac.Value( "Stp::MstConfigSpec",
                           regionId=self.pendingRegionId,
                           configRevision=self.pendingConfigRevision,
                           vidToMstiMap=
                              self.mapToMapString( self.pendingVidToMstiMap ) )
      self.stpConfig.mstConfigSpec = newSpec

   def regionIdIs( self, regionId ):
      self.pendingRegionId = regionId

   def regionId( self ):
      return self.pendingRegionId

   def configRevisionIs( self, rev ):
      self.pendingConfigRevision = rev

   def configRevision( self ):
      return self.pendingConfigRevision

   def vlanMapIs( self, vlanIds, instId ):
      if instId != 0:
         for vlanId in vlanIds:
            self.pendingVidToMstiMap[ vlanId ] = instId
         self.pendingMsti.add( instId )
      else:
         for vlanId in vlanIds:
            try:
               del self.pendingVidToMstiMap[ vlanId ]
            except KeyError:
               # It wasn't there.  No big deal.
               pass
         # Rather than being extra smart and using reference counts on the
         # entries in pendingMsti to know when to delete them, I'm just going
         # to recalculate them whenever we delete mappings.
         self.pendingMsti = set( self.pendingVidToMstiMap.itervalues() )

   def vlanMap( self, vlanId ):
      return self.pendingVidToMstiMap.get( vlanId, 0 )

   def vlanMapDict( self ):
      return self.pendingVidToMstiMap

   def inst( self, instId ):
      return instId in self.pendingMsti
   
   def insts( self ):
      # Return the number of unique msti instances in the current config.
      return len( self.pendingMsti )

   def instDel( self, instId ):
      # We're deleting this instance.  Delete all of the mappings to it.  Make
      # sure to pull it from the set of msti as well.
      
      removeSet = set()
      for (vlanId, mstiId) in self.pendingVidToMstiMap.iteritems():
         if mstiId == instId:
            removeSet.add( vlanId )
      for vlanId in removeSet:
         del self.pendingVidToMstiMap[ vlanId ]
      self.pendingMsti.discard( instId )

   def mapStringToMap( self, str ):
      return parseVidToMstiMap( str )
   
   def mapToMapString( self, map ):
      vlanIds = map.keys()
      vlanIds.sort()
      str = ''
      for vlanId in vlanIds:
         str += vidToMstiMapEntry( vlanId, map[ vlanId ] )
      return str

def inverseVlanMap( vlanMap, includeInst0=True ):
   # Figure out which vlans are mapped to which msti.  Instance 0 is
   # special in that it gets all unmapped vlans.  We figure out which
   # ones those are by adding all vlans to the id0Set and then removing
   # them as we find them mapped to some other msti.
   idMap = {}
   id0Set = set( xrange( ValidVlanIdMin, ValidVlanIdMax + 1 ) )
   for (vlanId, instId) in vlanMap.iteritems():
      vlanSet = idMap.get( instId )
      if not vlanSet:
         vlanSet = set()
         idMap[ instId ] = vlanSet
      vlanSet.add( vlanId )
      id0Set.discard( vlanId )
   if includeInst0:
      idMap[ 0 ] = id0Set
   return idMap
   
def printMstConfig( regionId, rev, vlanMap, digest=None ):

   idMap = inverseVlanMap( vlanMap )
   instCount = len( idMap )

   print "Name      [%s]" % regionId
   print "Revision  %-4d Instances configured %d" % (rev, instCount)
   if digest is not None:
      print "Digest        ",
      str = "0x"
      for i in range( 16 ):
         val = getattr( digest, 'configDigest%d' % i )
         str += "%02X" % val
      print str
   else:
      print "\nInstance  Vlans mapped"
      print "-" * 8, "-" * 71
      instIds = idMap.keys()
      instIds.sort()
      for instId in instIds:
         instIdPrinted = False
         for vlanSubStr in Vlan.vlanSetToCanonicalStringGen( idMap[ instId ], 70 ):
            instStr = " " * 9 if instIdPrinted else "%-9d" % instId
            print instStr, vlanSubStr
            instIdPrinted = True
      print "-" * 80

def showCurrentMstConfig( stpConfig, digest=None ):
   
   # Make a new MstConfig object that we can revert to the current config
   # and extract the info from that.
   currentMstConfig = MstConfig( stpConfig )
   currentMstConfig.revert()
   regionId = currentMstConfig.regionId()
   rev = currentMstConfig.configRevision()
   vlanMap = currentMstConfig.vlanMapDict()
   printMstConfig( regionId, rev, vlanMap, digest )

def mstiCliNameToNumber( name ):
   num = int( name[ 3 : ] )
   return num

def cmpMstiNames( m1, m2 ):
   # This is a comparison function to use with sort for msti names.
   m1CliName = mstiCliName( m1 )
   m2CliName = mstiCliName( m2 )

   m1Num = mstiCliNameToNumber( m1CliName )
   m2Num = mstiCliNameToNumber( m2CliName )

   return cmp( m1Num, m2Num )

def stpiNameToNumber( s ):

   if s == MstStpiName:
      return 0

   if s.startswith( PvstInstNamePrefix ) or \
      s.startswith( PvstInstNamePrefix.upper() ):
      num = int( s[ 2 : ] )
      return num

   assert False
   
def cmpStpiNames( s1, s2 ):
   # This is a comparison function to use with sort for stpi names.

   s1Num = stpiNameToNumber( s1 )
   s2Num = stpiNameToNumber( s2 )

   return cmp( s1Num, s2Num )

# Return the index to use in the counter collections for the given port in
# the given stpi.
def stpCounterName( stpiName, portName ):
   return stpiName + ":" + portName


# Get a list of all valid (mstiStatus and mstiPortStatus) in the system
# Helper methods for various cli commands
def allMstiPortStatusPairs( status ):
   mstiPortStatusPairList = []  
   for stpiName in sorted( status.stpiStatus.keys(), cmpStpiNames ):
      stpiStatus = status.stpiStatus.get( stpiName )
      if not stpiStatus:
         continue

      for key in sorted( stpiStatus.mstiStatus.keys(), cmpMstiNames ):
         mstiStatus = stpiStatus.mstiStatus.get( key )
         if not mstiStatus:
            continue

         for portKey in Arnet.sortIntf( mstiStatus.mstiPortStatus ):
            mstiPortStatus = mstiStatus.mstiPortStatus.get( portKey )
            if not mstiPortStatus:
               continue
 
            mstiPortStatusPairList.append( (mstiStatus, mstiPortStatus) )
   return mstiPortStatusPairList

def hwEpochStatusIs( status ):
   global _hwEpochStatus
   _hwEpochStatus = status 

def loopGuardAllowed( epochStatus ):
   # loopGuard always enabled.
   return True

def loopGuardCliGuard( mode, token ):
   # loopGuard always enabled.
   return None
