# Copyright (c) 2010, 2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import BasicCli
import Cell
import CliCommand
from CliModel import Dict, Int, Model, Str
import CliParser
import CliPlugin.TechSupportCli as TechSupportCli
import CliToken.Clear
import ShowCommand
import TableOutput
import Tac
import Tracing
import LazyMount

__defaultTraceHandle__ = Tracing.Handle( "PciCli" )

entityManager = None
pciDeviceConfigDir = None
sysDevices = None
cellDevices = None
pciDeviceStatusDir = None
pcieSwitchStatusDir = None
cellCliRequest = None
sysCliRequest = None

def pciGuard( mode, token ):
   if pciDeviceConfigDir and len( pciDeviceConfigDir.pciDeviceConfig ) > 0:
      return None
   else:
      return CliParser.guardNotThisPlatform

nodePci = CliCommand.guardedKeyword( 'pci',
      helpdesc='Display PCIe devices with error counters', guard=pciGuard )

#--------------------------------------------------------------------------------
# show pci
#--------------------------------------------------------------------------------
def getPcieSwitchSmbusErrors():
   return [ ( stat.name, stat.registerAccessErrorCounter )
         for stat in pcieSwitchStatusDir.values() ]

def getPciDevices():
   pciDevices = {}
   for pciDevice in pciDeviceStatusDir.pciDeviceStatus.values():
      pciAddr = pciDevice.addr
      pciId = "%02x:%02x.%x" % \
               ( pciAddr.busNumber, pciAddr.device, pciAddr.function )
      pciDevices.update( { pciId:pciDevice } )
   return pciDevices

def getPciErrors():
   pciDevices = getPciDevices()
   ret = []
   for pciId in sorted( pciDevices ):
      pciDevice = pciDevices[ pciId ]
      ret.append( ( pciDevice.name, pciId, pciDevice.correctableError,
            pciDevice.uncorrectableNonFatalError,
            pciDevice.uncorrectableFatalError ) )
   return ret

def updateErrorRequestTimes():
   if cellDevices and not entityManager.locallyReadOnly():
      cellCliRequest.scanErrorRequest = Tac.now()
   if sysDevices and not entityManager.locallyReadOnly():
      sysCliRequest.scanErrorRequest = Tac.now()

def prepShowPci( mode, args ):
   updateErrorRequestTimes()

class PciErrorSum( Model ):
   name = Str( help="PCI device name" )
   correctableErrors = Int( help="Correctable error counter" )
   nonFatalErrors = Int( help="Uncorrectable non-fatal error counter" )
   fatalErrors = Int( help="Uncorrectable fatal error counter" )

class PciErrors( Model ):
   #note, pciId (str) will be used as the key.
   pciIds = Dict( valueType=PciErrorSum, keyType=str,
                  help="Summary of error counters for each PCI device" )
   switchs = Dict( valueType=int,
                   help="Summary of smBus errors on PCIe switch" )

   def render( self ):
      header = ( "Name", "PciId", "CorrErr", "NonFatalErr", "FatalErr" )
      table = TableOutput.createTable( header )
      for pciId in sorted(self.pciIds):
         t = self.pciIds[ pciId ]
         table.newRow( t.name, pciId, t.correctableErrors,
               t.nonFatalErrors, t.fatalErrors )
      print table.output()

      if len( self.switchs ) > 0:
         print 

         header = ( "PcieSwitch", "SMBusERR" )
         table = TableOutput.createTable( header )
         for name in self.switchs:
            table.newRow( name, self.switchs[ name ] )
         print table.output()

class PciCmd( ShowCommand.ShowCliCommandClass ):
   syntax = 'show pci'
   data = {
      'pci' : nodePci,
   }

   cliModel = PciErrors
   privileged = True
   prepareFunction = prepShowPci

   @staticmethod
   def handler( mode, args ):
      pciList = getPciErrors()
      pciDict = {}
      for line in pciList:
         pciDict[ line[ 1 ] ] = PciErrorSum(
               name=line[ 0 ],
               correctableErrors=line[ 2 ],
               nonFatalErrors=line[ 3 ],
               fatalErrors=line[ 4 ] )

      smbusDict = {}
      if pcieSwitchStatusDir.values():
         smbusErrors = getPcieSwitchSmbusErrors()
         for line in smbusErrors:
            smbusDict[ line[0] ] = line[1]

      return PciErrors( pciIds=pciDict, switchs=smbusDict )

BasicCli.addShowCommandClass( PciCmd )

