# Copyright (c) 2015 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import re
import BasicCli
import Cell
import CliCommand
import CliMatcher
import CliParser
import LazyMount
import ShowCommand
import Tracing
from CliToken.Clear import clearKwNode
from CliToken.Platform import platformMatcherForShow
from CliToken.Platform import platformMatcherForClear

traceHandle = Tracing.Handle( 'SmbusCli' )
t5 = traceHandle.trace5

class Opcode( object ):
   rx = 1
   tx = 2

sysdbRoot = None
smbusTopology = None
counterDir = None
clearCounterDir = None
smbusConfig = None

class SmbusDevice():
   def __init__( self, accelId, busId, deviceId, deviceAddr, 
                 deviceName="SmbusDevice" ):
      self.accelId = accelId
      self.busId = busId
      self.deviceId = deviceId
      self.deviceAddr = deviceAddr
      self.deviceName = deviceName

def counterGuard( mode, token ):
   if "counter" in sysdbRoot[ "hardware" ].entryState.iterkeys():
      if token in sysdbRoot[ "hardware" ][ "counter" ].entryState.iterkeys():
         return None
   return CliParser.guardNotThisPlatform

def isModular():
   return True if Cell.cellType() != "fixed" else False

# "show platform ?" only shows "smbus" option when Smbus agent is running
matcherSmbus = CliMatcher.KeywordMatcher(
   'smbus',
   helpdesc='Smbus-device info' )

nodeSmbus = CliCommand.Node(
   matcher=matcherSmbus,
   guard=counterGuard )

def parseBus( busName ):
   m = re.match( r"bus(\d+)\:(\d+)", busName )
   if m:
      return m.groups()
   return ( None, None )

