#!/usr/bin/env python
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Arnet
import Tac
import LazyMount
import SmashLazyMount
import ShowCommand
import CliParser
import CliCommand
import BasicCli
import BasicCliModes
import CliToken.Clear
import MplsCli
from CliPlugin.MplsIngressCountersModel import MplsLfibCounterLabelEntry
from CliPlugin.MplsIngressCountersModel import MplsLfibCounters
from CliPlugin.AleCountersCli import checkCounterFeatureSupported
from CliPlugin.AleCountersCli import checkCounterFeatureEnabled

FeatureId = Tac.Type( 'FlexCounters::FeatureId' )
FapId = Tac.Type( 'FlexCounters::FapId' )
RouteMetric = Tac.Type( "Mpls::RouteMetric" )
MplsLabel = Tac.Type( 'Arnet::MplsLabel' )
RouteKey = Tac.Type( 'Mpls::RouteKey' )
BoundedMplsLabelStack = Tac.Type( 'Arnet::BoundedMplsLabelStack' )

mplsLfibCounterTable = None
mplsLfibSnapshotTable = None
mplsLfibClear = None
mplsHwStatus = None
mplsRoutingInfo = None
fcFeatureConfigDir = None
transitLfib = None


def mplsLfibCountersGuard( mode, token ):
   '''
   Guard the Mpls Lfib cli commands.
   '''
   if not checkCounterFeatureSupported( FeatureId.MplsLfib ):
      return CliParser.guardNotThisPlatform
   return None

nodeCounters = CliCommand.guardedKeyword( 'counters',
   helpdesc="MPLS LFIB counters",
   guard=mplsLfibCountersGuard )

#-------------------------------------------------------------------------
# The "show mpls lfib counters [ { LABELS } ]" command
#-------------------------------------------------------------------------
nodeLabelMatcher = Arnet.MplsLib.labelValMatcher

def willExceedConfiguredCount( labels ):
   maxCount = 1
   if mplsRoutingInfo:
      maxCount = mplsRoutingInfo.mplsLookupLabelCount

   labelsCount = 0
   if labels:
      labelsCount = len( labels )

   return labelsCount > maxCount


def createLabelStack( labels ):
   labelStack = BoundedMplsLabelStack()
   if labels:
      for label in reversed( labels ):
         labelStack.append( label )
   return Tac.const( labelStack )

def mplsLfibSmashCounter( routeKey ):
   counter = mplsLfibCounterTable.counterEntry.get( routeKey )
   snapshot = mplsLfibSnapshotTable.counterEntry.get( routeKey )
   if counter is None:
      pkts = 0
      octets = 0
   elif snapshot is None:
      pkts = counter.pkts
      octets = counter.octets
   else:
      pkts = counter.pkts - snapshot.pkts
      octets = counter.octets - snapshot.octets
   return pkts, octets

def labelUnprogrammed( routeKey ):
   unprogrammed = True
   route = mplsHwStatus.route.get( routeKey )
   if route:
      unprogrammed = route.unprogrammed
      if not unprogrammed:
         adj = mplsHwStatus.adjacencyBase( route.adjBaseKey )
         if adj:
            unprogrammed = adj.unprogrammed
   return unprogrammed

def addCounterInfoForRouteKey( mplsLfibCounters, mplsLfibMultiLabelCounters,
                               routeKey ):
   labelStack = routeKey.labelStack
   if labelStack != BoundedMplsLabelStack():
      unprogrammed = labelUnprogrammed( routeKey )
   else:
      unprogrammed = False
   numPkts, numBytes = mplsLfibSmashCounter( routeKey )
   counterInfo = MplsLfibCounterLabelEntry( totalPackets=numPkts,
                                            totalBytes=numBytes,
                                            unprogrammed=unprogrammed )
   # Empty label stack indicates non-countable labels
   if labelStack == BoundedMplsLabelStack():
      mplsLfibCounters[ MplsLabel.null ] = counterInfo
   elif labelStack.stackSize == 1:
      mplsLfibCounters[ labelStack.top() ] = counterInfo
   else:
      mplsLfibMultiLabelCounters[ labelStack.cliShowString() ] = counterInfo