#-------------------------------------------------------------------------------
# The "show pci detail" command, in privileged mode.
#-------------------------------------------------------------------------------
def displayPciDetailErrors():
   pciDevices = getPciDevices()
   devSummary = {}
   for pciId in pciDevices:
      pciDevice = pciDevices[ pciId ]
      devSummary[ pciId ] = PciDetailError(
            name=pciDevice.name,
            pciErrors={
               "correctableError": pciDevice.correctableError,
               "receiverError": pciDevice.receiverError,
               "badTlp": pciDevice.badTlp,
               "badDllp": pciDevice.badDllp,
               "replayNumRollover": pciDevice.replayNumRollover,
               "replayTimerTimeout": pciDevice.replayTimerTimeout,
               "advisoryNonFatal": pciDevice.advisoryNonFatal,
               "uncorrectableNonFatalError": pciDevice.uncorrectableNonFatalError,
               "poisonedTlp": pciDevice.poisonedTlp,
               "ecrcError": pciDevice.ecrcError,
               "unsupportedRequestError": pciDevice.unsupportedRequestError,
               "completionTimeout": pciDevice.completionTimeout,
               "completerAbort": pciDevice.completerAbort,
               "unexpectedCompletion": pciDevice.unexpectedCompletion,
               "acsError": pciDevice.acsError,
               "uncorrectableFatalError": pciDevice.uncorrectableFatalError,
               "trainingError": pciDevice.trainingError,
               "dataLinkProtocolError": pciDevice.dataLinkProtocolError,
               "receiverOverflow": pciDevice.receiverOverflow,
               "flowControlProtocolError": pciDevice.flowControlProtocolError,
               "malformedTlp": pciDevice.malformedTlp
            } )
   return PciDetailErrors( devices=devSummary )

class PciDetailError( Model ):
   name = Str( help="Name of pci device" )
   pciErrors = Dict( valueType=int, keyType=str,
                  help="Error counts" )

class PciDetailErrors( Model ):
   devices = Dict( keyType=str, valueType=PciDetailError,
         help="Summary of pci device errors" )

   def render( self ):
      for pciId, pciDevice in sorted( self.devices.iteritems() ):
         print "%s %s" % \
            ( pciDevice.name, pciId )
         print "-" * 40
         print "%u Correctable Error" % \
            pciDevice.pciErrors[ "correctableError" ]
         print " %u Receiver Error" % \
            pciDevice.pciErrors[ "receiverError" ]
         print " %u Bad TLP" % \
            pciDevice.pciErrors[ "badTlp" ]
         print " %u Bad DLLP" % \
            pciDevice.pciErrors[ "badDllp" ]
         print " %u Replay Number Rollover" % \
            pciDevice.pciErrors[ "replayNumRollover" ]
         print " %u Replay Timer Time-out" % \
            pciDevice.pciErrors[ "replayTimerTimeout" ]
         print " %u Advisory Non-Fatal" % \
            pciDevice.pciErrors[ "advisoryNonFatal" ]
         print "%u Uncorrectable Non-Fatal Error" % \
            pciDevice.pciErrors[ "uncorrectableNonFatalError" ]
         print " %u Poisoned TLP Received" % \
            pciDevice.pciErrors[ "poisonedTlp" ]
         print " %u ECRC Check Failed" % \
            pciDevice.pciErrors[ "ecrcError" ]
         print " %u Unsupported Request" % \
            pciDevice.pciErrors[ "unsupportedRequestError" ]
         print " %u Completion Time-out" % \
            pciDevice.pciErrors[ "completionTimeout" ]
         print " %u Completion Abort" % \
            pciDevice.pciErrors[ "completerAbort" ]
         print " %u Unexpected Completion" % \
            pciDevice.pciErrors[ "unexpectedCompletion" ]
         print " %u ACS Error" % \
            pciDevice.pciErrors[ "acsError" ]
         print "%u Uncorrectable Fatal Error" % \
            pciDevice.pciErrors[ "uncorrectableFatalError" ]
         print " %u Training Error" % \
            pciDevice.pciErrors[ "trainingError" ]
         print " %u DLL Protocol Error" % \
            pciDevice.pciErrors[ "dataLinkProtocolError" ]
         print " %u Receiver Overflow" % \
            pciDevice.pciErrors[ "receiverOverflow" ]
         print " %u Flow Control Protocol Error" % \
            pciDevice.pciErrors[ "flowControlProtocolError" ]
         print " %u Malformed Tlp" % \
            pciDevice.pciErrors[ "malformedTlp" ]

#--------------------------------------------------------------------------------
# show pci detail
#--------------------------------------------------------------------------------
class PciDetailCmd( ShowCommand.ShowCliCommandClass ):
   syntax = 'show pci detail'
   data = {
      'pci' : nodePci,
      'detail' : 'Display PCIe devices with Advanced Error Reporting error counters',
   }
   cliModel = PciDetailErrors
   privileged = True

   @staticmethod
   def handler( mode, args ):
      updateErrorRequestTimes()
      return displayPciDetailErrors()