def createSmbusDeviceMap():
   smbusDeviceMap = {}
   for nodeType, node in smbusTopology.node.iteritems():

      if nodeType == "Chassis":
         continue

      supeAccelId = None
      supeBusId = None
      # Traverse line/fabric card smbus mux
      if ( ( nodeType.startswith( "Linecard" ) or
             nodeType.startswith( "Fabric" ) or
             nodeType.startswith( "Switchcard" ) ) and
           ( "supe1Bus" in node.hwSmbus ) and
           ( "supe2Bus" in node.hwSmbus ) ):
         # The upstream bus addresses are the same for
         # active and standby supes
         if node.hwSmbus[ "supe1Bus" ].connectionEndpoint[ 0 ].otherEndConnectedTo:
            upstreamBus = node.hwSmbus[ "supe1Bus" ].connectionEndpoint[
                          0 ].otherEndConnectedTo.name
         else:
            upstreamBus = node.hwSmbus[ "supe2Bus" ].connectionEndpoint[
                          0 ].otherEndConnectedTo.name

         if upstreamBus.startswith( "bus" ):
            supeAccelId, supeBusId = parseBus( upstreamBus )

      for busAddr, bus in node.hwSmbus.iteritems():
         if not busAddr.startswith( "bus" ) and \
            not busAddr.startswith( "cardBus" ) and \
            not busAddr.startswith( "isolatorBus" ):
            continue

         if busAddr.startswith( "bus" ):
            accelId, busId = parseBus( busAddr )
            if accelId is None:
               continue
         elif supeAccelId is not None:
            # Line/Fabric Smbus devices accessed by supe
            # i.e. cardBus or isolatorBus
            accelId = supeAccelId
            busId = supeBusId
         elif 0 in bus.connectionEndpoint:
            # Card bus connected directly to supervisor accel (no mux). There may be
            # multiple buses on a card
            endpoint = bus.connectionEndpoint[ 0 ].otherEndConnectedTo
            if not endpoint:
               continue
            accelId, busId = parseBus( endpoint.name )
            if accelId is None:
               continue
         else:
            continue

         if len( bus.device ):
            for deviceName, device in bus.device.iteritems():
               deviceId = device.deviceId
               deviceAddr = ""

               if isModular():
                  if nodeType in [ "1", "2" ]:
                     # Supervisor
                     deviceAddr += "Supervisor%s/" % nodeType
                  elif ( nodeType.startswith( "Linecard" ) or
                         nodeType.startswith( "Fabric" ) ) and \
                       ( busAddr.startswith( "cardBus" ) or
                         busAddr.startswith( "isolator" ) ):
                     # Line/Fabric Smbus devices accessed by current supe
                     deviceAddr += "Supervisor%d/" % Cell.activeCell()
                  elif nodeType.startswith( "Fabric" ):
                     # Fabriccards
                     deviceAddr += "Supervisor%d/" % Cell.activeCell()
                  else:
                     # Linecards
                     deviceAddr += "%s/" % nodeType

               deviceAddr += "%02d/%02d/0x%02x" % ( int( accelId ),
                                int( busId ), int( deviceId ) )

               if deviceName.startswith( "powerSupply" ):
                  # The power supply is represented with Smbus
                  # address 0x0 on the topology. For fixed systems,
                  # the power supply has the FRU and PMBUS address
                  # offsets fixed at 0x50 and 0x58, respectively.
                  # Psu Fru
                  fruAddr = "0x%02x" % ( deviceId + 0x50 )
                  pmbusAddr = "0x%02x" % ( deviceId + 0x58 )
                  for addr in [ fruAddr, pmbusAddr ]:
                     newAddr = deviceAddr.replace( "0x00", addr )
                     newSmbusDevice = SmbusDevice( accelId, busId, deviceId, newAddr,
                                                   deviceName )
                     smbusDeviceMap[ newAddr ] = newSmbusDevice
               else:
                  if re.match( r"^([0-9]*)$", deviceName ):
                     # For tempsensors and Sol chip, use "modelName" instead
                     deviceName = "SmbusDevice"
                     if device.modelName:
                        deviceName =  device.modelName
                     elif device.api:
                        deviceName = device.api.split( "-" )[ 0 ]
                  elif deviceName.find( "PowerController" ) >= 0:
                     # Abbreviate PowerController to fit the CLI column
                     deviceName = deviceName.replace( "PowerController", "DPM" )
                  elif deviceName.startswith( "Ethernet" ) or \
                       deviceName.startswith( "Xcvr" ):
                     # Xcvr devices are represented with Smbus address 0x0
                     # on the topology. The 0x0 address is assigned
                     # by the following code in the FDL
                     # xcvrCtrl.xcvrController.smbusDeviceBase = ( xcvrName, 0x0 )
                     deviceAddr = deviceAddr.replace( "0x00", "0x50" )
                  newSmbusDevice = SmbusDevice( accelId, busId, deviceId, deviceAddr,
                                                deviceName )
                  smbusDeviceMap[ deviceAddr ] = newSmbusDevice
   return smbusDeviceMap

def accelGeneration( hostType, hostId ):
   host = smbusConfig.get( hostType )
   if not host:
      return False
   accel = host.get( hostId )
   if not accel:
      return False
   return accel.engineGenerationId

def sameGeneration( gen1, gen2 ):
   return gen1 and gen1.valid and gen1 == gen2

def smbusCounters():
   for hostType, hostDir in counterDir.iteritems():
      for hostId, host in hostDir.iteritems():
         accelGen = accelGeneration( hostType, hostId )
         if not sameGeneration( accelGen, host.generation ):
            continue
         for deviceId, device in host.device.iteritems():
            yield Counter( hostType, hostId, deviceId, device )

