#!/usr/bin/env python
# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable-msg=protected-access

from __future__ import absolute_import, division, print_function
import ast
from collections import namedtuple
from Ark import (
   switchTimeToUtc
)
from Arnet import IpGenAddr
from BasicCli import (
   addShowCommandClass,
)
import AgentDirectory
import AgentCommandRequest
import cStringIO
import CliParser
import CliCommand
import CliMatcher
import CliPlugin.IpAddrMatcher as IpAddrMatcher
import CliPlugin.SfeFtModel as SfeFtModel
from CliPlugin.FlowTrackingCliLib import (
   getFlowGroups,
   getFlowGroupsAndIds,
   getTrackerNamesAndIds,
   isSfeAgentRunning,
   hardwareShowKw,
   exporterKw,
   exporterNameMatcher,
   trackerKw,
   trackingShowKw,
   trackerNameMatcher,
   IP_GROUP,
   groupKw,
)
from CliPlugin.FlowTrackingCounterCliLib import (
   FlowGroupCounterKey,
   FlowGroupCounterEntry,
   TemplateIdType,
   CollectorStatisticsKey,
   CollectorInfo,
   addExporters,
   FlowCounterKey,
   FlowCounterEntry,
)
import CliPlugin.FlowTrackingCounterModel as ftCounterModel
from CliPlugin.SfeCliLib import nodeSfe
from CliPlugin.VrfCli import VrfExprFactory
from CliToken.Platform import (
   platformMatcherForShow
)
from CliToken.Flow import (
   flowMatcherForShow,
)
from FlowTrackerCliUtil import (
   ftrTypeHardware,
   ipStr,
   protocolStr,
   tcpFlagStr,
)
from IpLibConsts import ALL_VRF_NAME
import LazyMount
from ShowCommand import ShowCliCommandClass
import SmashLazyMount
import SharedMem
import Smash
import Tac
import Tracing
from TypeFuture import TacLazyType
import SfeAgent

IpProtoType = TacLazyType( 'Arnet::IpProtocolNumber' )

traceHandle = Tracing.Handle( 'SfeFlowCliShow' )
t0 = traceHandle.trace0
t1 = traceHandle.trace1

activeAgentDir = None
sfeVrfIdMap = None
entityManager = None
sfeConfig = None
sfeCounters = None
ipfixStats = None
shmemEm = None

def getFlowTable( trackerName ):
   mountPath = 'flowtracking/%s/flowTable/%s' % ( ftrTypeHardware, trackerName )
   smashFlowTable = shmemEm.doMount( mountPath, 'Sfe::FlowTrackerFlowTable',
                                     Smash.mountInfo( 'reader' ) )
   return smashFlowTable

def sfeGuard( mode, token ):
   if AgentDirectory.agent( mode.sysname, 'Sfe' ):
      return None
   else:
      return CliParser.guardNotThisPlatform

