#!/usr/bin/env python
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import Ark
import CliPlugin.EthIntfCli as EthIntfCli
import CliPlugin.IntfCli as IntfCli
import CliPlugin.IntfQueueCounterLib as IntfQueueCounterLib
from CliPlugin.QueueCountersModel import QueueCountersRate, EgressQueueCounters, \
     EgressQueueDestinationTypeCounters, EgressQueueTrafficClassCounters, \
     EgressQueueDropPrecedenceCounters, Counters
from CliPlugin.QosCli import showQosInterface
from CliPlugin.AleCountersModel import LatePollEvent, FastPollStats
from CliPlugin.AleCountersModel import PollInfo, FastPollInfo, FastPollInfoDetail

def vtepCounter( idx, counterTable, snapshotTable ):
   """
   Read the counters from counterTable and deduct the value found in snapshotTable.
   snapshotTable contains the snapshot of counter values from counterTable when
   clear command is issued.
   """
   currentCtr = counterTable.vtepDecapCounter.get( idx )
   snapshotCtr = snapshotTable.vtepDecapCounter.get( idx )

   def counter( attrName ):
      if currentCtr is None:
         # If the counter is missing, will return zeros. This is usually a transient
         # condition when we are caught in the middle of cleanup.
         return 0
      currentVal = getattr( currentCtr, attrName )
      # If the snapshot is not present, will return the running counter only.
      snapshotVal = 0 if snapshotCtr is None else getattr( snapshotCtr, attrName )
      # The current is the running counter minus the snapshot.
      return currentVal - snapshotVal
   return counter( 'pkts' ), counter( 'octets' ), \
          counter( 'bumPkts' ), counter( 'bumOctets' ), \
          counter( 'dropPkts' ), counter( 'dropOctets' )

def showIntfQueueCounterRates( mode, intf=None, mod=None,
                               trafficClassLabelFunc=None ):
   counterAccessor = IntfQueueCounterLib.getCounterAccessor()
   queueCounters = QueueCountersRate()
   queueCounters.egressQueueCounters = EgressQueueCounters()
   interfaces = queueCounters.egressQueueCounters.interfaces
   intfs = IntfCli.counterSupportedIntfs( mode, intf, mod,
                                          intfType=EthIntfCli.EthPhyIntf )
   if intfs:
      intfs = [ i for i in intfs if i.name.startswith( "Ethernet" ) ]
   if not intfs:
      mode.addWarning(
         "Queue counter rates not supported on %s" % ( intf or "any interface", ) )
      return queueCounters

   for intfObj in intfs:
      intfId = intfObj.name
      interfaces[ intfId ] = EgressQueueDestinationTypeCounters()
      interfaces[ intfId ].ucastQueues = EgressQueueTrafficClassCounters()
      interfaces[ intfId ].mcastQueues = EgressQueueTrafficClassCounters()
      counter = counterAccessor.counter( intfId )
      numUnicastQueues = counterAccessor.numUnicastQueues( intfId )
      numMulticastQueues = counterAccessor.numMulticastQueues( intfId )
      ucastTc = interfaces[ intfId ].ucastQueues.trafficClasses
      mcastTc = interfaces[ intfId ].mcastQueues.trafficClasses
      for queueType, numQueues, tc in [ ( "ucast", numUnicastQueues, ucastTc ),
                                        ( "mcast", numMulticastQueues, mcastTc ) ]:
         for queueId in xrange( numQueues ):
            queueIdx = queueId if queueType == "ucast" \
                       else queueId + numUnicastQueues
            intfQueueRates = counter.intfQueueRates[ queueIdx ]
            if trafficClassLabelFunc is None:
               tcId = "TC" + str( queueId )
            else:
               tcId = trafficClassLabelFunc( intfId, queueId )
            if tcId in tc:
               # If the traffic class label already exists, skip it.  This is
               # probably a mapped traffic class and the rates for both traffc
               # classes will be combined in the mapped queue.
               continue
            tc[ tcId ] = EgressQueueDropPrecedenceCounters()
            tc[ tcId ].dropPrecedences[ "DP0" ] = Counters()
            ctrs = tc[ tcId ].dropPrecedences[ "DP0" ]
            ctrs.enqueuedPacketsRate = intfQueueRates.pktsRate
            ctrs.enqueuedBitsRate = intfQueueRates.bitsRate
            ctrs.droppedPacketsRate = intfQueueRates.pktsDropRate
            ctrs.droppedBitsRate = intfQueueRates.bitsDropRate

   supportedIntfs = IntfCli.countersRateSupportedIntfs( mode, intf=intf, mod=mod )

   for supportedIntf in supportedIntfs:
      if supportedIntf.name in interfaces:
         intf = interfaces[ supportedIntf.name ]
         intf.bandwidth = supportedIntf.bandwidth()
         intf.loadInterval = IntfCli.getActualLoadIntervalValue(
            supportedIntf.config().loadInterval )
         intfQos = showQosInterface( supportedIntf )

         for txQueueQos in intfQos.txQueueQosModel.txQueueList:
            for destType in ( "ucastQueues", "mcastQueues", ):
               tcs = intf[ destType ][ "trafficClasses" ]
               if trafficClassLabelFunc is None:
                  tcId = "TC{}".format( txQueueQos.txQueue )
               else:
                  queueId = int( txQueueQos.txQueue )
                  tcId = trafficClassLabelFunc( supportedIntf.name, queueId )
               if tcId not in tcs:
                  continue

               tc = tcs[ tcId ]
               if txQueueQos.operationalSchedMode.schedulingMode == "roundRobin":
                  tc.schedMode = "weightedRoundRobin"
                  tc.wrrBw = txQueueQos.operationalWrrBw
               else:
                  tc.schedMode = "strictPriority"
               tc.shapeRate = txQueueQos.operationalShapeRate
   return queueCounters