def showMplsLfibCounters( mode, args ):
   '''
   Return Mpls Lfib counter data.
   '''
   if not checkCounterFeatureEnabled( FeatureId.MplsLfib ):
      mode.addError( "hardware counter feature mpls lfib in should be "
                     "enabled first" )
      return

   labels = args.get( 'LABELS' )

   # limit num of labels
   if willExceedConfiguredCount( labels ):
      mode.addError( MplsCli.MAX_LABEL_EXCEEDED )
      return MplsLfibCounters()

   mplsLfibCounters = {}
   mplsLfibMultiLabelCounters = {}
   if not labels:
      for routeKey in transitLfib.lfibRoute:
         addCounterInfoForRouteKey( mplsLfibCounters, mplsLfibMultiLabelCounters,
                                    routeKey )
      # Add counter information for all non-countable labels.
      addCounterInfoForRouteKey( mplsLfibCounters, mplsLfibMultiLabelCounters,
                                 RouteKey() )
   addCounterInfoForRouteKey( mplsLfibCounters, mplsLfibMultiLabelCounters,
                              RouteKey( createLabelStack( labels ) ) )

   return MplsLfibCounters( counters=mplsLfibCounters,
                            multiLabelCounters=mplsLfibMultiLabelCounters )

class ShowMplsLfibCountersCmd( ShowCommand.ShowCliCommandClass ):
   syntax = 'show mpls lfib counters [ { LABELS } ]'
   data = {
         'mpls' : MplsCli.mplsNodeForShow,
         'lfib' : MplsCli.matcherLfib,
         'counters' : nodeCounters,
         'LABELS' : nodeLabelMatcher,
         }
   privileged = True
   cliModel = MplsLfibCounters

   handler = showMplsLfibCounters

BasicCli.addShowCommandClass( ShowMplsLfibCountersCmd )
#-------------------------------------------------------------------------
# The "clear mpls lfib counters [ START_LABEL [ END_LABEL ] ]" command
#-------------------------------------------------------------------------

def clearMplsLfibCounters( mode, args, multiLabel=False ):
   '''
   Clear Mpls Lfib counters.
   '''
   if not checkCounterFeatureEnabled( FeatureId.MplsLfib ):
      mode.addError( "hardware counter feature mpls lfib in should be "
                     "enabled first" )
      return
   labelStart = args.get( 'START_LABEL' )
   labelEnd = args.get( 'END_LABEL' )

   clearReq = mplsLfibClear.clearLfibEntryCountersRequest
   clearReq.clear()

   if labelStart is None:
      # Clear all labels
      if not multiLabel:
         # Clear the non-countable labels
         clearReq[ BoundedMplsLabelStack() ] = True
      for routeKey in transitLfib.lfibRoute:
         labelStack = routeKey.labelStack
         # 'clear mpls lfib counters' clears all counters
         # 'clear mpls lfib counters multi-label' clears only multi-label counters
         if not multiLabel or labelStack.stackSize > 1:
            if labelStack not in clearReq:
               clearReq[ labelStack ] = True
   else:
      # Clear range of labels.
      labelEnd = labelStart if labelEnd is None else labelEnd
      for routeKey in transitLfib.lfibRoute:
         labelStack = routeKey.labelStack
         # Only clear single label route counters when specifying range
         if labelStack.stackSize == 1:
            label = routeKey.topLabel
            if label >= labelStart and label <= labelEnd:
               if labelStack not in clearReq:
                  clearReq[ labelStack ] = True

class ClearMplsLfibCountersMultiLabelStartEndCmd( CliCommand.CliCommandClass ):
   syntax = 'clear mpls lfib counters [ START_LABEL [ END_LABEL ] ]'
   data = {
         'clear' : CliToken.Clear.clearKwNode,
         'mpls' : MplsCli.mplsMatcherForClear,
         'lfib' : MplsCli.matcherLfib,
         'counters' : nodeCounters,
         'START_LABEL' : nodeLabelMatcher,
         'END_LABEL' : nodeLabelMatcher,
         }

   handler = clearMplsLfibCounters