def getCounterInfo():
   # smbusDevices is a collection of all registered smbus devices in the system
   smbusDevices = createSmbusDeviceMap()
   # counters is a collection of all smbus counters in the system. There are some
   # counters that appear without a registered smbus device, so we must expose both
   counters = { counter.addr : counter for counter in smbusCounters() }
   output = []
   # We must iterate through both the registered smbus devices as well as the
   # active counters in order to achieve a full collection of smbus activity
   for addr in sorted( smbusDevices.keys() ):
      smbusDevice = smbusDevices[ addr ]
      counter = counters.get( addr )
      output.append( ( smbusDevice.deviceName, addr, counter ) )

   for addr in sorted( counters.keys() ):
      counter = counters[ addr ]
      # if addr is already in smbusDevices, then we've already added information
      # about this one to the output, so we skip it. Else, we do as follows
      if addr not in smbusDevices:
         t5( "Found unregistered smbus counter at address", addr )
         supeRe = re.compile( r'Supervisor(\d+)\/(\d+)\/(\d+)\/0x(54|5c|3c)' )
         if isModular() and supeRe.match( addr ):
            deviceName = "powerSupply"
         else:
            deviceName = "SmbusDevice"
         output.append( ( deviceName, addr, counter ) )
   return output

class Counter( object ):
   def __init__( self, hostType, hostId, deviceId, device ):
      self.hostType = hostType
      self.hostId = hostId
      self.deviceId = deviceId
      self._txCount = device.txCount
      self.txRate = device.txRate
      self._rxCount = device.rxCount
      self.rxRate = device.rxRate
      self._timeoutErrorCount = device.timeoutErrorCount
      self._ackErrorCount = device.ackErrorCount
      self._busConflictErrorCount = device.busConflictErrorCount
      self.hostName = "{}/{}".format( self.hostType, self.hostId )
      if isModular():
         self.addr = ""
         if self.hostType == "cell":
            self.addr += "Supervisor"
         self.addr += "%s/%s" % ( self.hostId, self.deviceId )
      else:
         self.addr = self.deviceId


   def getAllCounts( self ):
      accelGen = accelGeneration( self.hostType, self.hostId )
      host = clearCounterDir.host.get( self.hostName )
      txClear, rxClear = 0, 0
      timeoutClear, ackClear, busClear = 0, 0, 0
      if host and sameGeneration( accelGen, host.generation ):
         device = host.device.get( self.deviceId )
         if device:
            txClear = device.lastTxClear
            rxClear = device.lastRxClear
            timeoutClear = device.lastTimeoutClear
            ackClear = device.lastAckClear
            busClear = device.lastBusConflictClear
      return [ self._rxCount - rxClear,
               self._txCount - txClear,
               self._timeoutErrorCount - timeoutClear,
               self._ackErrorCount - ackClear,
               self._busConflictErrorCount - busClear ]

   def readCount( self ):
      rxval, txval, _, _2, _3 = self.getAllCounts()
      return rxval, txval

   def errorCount( self ):
      _, _2, timeoutClear, ackClear, busClear = self.getAllCounts()
      return timeoutClear, ackClear, busClear

   def clear( self ):
      host = clearCounterDir.newHost( self.hostName )
      clearCounter = host.newDevice( self.deviceId )
      clearCounter.lastTxClear = self._txCount
      clearCounter.lastRxClear = self._rxCount
      clearCounter.lastTimeoutClear = self._timeoutErrorCount
      clearCounter.lastAckClear = self._ackErrorCount
      clearCounter.lastBusConflictClear = self._busConflictErrorCount
      host.generation = accelGeneration( self.hostType, self.hostId )