# SHOW COMMANDS
#------------------------------------------------------------
# show flow tracking hardware flow-table [detail]
# [ tracker <tracker-name> ] [ group <group-name> ]
# [ src-ip <ip> ] [ dst-ip <ip> ] [ src-port <port> ] [ dst-port <port> ]
# [ protocol <protocol> ] [ vrf <vrf> ]
#------------------------------------------------------------
class ShowFlowTable( object ):

   def showFlowDetail( self, flowDetailModel, entry ):
      flowDetailModel.tcpFlags = tcpFlagStr( entry.tcpFlags )
      flowDetailModel.lastPktTime = entry.lastPktTime

   def isEntryMatchingFilter( self, key, entry, srcIpAddr=None,
         dstIpAddr=None, srcPort=None, dstPort=None, vrfName=None,
         protocol=None ):

      if srcIpAddr is not None and srcIpAddr != entry.srcIpAddr:
         return False

      if dstIpAddr is not None and dstIpAddr != entry.dstIpAddr:
         return False

      if srcPort is not None and srcPort != entry.srcPort:
         return False

      if dstPort is not None and dstPort != entry.dstPort:
         return False

      vrfNameStr = "unknown"
      if entry.vrfId in sfeVrfIdMap.vrfIdToName:
         vrfNameStr = sfeVrfIdMap.vrfIdToName[ entry.vrfId ].vrfName
      if vrfName is not None and vrfName != vrfNameStr and vrfName != ALL_VRF_NAME:
         return False

      if ( protocol is not None and
            ( protocol.lower() != protocolStr( entry.ipProtocol,
               entry.ipProtocolNumber ).lower() ) and
            protocol != str( entry.ipProtocolNumber ) ):
         return False

      return True

   def populateFlowModel( self, key, entry ):
      flowModel = SfeFtModel.FlowModel()
      flowModel.bytesReceived = entry.bytes
      flowModel.pktsReceived = entry.pkts
      flowModel.startTime = entry.startTime
      flowModel._exportedTime = entry.exportedLastPktTime
      flowModel._exportedPkts = entry.exportedPkts
      flowModel._exportedBytes = entry.exportedBytes

      flowKeyModel = SfeFtModel.FlowKeyModel()
      flowKeyModel._combinedKey = key.combinedKey
      flowKeyModel._direction = key.direction
      vrfNameStr = "unknown"
      if entry.vrfId in sfeVrfIdMap.vrfIdToName:
         vrfNameStr = sfeVrfIdMap.vrfIdToName[ entry.vrfId ].vrfName
      flowKeyModel.vrfName = vrfNameStr
      flowKeyModel.srcAddr = IpGenAddr( ipStr( entry.srcIpAddr ) )
      flowKeyModel.dstAddr = IpGenAddr( ipStr( entry.dstIpAddr ) )
      flowKeyModel.ipProtocolNumber = entry.ipProtocolNumber
      try:
         flowKeyModel.ipProtocol = entry.ipProtocol
      except ValueError:
         # unassigned protocol number will not have ipProtocol Enum
         pass
      flowKeyModel.srcPort = entry.srcPort
      flowKeyModel.dstPort = entry.dstPort
      flowModel.key = flowKeyModel

      return flowModel

   def populateFlowEntry( self, flowTable, grName, key, entry, detailDisplay=False ):

      flowModel = self.populateFlowModel( key, entry )

      if detailDisplay:
         flowDetailModel = SfeFtModel.FlowDetailModel()
         self.showFlowDetail( flowDetailModel, entry )
         flowModel.flowDetail = flowDetailModel

      return flowModel

   def populateFlowEntries( self, grName, groupModel, flowTable, detailDisplay,
                            srcIpAddr, dstIpAddr, srcPort, dstPort, vrfName,
                            protocol ):
      flowEntry = {}
      if grName == IP_GROUP:
         flowEntry = flowTable.ipFlowEntry
      for key, entry in flowEntry.items():
         if not self.isEntryMatchingFilter( key, entry, srcIpAddr,
               dstIpAddr, srcPort, dstPort, vrfName, protocol ):
            continue
         flowModel = self.populateFlowEntry( flowTable, grName, key, entry,
                                             detailDisplay=detailDisplay )
         groupModel.flows.append( flowModel )

   def fetchTrackingModel( self ):
      trackingModel = SfeFtModel.TrackingModel()
      trackingModel.running = isSfeAgentRunning( entityManager, activeAgentDir )
      trackingModel.softwareFlowTable = True
      return trackingModel

   def showFlowTable( self, mode, args ):
      groupName = args.get( 'GROUP_NAME' )
      trackerName = args.get( 'TRACKER_NAME' )
      srcIpAddr = args.get( 'SRC_IPV4' )
      dstIpAddr = args.get( 'DST_IPV4' )
      srcPort = args.get( 'SRC_PORT' )
      dstPort = args.get( 'DST_PORT' )
      protocol = args.get( 'PROTOCOL' )
      vrfName = args.get( 'VRF' )

      trackingModel = self.fetchTrackingModel()
      trackingModel._detail = 'detail' in args
      for trName in sfeConfig.hwFtConfig:
         if trackerName and trName != trackerName:
            continue

         flowTable = getFlowTable( trName )
         if not flowTable:
            continue

         trackerModel = SfeFtModel.TrackerModel()
         trackingModel.trackers[ trName ] = trackerModel
         trackerModel.numFlows = 0

         for grName in getFlowGroups( ftrTypeHardware, trName ):
            if groupName and grName != groupName:
               continue
            groupModel = SfeFtModel.GroupModel()
            trackerModel.groups[ grName ] = groupModel
            self.populateFlowEntries(
               grName, groupModel, flowTable,
               trackingModel._detail, srcIpAddr, dstIpAddr, srcPort, dstPort,
               vrfName, protocol )
            trackerModel.numFlows += len( groupModel.flows )
            if groupName:
               break
         if trackerName:
            break
      return trackingModel