BasicCliModes.EnableMode.addCommandClass(
   ClearMplsLfibCountersMultiLabelStartEndCmd )

#-------------------------------------------------------------------------
# The "clear mpls lfib counters multi-label [ { LABELS } ]" command
#-------------------------------------------------------------------------
nodeMultiLabel = CliCommand.guardedKeyword( "multi-label",
   helpdesc="Specify a multi-label entry (labels ordered top-most to bottom-most)",
   guard=MplsCli.mplsMultiLabelLookupGuard )

def clearMplsLfibCountersMultiLabel( mode, args ):
   '''
   Clear Mpls Lfib counters for a particular label stack or all multi-label route
   counters if no label stack is specified
   '''
   labels = args.get( 'LABELS' )
   if willExceedConfiguredCount( labels ):
      mode.addError( MplsCli.MAX_LABEL_EXCEEDED )
      return

   clearReq = mplsLfibClear.clearLfibEntryCountersRequest
   clearReq.clear()

   if labels is None:
      # Clear all multi-label route counters
      clearMplsLfibCounters( mode, args, multiLabel=True )
   else:
      # Clear counter for specific label stack.
      clearReq[ createLabelStack( labels ) ] = True

class ClearMplsLfibCountersMultiLabelCmd( CliCommand.CliCommandClass ):
   syntax = "clear mpls lfib counters multi-label [ { LABELS } ]"
   data = {
      'clear' : CliToken.Clear.clearKwNode,
      'mpls' : MplsCli.mplsMatcherForClear,
      'lfib' : MplsCli.matcherLfib,
      'counters' : nodeCounters,
      'multi-label' : nodeMultiLabel,
      'LABELS' : nodeLabelMatcher,
      }

   handler = clearMplsLfibCountersMultiLabel

BasicCliModes.EnableMode.addCommandClass( ClearMplsLfibCountersMultiLabelCmd )

def Plugin( em ):
   global mplsLfibCounterTable
   global mplsLfibSnapshotTable
   global mplsLfibClear
   global mplsHwStatus
   global mplsRoutingInfo
   global fcFeatureConfigDir
   global transitLfib
   mplsLfibClear = LazyMount.mount( em,
                                    'hardware/counter/mplsLfib/clear/config',
                                    "Ale::FlexCounter::MplsLfibCliClearConfig",
                                    "w" )

   mplsHwStatus = LazyMount.mount( em,
                                   "routing/hardware/mpls/status",
                                   "Mpls::Hardware::Status",
                                   "r" )

   mplsRoutingInfo = LazyMount.mount( em,
                                      "routing/mpls/routingInfo/status",
                                      "Mpls::RoutingInfo",
                                      "r" )

   fcFeatureConfigDir = LazyMount.mount( em,
         "flexCounter/featureConfigDir/cliAgent",
         "Ale::FlexCounter::FeatureConfigDir", 'r' )

   readerInfo = SmashLazyMount.mountInfo( 'reader' )

   # Mount the Mpls Lfib current counter smash.
   mountPath = 'flexCounters/counterTable/MplsLfib/%u' % ( FapId.allFapsId )
   mplsLfibCounterTable = SmashLazyMount.mount( em, mountPath,
                                          "Ale::FlexCounter::MplsLfibCounterTable",
                                          readerInfo )

   # Mount the Mpls Lfib snapshot counter smash.
   mountPath = 'flexCounters/snapshotTable/MplsLfib/%u' % ( FapId.allFapsId )
   mplsLfibSnapshotTable = SmashLazyMount.mount( em, mountPath,
                                          "Ale::FlexCounter::MplsLfibCounterTable",
                                          readerInfo )
   transitLfib = SmashLazyMount.mount( em, "mpls/transitLfib", "Mpls::LfibStatus",
                                       readerInfo )
