#!/usr/bin/env python
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import AleCountersCli
import Arnet
import BasicCli
import CliCommand
import CliExtensions
import CliMatcher
import CliParser
import CliPlugin.IpAddrMatcher as IpAddrMatcher
import CliPlugin.IpGenAddrMatcher as IpGenAddrMatcher
from CliPlugin.SrTePolicyCli import srTeTunnelCountersHook
import CliToken
import LazyMount
import MplsCli
from MplsDebugCli import tunnelTypeValMatcher
from MplsTypeLib import tunnelTypeXlate
import ShowCommand
import SmashLazyMount
import Tac
import TunnelCli
import TunnelModels

from MplsTunnelCountersModel import (
   MplsTunnelCountersEntry,
   MplsTunnelCounters,
)

mplsTunnelClear = None
mplsTunnelCounterTable = None
mplsTunnelSnapshotTable = None

CounterIndex = Tac.Type( 'FlexCounters::CounterIndex' )
MplsTunnelFeatureId = Tac.Type( 'FlexCounters::FeatureId' ).MplsTunnel

def mplsTunnelCountersGuard( mode, token ):
   '''
   Guard the Mpls Tunnel Counters CLI commands.
   '''
   if not AleCountersCli.checkCounterFeatureSupported( MplsTunnelFeatureId ):
      return CliParser.guardNotThisPlatform

   if AleCountersCli.checkCounterFeatureConfigured( MplsTunnelFeatureId ):
      return None

   return "'hardware counter feature mpls tunnel' should be enabled first"

# Platform-specific hook indicating if given counter is actually programmed
mplsTunnelCounterActiveHook = CliExtensions.CliHook()

def getMplsTunnelCountersFromSmash( tunnelId ):

   counterActive = False
   for hook in mplsTunnelCounterActiveHook.extensions():
      counterActive = hook( tunnelId )
      break

   if counterActive:
      return AleCountersCli.getCurrentCounter( CounterIndex( tunnelId ),
                                               mplsTunnelCounterTable,
                                               mplsTunnelSnapshotTable )
   else:
      return None, None
   
srTeTunnelCountersHook.addExtension( getMplsTunnelCountersFromSmash )
#-------------------------------------------------------------------------
# The "show mpls tunnel counters" command
#-------------------------------------------------------------------------
countersAfterTunnelNode = CliCommand.guardedKeyword( 'counters',
      helpdesc='Tunnel egress hardware counters',
      guard=mplsTunnelCountersGuard )
typeMatcher = CliMatcher.KeywordMatcher( 'type',
      helpdesc='Match tunnel type' )
indexMatcher = CliMatcher.KeywordMatcher( 'index',
      helpdesc='Match tunnel index' )
nexthopMatcher = CliMatcher.KeywordMatcher( 'nexthop',
      helpdesc='Match tunnel nexthop' )
interfaceMatcher = CliMatcher.KeywordMatcher( 'interface',
      helpdesc='Match tunnel interface' )
tableOutputMatcher = CliMatcher.KeywordMatcher( 'table-output',
      helpdesc='Provide results in a table format' )

nexthopValMatcher = IpGenAddrMatcher.IpGenAddrMatcher( MplsCli.nhStr )

def getMplsTunnelCounterEntryModel( tunnelId=None, endpoint=None,
                                    nexthop=None, intfId=None ):

   tunnelType = TunnelCli.getTunnelTypeFromTunnelId( tunnelId )
   if tunnelType in MplsCli.showTunnelFibIgnoredTunnelTypes:
      return None

   vias = []
   mplsTunnelCounterEntry = MplsCli.tunnelFib.entry.get( tunnelId )
   if mplsTunnelCounterEntry:
      tunnelEndpoint = TunnelCli.getEndpointFromTunnelId( tunnelId )
      if endpoint and tunnelEndpoint != endpoint:
         return None
      viaMatched = False
      for via in mplsTunnelCounterEntry.tunnelVia.itervalues():
         singleViaMatched = True
         viaModels = TunnelCli.getMplsViaModelFromTunnelVia( via )
         for viaModel in viaModels:
            if nexthop:
               singleViaMatched &= ( viaModel.nexthop == nexthop )
            if intfId:
               singleViaMatched &= ( viaModel.interface == intfId )
            vias.append( TunnelModels.IpVia( nexthop=viaModel.nexthop,
                                             interface=viaModel.interface,
                                             type='ip' ) )
         viaMatched |= singleViaMatched

      if not viaMatched:
         return None

      txPackets, txBytes = getMplsTunnelCountersFromSmash( tunnelId )
      return MplsTunnelCountersEntry(
            txPackets=txPackets, txBytes=txBytes,
            tunnelIndex=TunnelCli.getTunnelIndexFromId( tunnelId ),
            tunnelType=TunnelModels.tunnelTypeStrDict[ tunnelType ],
            endpoint=tunnelEndpoint, vias=vias )
   return None

