#!/usr/bin/env python
# Copyright (c) 2007-2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

#-------------------------------------------------------------------------------
# This module implements Stp clear commands
#
# clear spanning-tree counters [ interface <interface> ]
# clear spanning-tree detected-protocols [ interface <interface> ]
#
#-------------------------------------------------------------------------------
'''Clear commands supported for Spanning Tree'''

import BasicCli
import LazyMount
import ConfigMount
import Intf.Log
# Import Stp TAC object accessors, which create TAC objects if needed.
from StpCliUtil import Tac
from StpCliUtil import Tracing
from StpCliUtil import ( stpStatus, stpStatusIs, stpConfigIs,
      stpInputConfig, stpInputConfigIs, stpCounterName, stpPortCounterDir,
      stpPortCounterDirIs )

__defaultTraceHandle__ = Tracing.Handle( 'StpCli' )

enableMode = BasicCli.EnableMode

#-------------------------------------------------------------------------------
# clear spanning-tree counters [ interface <interface> ]
#-------------------------------------------------------------------------------
def snapshot( portStatus ):
   counters = portStatus.counters
   return Tac.Value(
      "Stp::PortCounters",
      bpduRx=counters.bpduRx,
      bpduTx=counters.bpduTx,
      bpduTaggedError=counters.bpduTaggedError,
      bpduOtherError=counters.bpduOtherError,
      timestamp=Tac.now() )
   
def clearCounters( mode, args ):
   intf = args.get( 'INTERFACE' )
   status = stpStatus()
   globalPortCounterDir = stpPortCounterDir()
   if intf is None:
      # clear counters on all ports
      Intf.Log.logClearCounters( "spanning tree", "all interfaces" )
      for stpiName in status.stpiStatus.keys():
         stpiStatus = status.stpiStatus.get( stpiName )
         if not stpiStatus:
            continue
         for key in stpiStatus.stpiPortStatus.keys():
            portStatus = stpiStatus.stpiPortStatus.get( key )
            if not portStatus:
               continue
            counterName = stpCounterName( stpiName, key )
            globalPortCounterDir.portCounters[ counterName ] = \
               snapshot( portStatus )
   else:
      intfName = intf.name
      Tracing.trace9( "Clear counters for ", intfName )
      Intf.Log.logClearCounters( "spanning tree", "interface " + intfName )
      for stpiName in status.stpiStatus.keys():
         stpiStatus = status.stpiStatus.get( stpiName )
         if not stpiStatus:
            continue
         portStatus = stpiStatus.stpiPortStatus.get( intfName, None )
         counterName = stpCounterName( stpiName, intfName )
         if portStatus:
            globalPortCounterDir.portCounters[ counterName ] = snapshot( portStatus )
         else:
            Tracing.trace0( "PortStatus", counterName, "doesn\'t exist" )

#-------------------------------------------------------------------------------
# clear spanning-tree counters session
#-------------------------------------------------------------------------------
def clearCountersSession( mode, args ):
   sessionPortCounter = mode.session.sessionData( 'Stp.sessionPortCounter', None )
   if sessionPortCounter is None:
      sessionPortCounter = {}
      mode.session.sessionDataIs( 'Stp.sessionPortCounter', sessionPortCounter )
   status = stpStatus()
      
   # clear counters on all ports in all stpi
   for stpiName in status.stpiStatus.keys():
      stpiStatus = status.stpiStatus.get( stpiName )
      if not stpiStatus:
         continue
      for key in stpiStatus.stpiPortStatus.keys():
         portStatus = stpiStatus.stpiPortStatus.get( key )
         if not portStatus:
            continue
         counterName = stpCounterName( stpiName, key )
         sessionPortCounter[ counterName ] = snapshot( portStatus )

#-------------------------------------------------------------------------------
# clear spanning-tree detected-protocols [ interface <interface> ]
#-------------------------------------------------------------------------------
def clearProtocols( mode, args ):
   intf = args.get( 'INTERFACE' )
   config = stpInputConfig()
   if intf is None:
      # clear protocol on all ports
      Tracing.trace9( "Clear protocols for all ports" )
      for key in config.portConfig.keys():
         portConfig = config.portConfig[ key ]
         portConfig.mcheckVer += 1
   else:
      intfName = intf.name
      Tracing.trace9( "Clear protocols for ", intfName )
      portConfig = config.portConfig.get( intfName, None )
      if portConfig:
         portConfig.mcheckVer += 1         
      else:
         Tracing.trace0( "PortConfig", intfName, "doesn\'t exist" )

#-------------------------------------------------------------------------------
# Have the Cli Agent mount all needed state from sysdb
#-------------------------------------------------------------------------------
def Plugin( entityManager ):
   status = LazyMount.mount( entityManager, "stp/status", "Stp::Status", "r" )
   stpStatusIs( status )
   config = LazyMount.mount( entityManager, "stp/config", "Stp::Config", "r" )
   stpConfigIs( config )
   cliConfig = ConfigMount.mount( entityManager, "stp/input/config/cli",
                                  "Stp::Input::Config", "w" )
   stpInputConfigIs( cliConfig )
   portCounterDir = LazyMount.mount( entityManager, "stp/counter",
                                     "Stp::PortCounterDir", "w" )
   stpPortCounterDirIs( portCounterDir )