def showFlowTable( mode, args ):
   t = ShowFlowTable()
   return t.showFlowTable( mode, args )

nodeSrcIp = CliCommand.Node(
   matcher=CliMatcher.KeywordMatcher( 'src-ip',
                                      "Flow source IP address" ),
   maxMatches=1 )

nodeIPv4 = CliCommand.Node(
                        matcher=IpAddrMatcher.IpAddrMatcher(
                                 helpdesc='IPv4 address' ),
                        maxMatches=1 )

nodeDstIp = CliCommand.Node(
   matcher=CliMatcher.KeywordMatcher( 'dst-ip',
                                      "Flow destination IP address" ),
   maxMatches=1 )

nodeSrcPort = CliCommand.Node(
   matcher=CliMatcher.KeywordMatcher( 'src-port',
                                      "Flow source port" ),
   maxMatches=1 )

nodePort = CliCommand.Node(
   matcher=CliMatcher.IntegerMatcher( 0, 65535,
                                      helpdesc='IP port' ),
   maxMatches=1 )

nodeDstPort = CliCommand.Node(
   matcher=CliMatcher.KeywordMatcher( 'dst-port',
                                      "Flow destination port" ),
   maxMatches=1 )

protocols = [ ( p[ len( 'ipProto' ) : ] if p.startswith( 'ipProto' ) else p )
              for p in IpProtoType.attributes ]

nodeProtocol = CliCommand.Node(
   matcher=CliMatcher.KeywordMatcher( 'protocol',
                                      "Flow IP protocol" ),
   maxMatches=1 )

nodeProtocolValue = CliCommand.Node(
   matcher=CliMatcher.DynamicNameMatcher( protocols,
                                          "IP protocol",
                                          passContext=True ),
   maxMatches=1 )

nodeDetail = CliCommand.Node(
   matcher=CliMatcher.KeywordMatcher( 'detail',
                                      "Detailed flow information" ),
   maxMatches=1 )

allGroupNameMatcher = CliCommand.Node(
   matcher=CliMatcher.DynamicNameMatcher(
      [ IP_GROUP ],
      "Flow group name",
      passContext=True ),
   maxMatches=1 )

class ShowTrackingFilterExpression( CliCommand.CliExpression ):
   expression = '''[ { ( tracker TRACKER_NAME )
                     | ( group GROUP_NAME )
                     | ( src-ip SRC_IPV4 )
                     | ( dst-ip DST_IPV4 )
                     | ( src-port SRC_PORT )
                     | ( dst-port DST_PORT )
                     | ( protocol PROTOCOL )
                     | ( VRF )
                     | ( detail ) } ]'''
   data = {
      'tracker' : CliCommand.Node( trackerKw, maxMatches=1 ),
      'TRACKER_NAME' : trackerNameMatcher,
      'group' : CliCommand.Node( groupKw, maxMatches=1 ),
      'GROUP_NAME' : allGroupNameMatcher,
      'src-ip' : nodeSrcIp,
      'SRC_IPV4' : nodeIPv4,
      'dst-ip' : nodeDstIp,
      'DST_IPV4' : nodeIPv4,
      'src-port' : nodeSrcPort,
      'SRC_PORT' : nodePort,
      'dst-port' : nodeDstPort,
      'DST_PORT' : nodePort,
      'protocol' : nodeProtocol,
      'PROTOCOL' : nodeProtocolValue,
      'VRF' : VrfExprFactory( helpdesc='Flow VRF',
                              inclDefaultVrf=True,
                              inclAllVrf=True,
                              maxMatches=1 ),
      'detail' : nodeDetail,
   }

