#!/usr/bin/env python
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import sys
import MicrosemiLib

class Microsemi( object ):
   base = MicrosemiLib.MicrosemiBase()
   speeds = ( "2.5G", "5G", "8G", "16G" )
   majorState = ( "DETECT", "POLLING", "CFG", "L0", "RECOVERY",
                  "DISABLE", "LOOPBACK", "HOT_RESET", "TX_LOS", "L1" )

   def unbind( self,
               DSP,            # DSPs 1,2,3 exists on eval board
               partition = 0,
               flags=0x2 ):
      status = self.base.doGas( self.base.MRPC_P2P_UNBIND | int( partition ) << 8 |
                                int( DSP ) << 16 | int( flags ) << 24,
                                self.base.MRPC_PORTPARTP2P )
      return status

   def bind( self,
             port,
             DSP=1, # DSPs 1,2,3 exists on eval board
             partition = 0 ):
      return self.base.bind( port, DSP, partition )

   def LTSSM( self ):
      numPorts = 128
      for port in self.base.allPorts():
         if self.base.doGas( numPorts << 24 | int( port ) << 8 |
                             self.base.MRPC_LTSSM_LogDump,
                         self.base.MRPC_DIAG_PORT_LTSSM_LOG ) == 0:
            log0 = 0
            states = dict()

            for offset in range( self.base.GAS_OUTPUT_DATA,
                                 self.base.GAS_OUTPUT_DATA + numPorts * 4, 4 ):
               log0 = self.base.read32( offset )
               states[ self.majorState[ self.base.getBits( log0, 10, 7 ) ] ] = 1
            print "Port %d" % int( port ), "Speed ", \
               self.speeds[ self.base.getBits( log0, 14, 13 ) ], \
               "Major State", states.keys()

   def bifurcation( self ):
      for stack in ( 1, 3, 4 ):
         assert self.base.doGas( stack << 8, self.base.MRPC_STACKBIF ) == 0
         value = self.base.readGroup4( 0x404 )
         for index in range( 0, 7 ):
            if value[ index ] != 0:
               print "Port %d.%d: %x" % ( stack, index, value[ index ] )

   def firmwareInfo( self ):
      print "Active Firmware Address             %08x" % self.base.read32(
         self.base.GAS_FirmwareAddress )
      print "Active Firmware Version             %08s" % self.base.getVersion(
         self.base.GAS_FirmwareVersion )
      print "Active Config Version               %08s" % self.base.getVersion(
         self.base.GAS_ConfigVersion )
      print "Active Firmware Config Address      %08x" % self.base.read32(
         self.base.GAS_VendorTableRevision )

      print "Inactive Firmware Address           %08x" % self.base.read32(
         self.base.GAS_InactiveFirmwareAddress )
      print "Inactive Firmware Version           %08s" % self.base.getVersion(
         self.base.GAS_InactiveFirmwareVersion )
      print "Inactive Config Version             %08s" % self.base.getVersion(
         self.base.GAS_InactiveConfigVersion )
      print "Inactive Firmware Config Address    %08x" % self.base.read32(
         self.base.GAS_InactiveVendorTableRevision )

      print "Config Revision                     %08x" % self.base.read32(
         self.base.GAS_VendorTableRevision )

   def firmwareToggle( self ):
      # Toggle Data Partition only
      self.base.doGas( 0x010002, self.base.MRPC_FWDNLD )
      self.reset()

   def firmwareDownload( self, fileName=None, pciSpec=None ):
      fileName = fileName or self.base.configImage
      self.base.firmwareDownload( fileName, pciSpec )
      print "Firmware download completed"

   def updateConfig( self, fileName=None ):
      self.base.updateConfig( fileName, True )

   def updateFirmware( self ):
      self.base.updateFirmware( True )

   def reset( self ):
      self.base.doGas( 0, self.base.MRPC_RESET )

   def imageInfo( self, fileName ):
      ( _, imgType, loadAddr,
        version, vendor, revision ) = self.base.getImageInfo( fileName )
      if imgType == 4:
         fileType = "configuration"
      else:
         fileType = "firmware"
      print "Type: %d (%s)" % ( imgType, fileType )
      print "Load addr: %08X" % loadAddr
      print "Version: %x" % version
      print "Vendor: %x" % vendor
      print "Revision: %x " % revision

   def temp( self ):
      self.base.doGas( 1, self.base.MRPC_DIETEMP )
      self.base.doGas( 2, self.base.MRPC_DIETEMP )
      print self.base.read32( self.base.GAS_OUTPUT_DATA ) / 100

   def lnkstat( self ):
      states = self.base.linksStates()

      for physPortId in states:
         if states[ physPortId ][ 'linkrate' ] == 255:
            break
         linkUp = states[ physPortId ][ 'linkrate' ] >> 7
         linkrate = states[ physPortId ][ 'linkrate' ] & 0x7f

         if states[ physPortId ][ 'usp' ]:
            usp = "usp"
         else:
            usp = "dsp"
         print "[%02d] part:%02d.%02d w:cfg[x%02d]-neg[x%02d]\
         stk:%d.%d %s dl_active:%d Rate: %s LTSSM: %s" % (
            physPortId, states[ physPortId ][ 'partID' ],
            states[ physPortId ][ 'logPortId' ],
            states[ physPortId ][ 'cfgLinkWidth' ],
            states[ physPortId ][ 'negLinkWidth' ],
            states[ physPortId ][ 'stkID' ] >> 4,
            states[ physPortId ][ 'stkID' ] & 0xf,
            usp,
            linkUp, self.speeds[ linkrate ],
              self.majorState[ int( states[ physPortId ][ 'LTSSM' ] ) & 0xff ] )

   def ntInfo( self ):
      managementEndpoints = self.base.allMicrosemiDevices()
      managementEndpoints = [ d for d in managementEndpoints if d.devfn() == 1 ]

      for d in managementEndpoints:
         self.base.managementEndpoint_ = d.address()
         self.base.microsemiFunctionBar = None

         print "Management endpoint is ", self.base.managementEndpoint_

         ( partNumber, partId, _, _ ) = self.base.readGroup8(
            self.base.MRPC_NTB_BASE )
         ntMap0 = self.base.read32( self.base.MRPC_NTB_BASE + 4 )
         ntMap1 = self.base.read32( self.base.MRPC_NTB_BASE + 8 )
         ( requesterID, _ ) = self.base.readGroup16( self.base.MRPC_NTB_BASE + 12 )

         print "partId=%d, numPartitions=%d, NTMap@0x%08x%08x, requesterID=0x%x" % (
            partId, partNumber, ntMap1, ntMap0, requesterID )


         for nt in range( 0, 2 ):
            part = self.base.NTPartitionGetInfo( nt )
            print "NT%d: locked=%d, NTStat=%x, Opc=%x, Control=%x, BarOffset=%x" % (
               nt, part[ 'lockedId' ], part[ 'ntStat' ], part[ 'ntOpc' ],
               part[ 'ntControl' ], part[ 'barOffset' ] )
            print "NT%d: Err: %d, ErrIndex %d, ntRequesterErr=%x, ntTableErr=%x" % (
               nt, part[ 'ntError' ], part[ 'ntErrorIndex' ],
               part[ 'ntRequesterError' ], part[ 'ntTableError' ] )
            print "NT%d: enabled=%s, requester0=%d, proxy0=%d, enabled=%s,"\
                  "requester1 = % d, proxy1 = % d" % (
               partId,
               part[ 'Requester0' ][ 'enabled' ],
               part[ 'Requester0' ][ 'RequesterID' ],
               part[ 'Requester0' ][ 'NTProxy' ],
               part[ 'Requester1' ][ 'enabled' ],
               part[ 'Requester1' ][ 'RequesterID' ],
               part[ 'Requester1' ][ 'NTProxy' ] )

            # Read BAR Setup

            for barNo in range( 0, 6 ):
               bar = self.base.NTBBarGetInfo( nt, barNo )
               if 'valid' in bar:
                  sys.stdout.write( "BAR%d  " % barNo )
                  sys.stdout.write( "%s" % bar[ 'mode' ] )
                  if not bar[ 'prefetch' ]:
                     sys.stdout.write( "Non-" )
                  sys.stdout.write( "prefetchable " )
                  sys.stdout.write( "%s " % bar[ 'mappingType' ] )
                  if "NT-Direct" in bar[ 'mappingType' ]:
                     print "BaseAddr=%016x, Size=%08xB, Pos=0x%016x, destPart=%d" % (
                        bar[ 'BaseAddr' ], bar[ 'NtTranslationSize' ],
                        bar[ 'NtTranslationPosition' ],
                        bar[ 'NtDestinationPartition' ] )
                  else:
                     print ""
               else:
                  if 'dw' in bar and bar[ 'dw' ] != [ 0, 0, 0, 0 ]:
                     print "Unconfigured BAR%d: " % barNo, bar

   def ntEnable( self ):
      self.base.ntEnable()

   def ntReset( self ):
      self.base.ntReset()

M = Microsemi()

if len( sys.argv ) > 1 and sys.argv[ 1 ] in dir( M ):
   getattr( M, sys.argv[ 1 ] )( *sys.argv[ 2: ] )
else:
   print "Unknown command - list: ", [
      x for x in dir( M ) if callable(
         getattr( M, x ) ) and not x.startswith( "_" ) ]