BasicCli.addShowCommandClass( PciDetailCmd )

#--------------------------------------------------------------------------------
# show pci tree
#--------------------------------------------------------------------------------
def displayPciTree( domainRoot, displayStr ):
   childBus = domainRoot.childBus
   displayStr = displayStr + domainRoot.name
   if childBus.values() == []:
      print displayStr + '\n'
      return
   length = displayStr.__len__()
   displayStr += '--'
   for _, dev in sorted( childBus.items() ):
      displayPciTree( dev, displayStr )
      displayStr = ' ' * length + '+-'

class PciTreeCmd( ShowCommand.ShowCliCommandClass ):
   syntax = 'show pci tree'
   data = {
      'pci' : nodePci,
      'tree' : 'Print user space view of Hardware PCI tree',
   }
   hidden = True
   privileged = True

   @staticmethod
   def handler( mode, args ):
      domainRoot = cellDevices.pciDeviceConfig[ 'DomainRoot0' ]
      if domainRoot is None:
         print 'Unable to find DomainRoot in Pci tree'
      displayPciTree( domainRoot, '' )

BasicCli.addShowCommandClass( PciTreeCmd )
                                                      
#--------------------------------------------------------------------------------
# clear pci
#--------------------------------------------------------------------------------
def waitForClearCounterResponse( requestTime ):
   try:
      Tac.waitFor( lambda: pciDeviceStatusDir.clearCounterResponse > 
                   requestTime, timeout=5.0, 
                   description="clear counter response to be updated",
                   warnAfter=None, maxDelay=0.1, sleep=True )
   except Tac.Timeout:
      print "Warning: PCIe error counters may not have reset yet"

class ClearPciCmd( CliCommand.CliCommandClass ):
   syntax = 'clear pci'
   data = {
      'clear' : CliToken.Clear.clearKwNode,
      'pci' : CliCommand.guardedKeyword( 'pci',
         helpdesc='Clear PCIe devices\' error counters', guard=pciGuard ),
   }

   @staticmethod
   def handler( mode, args ):
      if cellDevices:
         cellCliRequest.clearCounterRequest = Tac.now()
         waitForClearCounterResponse( cellCliRequest.clearCounterRequest )
      if sysDevices:
         sysCliRequest.clearCounterRequest = Tac.now()
         waitForClearCounterResponse( sysCliRequest.clearCounterRequest )

BasicCli.EnableMode.addCommandClass( ClearPciCmd )

#------------------------------------------------------
# Commands for 'show tech-support'.
#------------------------------------------------------
def _showTechCmds():
   if pciDeviceConfigDir and len( pciDeviceConfigDir.pciDeviceConfig ) > 0:
      return [ 'show pci',
               'show pci tree',
               'bash lspci' ]
   else:
      return [ ]

timeStamp = '2010-10-15 10:37:37'
TechSupportCli.registerShowTechSupportCmdCallback( timeStamp, _showTechCmds )


#------------------------------------------------------
# Plugin method
#------------------------------------------------------

def Plugin( em ):
   global entityManager
   global pciDeviceConfigDir
   global pciDeviceStatusDir
   global pcieSwitchStatusDir
   global cellDevices, sysDevices
   global cellCliRequest, sysCliRequest
   entityManager = em
   pciDeviceConfigDir = em.mount( "hardware/cell/%d/pciDeviceMap/config" %
                                  Cell.cellId(),
                                  "Hardware::PciDeviceConfigDir",
                                  "r" )
   pciDeviceStatusDir = LazyMount.mount( em, "cell/%d/hardware/pciDeviceStatusDir" %
                                         Cell.cellId(),
                                         "Hardware::PciDeviceStatusDir",
                                         "r" )
   pcieSwitchStatusDir = LazyMount.mount( em,
                                         "hardware/pcieSwitch/status/system",
                                         "Tac::Dir", "ri" )
   cellDevices = LazyMount.mount( em, "hardware/cell/%d/pciDeviceMap/config" % \
                                  Cell.cellId(), "Hardware::PciDeviceConfigDir",
                                  "r" )
   sysDevices = LazyMount.mount( em, "hardware/pciDeviceMap/config",
                                 "Hardware::PciDeviceConfigDir", "r" )
   cellCliRequest = LazyMount.mount(
      em, "cell/%d/hardware/pciDeviceMap/cliRequest" % \
      Cell.cellId(), "Hardware::PciDeviceCliRequest",
      "w" )
   sysCliRequest = LazyMount.mount(
      em, "hardware/pciDeviceMap/cliRequest",
      "Hardware::PciDeviceCliRequest", "w" )