class ShowFlowTrackingFlowTable( ShowCliCommandClass ):
   syntax = 'show flow tracking hardware flow-table [ FILTER ] '
   data = {
      'flow' : flowMatcherForShow,
      'tracking' : trackingShowKw,
      'hardware' : hardwareShowKw,
      'flow-table' : CliCommand.guardedKeyword(
         'flow-table',
         helpdesc='Flow table',
         guard=sfeGuard ),
      'FILTER' : ShowTrackingFilterExpression
   }

   handler = showFlowTable
   cliModel = SfeFtModel.TrackingModel

addShowCommandClass( ShowFlowTrackingFlowTable )

#--------------------------
FgCount = namedtuple( 'FgCount', [ 'flows', 'expiredFlows', 'packets' ] )

def addFlowGroups( trModel, trName, ftId, v4Flows ):
   anyGroups = False
   flows = 0
   expiredFlows = 0
   packets = 0
   for fgName, fgId in getFlowGroupsAndIds( ftrTypeHardware, trName ):
      anyGroups = True
      t1( 'process flow group', fgName )
      fgModel = ftCounterModel.FlowGroupCounters()
      if fgName == IP_GROUP:
         fgModel.activeFlows = v4Flows

      counterKey = FlowGroupCounterKey( ftId, fgId )
      fgCounts = sfeCounters.flowGroupCounter.get( counterKey )
      if not fgCounts:
         t0( 'No counters for', counterKey.smashString() )
         fgCounts = FlowGroupCounterEntry()
      flows += fgCounts.flowEntry.flows
      fgModel.flows = fgCounts.flowEntry.flows
      expiredFlows += fgCounts.flowEntry.expiredFlows
      fgModel.expiredFlows = fgCounts.flowEntry.expiredFlows
      packets += fgCounts.flowEntry.packets
      fgModel.packets = fgCounts.flowEntry.packets

      trModel.flowGroups[ fgName ] = fgModel

   clearTimeKey = FlowGroupCounterKey( ftId, TemplateIdType.maxTemplateId )
   clearTime = sfeCounters.flowGroupCounter.get( clearTimeKey )
   if clearTime and clearTime.key == clearTimeKey:
      # clearTime should be most recent lastClearedTime from either
      # sfeCounters or ipfixStats
      lastClearedTime = switchTimeToUtc( clearTime.flowEntry.lastClearedTime )
      if lastClearedTime > trModel.clearTime:
         trModel.clearTime = lastClearedTime

   if not anyGroups:
      t0( 'WARNING: no flow groups for tracker', trName )

   return FgCount( flows=flows, expiredFlows=expiredFlows, packets=packets )

def addTrackers( model, trFilter, expFilter ):
   allTrackerFlows = 0
   allActiveFlows = 0
   for trName, ftId in getTrackerNamesAndIds( ftrTypeHardware ):
      if trFilter and trFilter != trName:
         t1( 'tracker', trName, 'did not match filter' )
         continue
      t1( 'process tracker', trName )
      flowTable = getFlowTable( trName )
      if not flowTable:
         continue
      trModel = ftCounterModel.TrackerCounters()
      v4Flows = len( flowTable.ipFlowEntry )
      trModel.activeFlows = v4Flows
      clearTimeKey = CollectorStatisticsKey( trName, "", CollectorInfo() )
      clearTime = ipfixStats.stats.get( clearTimeKey )
      if clearTime and clearTime.key == clearTimeKey:
         trModel.clearTime = switchTimeToUtc( clearTime.lastClearedTime )
      counts = addFlowGroups( trModel, trName, ftId, v4Flows )
      trModel.flows = counts.flows
      trModel.expiredFlows = counts.expiredFlows
      allTrackerFlows += counts.flows
      allActiveFlows += trModel.activeFlows
      trModel.packets = counts.packets
      addExporters( trModel, trName, expFilter, sfeConfig, ipfixStats, True )
      model.trackers[ trName ] = trModel
   ftrKey = FlowCounterKey( 0 )
   ftrCounts = sfeCounters.flowsCounter.get( ftrKey )
   if not ftrCounts:
      ftrCounts = FlowCounterEntry()
   if ftrCounts.flowEntry.lastClearedTime != 0:
      lastClearedTime = switchTimeToUtc( ftrCounts.flowEntry.lastClearedTime )
      model.clearTime = lastClearedTime
   model.activeFlows = allActiveFlows
   model.flows = ftrCounts.flowEntry.flows
   model.expiredFlows = ftrCounts.flowEntry.expiredFlows
   model.packets = ftrCounts.flowEntry.packets