def populatePollInfo( aleCountersConfig, aleCountersStatus ):
   """
   Creates/populates the PollInfo model from an Ale::Counters::CliConfig,
   and an Ale::Counters::Status object. The objects are arguments so that
   this code can be used with different polling objects (e.g. StrataCounters,
   and possibly SandCounters).
   """
   # Create empty model
   model = PollInfo()

   model.pollInterval = aleCountersConfig.periodPoll
   model.lastPollTimestamp = aleCountersStatus.timestampLastPoll
   model.totalPollCount = aleCountersStatus.pollCount
   model.latePollCount = aleCountersStatus.latePollCount

   avgFetch = 0.0
   if aleCountersStatus.fetchCount:
      avgFetch = aleCountersStatus.totalFetchTime / aleCountersStatus.fetchCount

   model.totalFetchCount = aleCountersStatus.fetchCount
   model.averageFetchTime = avgFetch
   model.maximumFetchTime = aleCountersStatus.maxFetchTime
   model.reEnqueueCount = aleCountersStatus.reEnqueueCount
   model.retryDequeueCount = aleCountersStatus.retryDequeueCount

   return model

def populateFastPollInfo( aleCountersConfig, aleCountersStatus ):
   """ Create and populate the Cli model for the detailed fast poll statistics"""
   # Create empty model
   model = FastPollInfo()

   model.pollFastInterval = aleCountersConfig.periodFastPoll
   model.lastPollTimestamp = Ark.switchTimeToUtc(
      aleCountersStatus.timestampLastPoll )
   model.totalPollCount = aleCountersStatus.pollCount
   model.latePollCount = aleCountersStatus.latePollCount

   avgFetch = 0.0
   if aleCountersStatus.fetchCount:
      avgFetch = ( aleCountersStatus.totalFetchTime /
                   aleCountersStatus.fetchCount * 1000 )

   model.totalFetchCount = aleCountersStatus.fetchCount
   model.averageFetchTime = avgFetch
   model.maximumFetchTime = aleCountersStatus.maxFetchTime * 1000
   model.reEnqueueCount = aleCountersStatus.reEnqueueCount
   model.retryDequeueCount = aleCountersStatus.retryDequeueCount

   return model

def populateFastPollInfoDetail(
      aleCountersStatus, latePollHistory, hourSnapshotDir, daySnapshotDir ):
   detail = FastPollInfoDetail()

   def latePollEvent( timestamp, delay ):
      out = LatePollEvent()
      out.timestamp = Ark.switchTimeToUtc( timestamp )
      out.delay = delay
      return out
   detail.latePollHistory = [
         latePollEvent( event.timestamp, event.delay * 1000 ) for
         event in latePollHistory.latePoll.values() ]

   def snapshotToStats( snapshots, snapshotName ):
      earliestSnapshot = snapshots[ 0 ]
      stats = FastPollStats()
      stats.statsStartTimestamp = Ark.switchTimeToUtc(
         earliestSnapshot.timestampSnapshot )
      stats.totalPollCount = (
            aleCountersStatus.pollCount - earliestSnapshot.pollCount )
      stats.latePollCount = (
            aleCountersStatus.latePollCount - earliestSnapshot.latePollCount )
      avgFetch = 0.0
      fetchCount = aleCountersStatus.fetchCount - earliestSnapshot.fetchCount
      if fetchCount:
         fetchTime = (
               aleCountersStatus.totalFetchTime - earliestSnapshot.totalFetchTime )
         avgFetch = fetchTime / fetchCount * 1000
      stats.averageFetchTime = avgFetch
      stats.maximumFetchTime = max( s.maxFetchTime for s in snapshots ) * 1000.
      stats.maximumFetchTime = max(
            stats.maximumFetchTime,
            aleCountersStatus.maxFetchTimePerUser.get( snapshotName, 0 ) )
      return stats

   if hourSnapshotDir.snapshot:
      detail.statsLastHour = snapshotToStats(
            hourSnapshotDir.snapshot.values(), 'hour' )
      if daySnapshotDir.snapshot and (
            daySnapshotDir.snapshot.values()[ 0 ].timestampSnapshot <
            hourSnapshotDir.snapshot.values()[ 0 ].timestampSnapshot ):
         detail.statsLastDay = snapshotToStats(
               daySnapshotDir.snapshot.values(), 'day' )
   return detail