def getMplsTunnelCounterModel( args ):

   mplsTunnelCounterEntries = {}
   allTunnelIds = MplsCli.tunnelFib.entry
   tunnelIds = allTunnelIds
   endpoint = None
   nexthop = None
   intfId = None

   if 'TYPE' in args:
      tunnelType = tunnelTypeXlate[ args[ 'TYPE' ] ]
      if 'INDEX' in args:
         tunnelIds = [ TunnelCli.getTunnelIdFromIndex( tunnelType,
                          args[ 'INDEX' ] ) ]
      else:
         tunnelIds = [ tId for tId in allTunnelIds if tunnelType ==
                       TunnelCli.getTunnelTypeFromTunnelId( tId ) ]
   elif 'endpoint' in args:
      endpoint = Arnet.IpGenPrefix( str( args[ 'ENDPOINT' ] ) )
   elif 'NEXTHOP' in args:
      nexthop = Arnet.IpGenAddr( str( args[ 'NEXTHOP' ] ) )
   elif 'INTF' in args:
      intf = args[ 'INTF' ]
      intfId = Tac.Value( "Arnet::IntfId", str( intf ) ) if intf else None

   for tunnelId in tunnelIds:
      mplsTunnelCounterEntryModel = getMplsTunnelCounterEntryModel(
                                       tunnelId=tunnelId, endpoint=endpoint,
                                       nexthop=nexthop, intfId=intfId )
      if mplsTunnelCounterEntryModel:
         mplsTunnelCounterEntries[ tunnelId ] = mplsTunnelCounterEntryModel

   return mplsTunnelCounterEntries

class ShowMplsTunnelCountersCmd( ShowCommand.ShowCliCommandClass ):
   syntax = ( 'show mpls tunnel counters '
              '[ ( type TYPE [ index INDEX ] ) '
              '| ( endpoint ENDPOINT )'
              '| ( nexthop NEXTHOP ) '
              '| ( interface INTF ) ] '
              '[ table-output ]' )
   data = {
         'mpls': MplsCli.mplsNodeForShow,
         'tunnel': MplsCli.tunnelAfterMplsMatcherForShow,
         'counters': countersAfterTunnelNode,
         'type': typeMatcher,
         'TYPE': tunnelTypeValMatcher,
         'index': indexMatcher,
         'INDEX': TunnelCli.tunnelIndexMatcher,
         'endpoint': MplsCli.endpointMatcher,
         'ENDPOINT': IpGenAddrMatcher.IpGenAddrOrPrefixExprFactory(
            ipOverlap=IpAddrMatcher.PREFIX_OVERLAP_AUTOZERO,
            ip6Overlap=IpAddrMatcher.PREFIX_OVERLAP_AUTOZERO,
            allowAddr=True ),
         'nexthop': nexthopMatcher,
         'NEXTHOP': nexthopValMatcher,
         'interface': interfaceMatcher,
         'INTF' : MplsCli.intfValMatcher,
         'table-output': tableOutputMatcher,
   }
   cliModel = MplsTunnelCounters

   @staticmethod
   def handler( mode, args ):
      entries = getMplsTunnelCounterModel( args ).values()
      # Sort entries by endpoint
      entries.sort( key=lambda entry: entry.endpoint.stringValue
                                      if entry.endpoint else "" )
      return MplsTunnelCounters( entries=entries,
                                 _tableOutput='table-output' in args )

BasicCli.addShowCommandClass( ShowMplsTunnelCountersCmd )

#------------------------------------------
# clear mpls tunnel counters
#------------------------------------------
class ClearMplsTunnelCountersCmd( CliCommand.CliCommandClass ):
   syntax = ( 'clear mpls tunnel counters '
              '[ ( type TYPE [ index INDEX ] ) '
              '| ( endpoint ENDPOINT )'
              '| ( nexthop NEXTHOP ) '
              '| ( interface INTF ) ]' )
   data = {
         'clear': CliToken.Clear.clearKwNode,
         'mpls': MplsCli.mplsMatcherForClear,
         'tunnel': MplsCli.tunnelAfterMplsMatcherForShow,
         'counters': countersAfterTunnelNode,
         'type': typeMatcher,
         'TYPE': tunnelTypeValMatcher,
         'index': indexMatcher,
         'INDEX': TunnelCli.tunnelIndexMatcher,
         'endpoint': MplsCli.endpointMatcher,
         'ENDPOINT': IpGenAddrMatcher.IpGenAddrOrPrefixExprFactory(
            ipOverlap=IpAddrMatcher.PREFIX_OVERLAP_AUTOZERO,
            ip6Overlap=IpAddrMatcher.PREFIX_OVERLAP_AUTOZERO,
            allowAddr=True ),
         'nexthop': nexthopMatcher,
         'NEXTHOP': nexthopValMatcher,
         'interface': interfaceMatcher,
         'INTF' : MplsCli.intfValMatcher,
         'table-output': tableOutputMatcher,
   }

   @staticmethod
   def handler( mode, args ):
      entries = getMplsTunnelCounterModel( args )
      if entries:
         mplsTunnelClear.clearTunnelCountersRequest.clear()
         for tunnelId in entries:
            mplsTunnelClear.clearTunnelCountersRequest[ tunnelId ] = True
         n = len( entries )
         print "%d tunnel counter entr%s cleared successfully" % (
                 n, 'y' if n == 1 else 'ies' )
      else:
         print "% No tunnel counter entries found"

