# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac
import Plugins
import SmbusUtil
import PlatformDesc
import AgentDirectory
import sys
import os
from time import asctime
from FpgaUtil import printToConsole

STATUS_WORD = 0x79
MFR_MODEL = 0x9a
MFR_ID = 0x99
MFR_SERIAL = 0x9e

try:
   logFile = open( '/var/log/psuFirmware.log', 'w' )
except IOError:
   logFile = None

def t0( arg, console=False ):
   curtime = asctime()
   msg = '\n'.join( [ '%s %s' % ( curtime, l ) for l in arg.split( '\n' ) ] )
   # The upgrade can be run from either the console or an ssh session, with
   # stdout either redirected to a file or outputting to the terminal. isatty()
   # is used to check whether we're outputting to the terminal to avoid printing
   # the same message twice
   if console and not sys.stdout.isatty():
      printToConsole( msg )
   print msg
   if logFile is not None:
      print >>logFile, msg

class PowerSupply( object ):
   '''This is a description of the power supply we're upgrading'''
   mfrModel = None
   deviceIdBase = None
   addrSize = 1
   readDelayMs='delay1ms'
   writeDelayMs='delay1ms'
   factory = SmbusUtil.Factory()
   firmwareFileName = None
   # Time in minutes the upgrade process takes
   upgradeTime = 3

   def __init__( self, slotId, smbusDesc, scdPciAddr ):
      self.slotId = slotId
      self.smbusDesc = smbusDesc
      self.scdPciAddr = scdPciAddr
      self.helper = self.factory.device( smbusDesc.accelId, smbusDesc.busId,
                                         self.deviceIdBase + smbusDesc.offset,
                                         self.addrSize,
                                         readDelayMs=self.readDelayMs,
                                         writeDelayMs=self.writeDelayMs,
                                         busTimeout='busTimeout1000ms',
                                         pciAddress=scdPciAddr )

   def identify( self ):
      mfrModel = self.getMfrModel()
      return self.mfrModel == mfrModel

   def getMfrModel( self ):
      mfrModel = ""
      try:
         mfrModel = self.helper.readString( MFR_MODEL )
      except: # pylint: disable-msg=W0702
         pass
      return mfrModel

   def powerLoss( self ):
      statusWord = self.helper.read16( STATUS_WORD )
      return statusWord & 0x0840

   def upgrade( self, firmwareFileName=None, force=False ):
      if firmwareFileName is None:
         firmwareFileName = self.firmwareFileName
      assert os.path.isfile( firmwareFileName )
      try:
         self.getPowerSupplyInfo()
      except: # pylint: disable-msg=W0702
         pass
      if force or self.shouldUpgrade():
         t0( "-----------------------------------------------------\n"
             "Upgrading the firmware in power supply %d.\n"
             "This process can take up to %d minutes.\n"
             "Please do not reboot your switch.\n"
             "-----------------------------------------------------" %
             ( self.slotId, self.upgradeTime ), console=True )
         try:
            poweredBeforeUpgrade = not self.powerLoss()
            self.upgradeSupply( firmwareFileName )
            if poweredBeforeUpgrade:
               Tac.waitFor( lambda : not self.powerLoss(), timeout=10 )
            t0( "Power supply %d upgraded successfully" % self.slotId,
                console=True )
         except Exception as e:
            t0( str( e ) )
            t0( "Power supply %d upgrade failed. Please try again or "
                "contact support" % self.slotId, console=True )
            # Re-raise exception, which should cancel all remaining upgrades.
            # If one upgrade fails, we cannot proceed since two failed upgrades may
            # cause the system to shut down permanently.
            raise Exception
      else:
         t0( "This firmware does not need to be upgraded" )

   def getPowerSupplyInfo( self ):
      t0( "Checking power supply %d firmware" % self.slotId )
      mfrId = self.helper.readString( MFR_ID )
      t0( "Manufacturer: %s" % mfrId )
      mfrModel = self.helper.readString( MFR_MODEL )
      t0( "Model: %s" % mfrModel )
      mfrSerial = self.helper.readString( MFR_SERIAL )
      t0( "Serial: %s" % mfrSerial )

   def shouldUpgrade( self ):
      raise NotImplementedError

   def upgradeSupply( self, firmwareFileName ):
      raise NotImplementedError

class PowerSupplyPlugin( object ):
   '''This holds a list of PowerSupplies'''
   def __init__( self ):
      self.powerSupplies_ = []

   def registerPowerSupply( self, powerSupply ):
      self.powerSupplies_.append( powerSupply )

   def powerSupplies( self ):
      return self.powerSupplies_

def getPowerSupplies( platform ):
   powerSupplyPlugin = PowerSupplyPlugin()
   Plugins.loadPlugins( 'PowerSupplyPlugin', context=powerSupplyPlugin )
   powerSupplies = []
   slotId = 1
   for smbusDesc in platform.powerSupplySlotSmbusDescs():
      for psClass in powerSupplyPlugin.powerSupplies():
         ps = psClass( slotId, smbusDesc, platform.scdPciAddr() )
         if ps.identify():
            powerSupplies.append( ps )
            break
      slotId += 1
   return powerSupplies

def upgradeFirmware( slotToUpgrade=None, firmwareFileName=None, force=False ):
   if AgentDirectory.agents( 'ar' ):
      t0( "Cannot upgrade power supply firmware while agents are running" )
      raise Exception

   platform = PlatformDesc.getPlatform()
   if platform is None:
      t0( "Platform not found. Skipping upgrading power supplies" )
      return
   t0( "Platform: {}".format( platform.name ) )

   powerSupplies = getPowerSupplies( platform )
   if not powerSupplies:
      t0( "No power supplies identified" )
      return

   for ps in powerSupplies:
      if slotToUpgrade is not None and ps.slotId != slotToUpgrade:
         continue
      if not ps.powerLoss():
         numSuppliesOn = len( [ p for p in powerSupplies if not p.powerLoss() ] )
         if numSuppliesOn < 2:
            t0( "Skipping upgrading power supply %d because it is the only\n"
                "power supply that is powered on" % ps.slotId )
            continue
      ps.upgrade( firmwareFileName, force )