def showCountersCmd( mode, args ):
   model = ftCounterModel.FtrCounters()
   model.running = isSfeAgentRunning( entityManager, activeAgentDir )
   model.softwareFlowTable = True
   if model.running:
      trFilter = args.get( 'TRACKER_NAME' )
      expFilter = args.get( 'EXPORTER_NAME' )
      addTrackers( model, trFilter, expFilter )
   return model

class ShowFtCounters( ShowCliCommandClass ):
   syntax = '''show flow tracking hardware counters 
               [ tracker TRACKER_NAME [ exporter EXPORTER_NAME ] ]'''

   data = {
      'flow' : flowMatcherForShow,
      'tracking' : trackingShowKw,
      'hardware' : hardwareShowKw,
      'counters' : CliCommand.guardedKeyword(
         'counters',
         helpdesc='Show flow tracking hardware counters',
         guard=sfeGuard ),
      'tracker' : trackerKw,
      'TRACKER_NAME' : trackerNameMatcher,
      'exporter' : exporterKw,
      'EXPORTER_NAME' : exporterNameMatcher,
   }

   handler = showCountersCmd
   cliModel = ftCounterModel.FtrCounters

addShowCommandClass( ShowFtCounters )

#---------------------------------------------------------------------------------
# show platform sfe flow tracking counters
#---------------------------------------------------------------------------------
def doShowSfeFtCounters( mode, args ):
   buff = cStringIO.StringIO()
   AgentCommandRequest.runSocketCommand( mode.entityManager, SfeAgent.name(),
                                         "sfe", "Ftcnt", stringBuff=buff,
                                         timeout=50, keepalive=True )
   output = buff.getvalue()
   try:
      # pylint: disable-msg=W0123
      ftCounters = ast.literal_eval( output )
   except SyntaxError:
      mode.addError( output )
      return SfeFtModel.PlatformCountersModel()

   model = SfeFtModel.PlatformCountersModel()
   model.setAttrsFromDict( ftCounters )
   return model

class ShowSfeFtCountersCmd( ShowCliCommandClass ):
   syntax = 'show platform sfe flow tracking counters'
   data = {
      'platform' : platformMatcherForShow,
      'sfe' : nodeSfe,
      'flow' : flowMatcherForShow,
      'tracking' : trackingShowKw,
      'counters' : 'Show flow tracker counters',
   }

   handler = doShowSfeFtCounters
   cliModel = SfeFtModel.PlatformCountersModel
   privileged = True

addShowCommandClass( ShowSfeFtCountersCmd )

#--------------------------
def Plugin( em ):
   global sfeConfig
   global activeAgentDir
   global sfeVrfIdMap
   global entityManager
   global shmemEm
   global ipfixStats
   global sfeCounters

   entityManager = em

   sfeConfig = LazyMount.mount( em,
                                'hardware/flowtracking/config/hardware',
                                'HwFlowTracking::Config', 'r' )
   activeAgentDir = LazyMount.mount( em, 'flowtracking/activeAgent',
                                     'Tac::Dir', 'ri' )
   sfeVrfIdMap = SmashLazyMount.mount( em, "vrf/vrfIdMapStatus",
                                        "Vrf::VrfIdMap::Status",
                                         SmashLazyMount.mountInfo( 'reader' ) )
   ipfixStats = SmashLazyMount.mount( em, 'flowtracking/hardware/ipfix/statistics',
                                      'Smash::Ipfix::CollectorStatistics',
                                      SmashLazyMount.mountInfo( 'reader' ) )
   sfeCounters = SmashLazyMount.mount( em, 'flowtracking/hardware/counters',
                                       'Smash::FlowTracker::FtCounters',
                                        SmashLazyMount.mountInfo( 'reader' ) )
   shmemEm = SharedMem.entityManager( sysdbEm=em )