BasicCli.EnableMode.addCommandClass( ClearMplsTunnelCountersCmd )

#-------------------------------------------------------------------------
# The "show mpls tunnel interface counters" command
#-------------------------------------------------------------------------
#class ShowMplsTunnelInterfaceCountersCmd( CliCommand.CliCommandClass ):
#   syntax = 'show mpls tunnel interface [ INTF ] counters egress'
#   data = {
#      'mpls': MplsCli.mplsNodeForShow,
#      'tunnel': MplsCli.tunnelAfterMplsMatcherForShow,
#      'interface': CliCommand.guardedKeyword( 'interface',
#                      helpdesc="Per-interface MPLS tunnel information",
#                      guard=mplsTunnelCountersGuard ),
#      'counters': 'ggregate MPLS tunnel counters',
#      'INTF': MplsCli.intfValMatcher,
#      'egress': 'Aggregate egress MPLS tunnel counters',
#   }
#   cliModel = MplsTunnelInterfaceCounters
#
#   @staticmethod
#   def handler( mode, args ):
#      tunnelInterfaceCounters = MplsTunnelInterfaceCounters()
#      interfaces = tunnelInterfaceCounters.interfaces
#
#      intfId = None
#      intf = args.get( 'INTF' )
#      if intf is not None:
#         intfId = Tac.Value( "Arnet::IntfId", str( intf ) )
#
#      # Generate per-interface aggregate stats for mpls tunnels
#      for tunnelId, tunnelEntry in MplsCli.tunnelFib.entry.iteritems():
#         for via in tunnelEntry.tunnelVia:
#            viaModels = TunnelCli.getMplsViaModelFromTunnelVia( via )
#            for viaModel in viaModels:
#               if not viaModel.interface:
#                  continue
#               # Skip if need to filter by interface and not matched
#               if intfId and viaModel.interface != intfId:
#                  continue
#               txPackets, txBytes = getMplsTunnelCountersFromSmash( tunnelId )
#               if txPackets and txBytes:
#                  interface = viaModel.interface
#                  if interface in interfaces:
#                     interfaces[ interface ].txPackets += txPackets
#                     interfaces[ interface ].txBytes += txBytes
#                  else:
#                     interfaces[ interface ] = \
#                        MplsTunnelInterfaceCountersEntry(
#                           txPackets=txPackets, txBytes=txBytes )
#
#      return tunnelInterfaceCounters
#
# Command disabled, see BUG252582
# Todor: I've converted the above command, but it may need a little work.
#
# BasicCli.registerLegacyShowCommandClass( ShowMplsTunnelInterfaceCountersCmd )

def Plugin( em ):
   global mplsTunnelClear
   global mplsTunnelCounterTable
   global mplsTunnelSnapshotTable

   mplsTunnelClear = LazyMount.mount( em, 'hardware/counter/tunnel/clear/config',
                                      'Ale::FlexCounter::TunnelCliClearConfig', 'w' )

   # Reference to Nexthop feature ID is deliberate here as we are
   # reusing Nexthop counter smashes by design
   NexthopVal = Tac.Type( 'FlexCounters::FeatureIdEnumVal'
      ).enumVal( Tac.Type( 'FlexCounters::FeatureId' ).Nexthop )
   AllFapsId = Tac.Type( 'FlexCounters::FapId' ).allFapsId

   # Mount the counter tables via smash
   flexDir = '/%u/%u' % ( NexthopVal, AllFapsId )
   mountInfo = SmashLazyMount.mountInfo( 'reader' )
   mplsTunnelCounterTable = SmashLazyMount.mount( em,
      'flexCounter/counterTable' + flexDir,
      "FlexCounters::CounterTable", mountInfo )
   mplsTunnelSnapshotTable = SmashLazyMount.mount( em,
      'flexCounter/snapshotCounterTable' + flexDir,
      "FlexCounters::CounterTable", mountInfo )