def doShowSmbusCounters():
   banner = "{0:>73}\n".format( "Est. Rx" )
   banner += "{0:<18} {1:<24} {2:>20} {3:>10}\n".format(
                "Device", "Address", "Rx bytes", "bytes/sec" )
   banner += "{0:<18} {1:<24} {2:>20} {3:>10}\n".format(
                "-" * 18, "-" * 24, "-" * 20, "-" * 10 )

   rxRet = []
   txRet = []
   for deviceName, addr, counter in getCounterInfo():
      if counter is None:
         # counter may be none if theres an inactive device
         rxCount = txCount = rxRate = txRate = 0
      else:
         rxCount, txCount = counter.readCount()
         rxRate = counter.rxRate
         txRate = counter.txRate
      rxRet.append( "{0:<18} {1:<24} {2:>20} {3:>10.2f}\n".format(
         deviceName, addr.ljust( 24 ), rxCount, rxRate ) )

      txRet.append( "{0:<18} {1:<24} {2:>20} {3:>10.2f}\n".format(
         deviceName, addr.ljust( 24 ), txCount, txRate ) )
   print "%s%s\n%s%s" % ( banner,
                          "".join( rxRet ),
                          banner.replace( "Rx", "Tx" ),
                          "".join( txRet ) ),

def doShowSmbusErrorCounters():
   banner = "{0:>53} {1:>7} {2:>13}\n".format( "Timeout", "Ack", "Bus Conflict" )
   banner += "{0:<18} {1:<24} {2:>9} {3:>7} {4:>13}\n".format(
                "Device", "Address", "Errors", "Errors", "Errors" )
   banner += "{0:<18} {1:<24} {2:>9} {3:>7} {4:>13}\n".format(
                "-" * 18, "-" * 24, "-" * 9, "-" * 7, "-" * 13 )

   errorRet = []
   for deviceName, addr, counter in getCounterInfo():
      if counter:
         timeoutErrorCount, ackErrorCount, busConflictErrorCount = \
            counter.errorCount()
      else:
         timeoutErrorCount = ackErrorCount = busConflictErrorCount = 0
      errorRet.append( "{0:<18} {1:<24} {2:>9} {3:>7} {4:>13}\n".format(
         deviceName, addr.ljust( 24 ), timeoutErrorCount,
         ackErrorCount, busConflictErrorCount ) )

   print "%s%s" % ( banner, "".join( errorRet ) )

#-----------------------------------------------------------------------------
# show platform smbus counters [ errors ]
#-----------------------------------------------------------------------------
class ShowPlatformSmbusCounters( ShowCommand.ShowCliCommandClass ):
   syntax = 'show platform smbus counters [ errors ]'
   data = {
      'platform' : platformMatcherForShow,
      'smbus' : nodeSmbus,
      'counters' : 'Hardware-access counters for each device',
      'errors' : 'Error info',
   }

   @staticmethod
   def handler( mode, args ):
      if 'errors' not in args:
         doShowSmbusCounters()
      else:
         doShowSmbusErrorCounters()

BasicCli.addShowCommandClass( ShowPlatformSmbusCounters )

#-----------------------------------------------------------------------------
# clear platform smbus counters
#-----------------------------------------------------------------------------
class ClearPlatformSmbusCounters( CliCommand.CliCommandClass ):
   syntax = 'clear platform smbus counters'
   data = {
      'clear' : clearKwNode,
      'platform' : platformMatcherForClear,
      'smbus' : nodeSmbus,
      'counters' : 'Hardware-access counters for each device',
   }

   @staticmethod
   def handler( mode, args ):
      clearCounterDir.host.clear()
      for counter in smbusCounters():
         counter.clear()

BasicCli.EnableMode.addCommandClass( ClearPlatformSmbusCounters )

def Plugin( entityManager ):
   global sysdbRoot, counterDir, smbusTopology, smbusConfig, clearCounterDir
   sysdbRoot = entityManager.root()
   counterDir = LazyMount.mount( entityManager,
                                 "hardware/counter/smbus", "Tac::Dir", "ri" )
   smbusTopology = LazyMount.mount( entityManager,
                                    "hardware/smbus/topology",
                                    "Hardware::SmbusTopology", "r" )

   clearCounterDir = LazyMount.mount( entityManager, "hardware/counter/cli/smbus",
                                      "HardwareCounter::ClearCounterDir", "w" )

   smbusConfig = LazyMount.mount( entityManager, "hardware/smbus", "Tac::Dir", "ri" )
