#!/usr/bin/env python
# Copyright (c) 2013 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import collections
import linecache
import struct
import time

class PsocProgrammer( object ):

   def __init__( self, ham, logger ):
      """
      ham should be an object that supports:
      * read8( address ) -> int
      * read16( address ) -> int
      * write8( address, data )
      * write16( address, data )
      * read( address, count ) -> data (Python string)
      * write( address, data ) - data should be encoded as a Python string
      * sendByte( address )

      logger should be an object that supports:
      * info( *strings )
      * debug( *strings )
      """
      self.ham_ = ham
      self.logger_ = logger

   def readCyBootloadStatus( self ):
      # struct BootloadStatusReg {
      #    size = 16;
      #    hamType = Cy8c3666axi16BitHam;
      #    bits dualBootSupport     -> 7:7;
      #    bits singleBootSupport   -> 6:6;
      #    bits active              -> 3:2;
      #    bits running             -> 1:0;
      # }
      status = self.ham_.read16( 0xF3 )
      CyBootloadStatus = collections.namedtuple(
         "CyBootloadStatus", "dualBootSupport singleBootSupport active running" )
      return CyBootloadStatus( ( status >> 7 ) & 0b1, ( status >> 6 ) & 0b1,
                               ( status >> 2 ) & 0b11, ( status ) & 0b11 )

   def readBootPmbus( self, count ):
      return self.ham_.read( 0xFD, count )

   def writeBootPmbus( self, payLoadData ):
      return self.ham_.write( 0xFC, payLoadData )

   def checkResponse( self, responseCode ):
      errorCodes = {
            0x00:"Cmd successfully received and executed",
            0x01:"File not accessible",
            0x02:"Reached End of File",
            0x03:"Num of data bytes received not in expected range",
            0x04:"Data not in proper form",
            0x05:"Cmd not recognized",
            0x06:"Expected device does not match detected device",
            0x07:"Detected Bootloader version not supported",
            0x08:"Checksum does not match expected value",
            0x09:"Flash Array Invalid",
            0x0A:"Flash Row Invalid",
            0x0B:"Bootloader not ready to process data",
            0x0C:"Application invalid. Cannot set as active",
            0x0D:"Application currently marked as active",
            0x0F:"Unknown Error has occured",
            0xFF:"Operation Aborted"}
      if responseCode in errorCodes:
         if responseCode != 0x00: # log errors only
            self.logger_.info( "Code %x - %s" %
                               ( responseCode, errorCodes[ responseCode ] ) )
            return False
      else:
         self.logger_.info( "Invalid Response code!" )
         return False
      return True

   def extractBytesResponse( self, line, byteLoc, numBytes):
      binRaw = ''.join(bin(ord(x))[2:].zfill(8) for x in line )
      begin = ( byteLoc-1 ) *8
      end = begin + (numBytes * 8 )
      final  = binRaw[begin:end]
      return int(final, 2 )

   def addPayload( self, command, dataLength, data ):
      tmp = data
      acc = command + ( dataLength & 0xff ) + ( ( dataLength & 0xff00 ) >> 8 )
      for _ in range( 0, dataLength ):
         acc += tmp & 0xff
         tmp >>= 8
      checksum = 0xffff - acc 
      swapChecksum = self.reverseBytes( checksum ) #LSByte first
      addStart = 0x01
      addCmd = ( addStart << 8 ) | command
      if dataLength <= 0xff:
         addLength = (( addCmd << 8 ) | dataLength ) << 8              
      else:
         swapDataLength = self.reverseBytes( dataLength ) #LSByte first
         addLength = (( addCmd << 16 ) | swapDataLength )
      addData = ( addLength << ( dataLength*8 )) | data
      addChecksum = ( addData << 16 ) | swapChecksum
      addEnd = ( addChecksum << 8 ) | 0x17
      return addEnd

   def packData( self, data, dataLen ):
      revData = self.reverseBytes( data )
      arr = []
      for _ in range( 0, dataLen ):
         arr.append( revData & 0xff )
         revData >>= 8
      retVal = struct.pack( "%dB" % dataLen, *arr )
      return retVal

   def switchApp( self ): # XXX switchApp should be a DOS command
      status = self.readCyBootloadStatus()
      if not status.dualBootSupport:
         return 1, "Psoc image does not support application switching!"

      self.enterBootloader( "dual" )
      self.switchActiveAppl( 2 - status.active )
      self.exitBootloader()
      self.logger_.info( "Active image switched to %d \n"
                         "Power cycle board for effect" % ( 3 - status.active ) )
      return 0, ""

   def switchActiveAppl( self, activeApp ): # XXX maybe move this into switchApp
      payloadData = self.addPayload( command=0x36, dataLength=0x01, data=activeApp )
      self.writeBootPmbus( self.packData( payloadData, 1+7 ) )
      time.sleep( 10 ) #wait for image checksum recalculation
      response = self.readBootPmbus( -1 )
      retVal = self.extractBytesResponse( response, 2, 1 )
      self.checkResponse( retVal )
      return retVal

   def initDualImage( self ):
      # We don't expect to be in bootloader mode.
      status = self.readCyBootloadStatus()
      if not status.dualBootSupport:
         return 1, "Current image does not support application bootloading!"
      if status.active != status.running:
         self.enterBootloader( "dual" )
         self.switchActiveAppl( status.running - 1 )
         self.exitBootloader()
         self.logger_.info(
            "Active != Running -> Active switched to %d" % status.running )
      self.enterBootloader( "dual" )
      time.sleep( 1 )
      return 0, "", status

   def loadDualImage( self, file1, file2 ):
      files = { 0:file1, 1:file2 }

      retVal, error, status = self.initDualImage()
      if retVal:
         return retVal, error

      retVal, error = self.loadApplicationImage( files[ 2 - status.running ],
                                                 57, "dual" )
      self.switchActiveAppl( 2 - status.running )
      self.exitBootloader()
 
      if retVal != 0:
         self.logger_.info( "Bootloading file:%s failed!" %
                            files[ 2 - status.running ] )
      else:
         self.logger_.info( "Power cycle now to complete the process" )
      
      return retVal, error

   def initSingleImage( self ):
      # Discover if we are in bootloader mode. The bootload status
      # register can only be read if we are running the application
      # (ie. not in bootloader mode).
      inBootloaderMode = True

      for _ in xrange( 3 ):
         try:
            status = self.readCyBootloadStatus()
            inBootloaderMode = False
            break
         except self.ham_.SmbusError:
            time.sleep( 0.5 )
      if inBootloaderMode:
         self.logger_.info( "Psoc is already in bootloader" )
      else:
         if status.dualBootSupport:
            self.logger_.info(
               "Current image does not support single image bootloading!" )
            return 1, "Current image does not support single image bootloading!"

         self.enterBootloader( "single" )
      time.sleep( 1 )
      return 0, ""

   def loadSingleImage( self, fileName, force=True ):
      retVal, error = self.initSingleImage()
      if retVal:
         return retVal, error
      retVal, error = self.loadApplicationImage( fileName, 100, "single", force )
      if retVal == 0:
         self.logger_.info( "Successfully programmed image" )
      else:
         self.logger_.info( "Programming failed" )
         self.logger_.info( str( retVal ) + ": " + error )
      self.exitBootloader()
      return retVal, error

   def loadAcyacdImage( self, fileName, force=True ):
      # load a unified acyacd psoc image file.
      if not fileName.endswith( "acyacd" ):
         return 1, "Incorrect file extention.  Requires a .acyacd file."

      # set defaults for single image PSOC
      psocType = "single"
      chunk = 100

      with open( fileName ) as bootFile:
         headers = self.parseHeaders( bootFile )

      if "psoc_type" in headers:
         psocType = headers[ "psoc_type" ]
      if "chunk" in headers:
         chunk = int( headers[ "chunk" ] )

      if psocType == "single":
         retVal, error = self.initSingleImage()
         if retVal:
            return retVal, error
         retVal, error = \
             self.loadApplicationImage( fileName, chunk, psocType, force )
      elif psocType == "dual":
         retVal, error, status = self.initDualImage()
         if retVal:
            return retVal, error
         imgIdx = 2 - status.running
         retVal, error = \
             self.loadApplicationImage( fileName, chunk, psocType, False,
                                        imgIdx )
         self.switchActiveAppl( 2 - status.running )
      self.exitBootloader()

      if retVal:
         self.logger_.info( "Programming failed" )
         self.logger_.info( str( retVal ) + ": " + error )
      else:
         self.logger_.info( "Successfully programmed image" )
      return retVal, error

   def getMetadata( self, imgType ):
      """
      Assumes self.enterBootloader has already been called.

      Returns an array of BootloaderMetadata types. 1 element for each
      application image on the psoc.
      """
      BootloaderMetadata = collections.namedtuple(
         "BootloaderMetadata", [ "bootloaderVersion",
                                 "applicationVersion",
                                 "applicationDate" ] )

      imgCount = { "single" : 1, "dual" : 2 }[ imgType ]
      result = [ None ] * imgCount
      for imgNum in range( imgCount ):
         payloadData = self.addPayload(
            command=0x3C, 
            dataLength=0x01, 
            data=imgNum )
         self.writeBootPmbus( self.packData( payloadData, 1+7 ) )
         time.sleep( 1 )
         response = self.readBootPmbus( count=63 )
         respStatusCode = self.extractBytesResponse( response, 2, 1 )
         self.logger_.info( "Reading Metadata; image: %d" % ( imgNum + 1 ) )
         if self.checkResponse( respStatusCode ):
            metadata = BootloaderMetadata(
               self.extractBytesResponse( response, 23, 2 ),
               self.extractBytesResponse( response, 26, 2 ),
               "%d/%d/%d" % ( self.extractBytesResponse( response, 29, 1 ),
                              self.extractBytesResponse( response, 30, 1 ),
                              self.extractBytesResponse( response, 31, 2 ) ) )
            result[ imgNum ] = metadata
      return result

   def enterBootloader( self, imgType ):
      if imgType == "single":
         self.ham_.sendByte( 0xFB ) #send pmbus bootloader_enter cmd
         time.sleep( 1 )
      self.sendDatalessPacket( 0x38 ) #enter bootloader
      time.sleep( 0.05 )
      response = self.readBootPmbus( 16 )
      self.checkResponse( self.extractBytesResponse( response, 2, 1 ) )
      self.logger_.info(
         "SiliconId: %x, SiliconRev: %x, BootLoaderVersion: %x" %
         ( self.reverseBytes (self.extractBytesResponse( response, 5, 4 )),
           self.extractBytesResponse( response, 9, 1 ),
           self.reverseBytes( self.extractBytesResponse( response, 10, 3 ))))
      return 0

   def exitBootloader( self ):
      self.sendDatalessPacket( 0x3B )

   def sendDatalessPacket( self, cmdCode ):
      #<start-0x01><cmd code>< 2B data length-0x0000><2B checksum><end-0x17> 
      checksum = self.reverseBytes( 0xffff - cmdCode )
      packet = ((((( 0x01 << 8 ) | cmdCode ) << 32 ) | checksum ) << 8 ) | 0x17
      self.writeBootPmbus( self.packData( packet, 7 ) )
      return 0

   def reverseBytes( self, data ):
      val = 0
      if data <= 0xff:
         return data
      elif ( 0xff < data <=  0xffff ):
         return ((  data & 0xff ) << 8 ) | ( data >> 8 )
      else:
         while ( data > 0 ):
            d = data & 0xff
            val |= d
            data >>= 8
            if data != 0:
               val <<= 8
         return val

   def getFlashSize( self, fileName ):
      rawLine = linecache.getline( fileName, 2 )
      line = rawLine.strip(':\r\n')
      arrayId = self.extractBytes( line, 1, 1 ) 
      payloadData = self.addPayload( command=0x32, dataLength=0x01, data=arrayId )
      self.writeBootPmbus( self.packData( payloadData, 8 ) )
      time.sleep( 0.01 ) 
      response = self.readBootPmbus( 12 )
      self.checkResponse( self.extractBytesResponse( response, 2, 1 ) )
      self.logger_.info(
         'Addr FirstRow: %x , LastRow: %x' %
         ( self.reverseBytes( self.extractBytesResponse( response, 5, 2 )),
           self.reverseBytes( self.extractBytesResponse( response, 7, 2 )) ) )
      return 0

   @staticmethod
   def parseHeaders( openFile ):
      headers = {}
      while True:
         line = openFile.readline().strip()
         if not line or line.startswith( "#IMAGE_START_0" ):
            break
         keyValue = line.split( ":" )
         key = keyValue[ 0 ].strip().lower()
         value = keyValue[ 1 ].strip()
         headers[ key ] = value
      return headers

   @staticmethod
   def setImageFileStart( openFile, num ):
      while True:
         line = openFile.readline().strip()
         if not line or line.startswith( "#IMAGE_START_%d" % num ):
            break

   def loadApplicationImage( self, fileName, chunk, imgType, force=False, imgIdx=0 ):
      if not fileName.endswith( "cyacd" ):
         return 1, "Incorrect file extention. Requires .cyacd or .acyacd file."

      with open( fileName ) as bootFile:
         if fileName.endswith( "acyacd" ):
            headers = self.parseHeaders( bootFile )
            if imgIdx:
               self.setImageFileStart( bootFile, imgIdx )
         else:
            headers = {}

         if "bootloader" in headers:
            expectedBootloaderVersion = int( headers[ "bootloader" ], 16 )
            metadata = self.getMetadata( imgType )[ 0 ]
            if metadata:
               hardwareBootloaderVersion = metadata.bootloaderVersion
               if expectedBootloaderVersion != hardwareBootloaderVersion:
                  return -1, (
                     "Bootloader version not compatible with application image. "
                     "Application image compatible with %s, hardware contains %s."
                     % ( hex( expectedBootloaderVersion ),
                         hex( hardwareBootloaderVersion ) ) )
            elif not force:
               return -1, (
                  "Unable to read bootloader version. "
                  "Use --force option to continue anyway." )
            else:
               self.logger_.info( "Unable to read bootloader version, "
                                  "but force option set." )

         line = bootFile.readline().strip( ':\r\n' )
         checksumType = self.extractBytes( line, 6, 1 )
         self.logger_.info(
            'File SiliconID: %x  SiliconRev: %x  ChecksumType: %x' %
            ( self.extractBytes( line, 1, 4 ),
              self.extractBytes( line, 5, 1 ),
              checksumType ) )
         if checksumType:
            return 1, "CRC 16 not supported!"
   
         self.syncBootloader()
         for rawLine in bootFile:
            line = rawLine.strip(':\r\n')
            if line.startswith( "#IMAGE_END" ):
               break
            attempts = 0
            while True:     #Breaks when Rowchecksum verifies
               # lineLength = ( num hex digits/2 ) - checksum byte
               lineLength = ( len( line ) / 2 ) - 1 
               arrayId = self.extractBytes( line, 1, 1 )
               rowId = self.extractBytes( line, 2, 2 )
               dataLength = self.extractBytes( line, 4, 2 )
               rowChecksum = self.extractBytes( line, lineLength+1, 1 )
               index = 6 #skip overhead bytes
               returnChecksum = ( ( rowChecksum ) + ( arrayId ) +
                                  ( rowId & 0xff ) + ( rowId >> 8 ) +
                                  ( dataLength & 0xff ) +
                                  ( dataLength >> 8 ) ) & 0xff
               while dataLength != 0:
                  if dataLength > chunk: 
                     data = self.extractBytes( line, index, chunk )
                     payloadData = self.addPayload(
                           command=0x37, dataLength=chunk, data=data )
                     time.sleep( 0.01 )
                     self.writeBootPmbus( self.packData( payloadData, chunk+7 ) )
                     response = self.readBootPmbus( 8 )
                     self.checkResponse( self.extractBytesResponse( response, 2, 1 ))
                     index += chunk 
                     dataLength -= chunk 
                  else:
                     data = self.extractBytes( line, index, dataLength )
                     pktHead = ( ( ( ( arrayId << 8) | ( rowId & 0xff ) ) << 8 ) |
                                 ( rowId >> 8 ) )
                     pktBody = ( pktHead << dataLength * 8 ) | data
                     payloadData = self.addPayload( 
                           command=0x39, dataLength=dataLength+3, data = pktBody ) 
                     self.writeBootPmbus(
                        self.packData( payloadData, dataLength+3+7 ) )
                     time.sleep( 0.05 )
                     response = self.readBootPmbus( 8 )
                     self.checkResponse(
                        self.extractBytesResponse( response, 2, 1 ) )
                     index += dataLength
                     dataLength -= dataLength
                     attempts += 1

               verifyRowStatus = self.isRowChecksumValid( rowId, pktHead,
                                                          returnChecksum )
               if verifyRowStatus:
                  break
               elif attempts >= 3:
                  # Erase row
                  self.logger_.info( "Row %x write failed 3 times! "
                                     "Erasing row to exit" % rowId )
                  payloadData = self.addPayload( 0x34, 0x03, pktHead )
                  self.writeBootPmbus( self.packData( payloadData, 10 ) )
                  response = self.readBootPmbus( 8 )
                  self.checkResponse( self.extractBytesResponse( response, 2, 1 ))
                  return -1, "Failed to write row %x" % rowId

      retVal, error = self.verifyApplicationChecksum()
      return retVal, error

   def verifyApplicationChecksum( self ):
      self.sendDatalessPacket( 0x31 )
      time.sleep( 2 )
      response = self.readBootPmbus( 9 )
      self.checkResponse( self.extractBytesResponse( response, 2, 1 )  )
      responseData = self.extractBytesResponse( response, 5, 1 ) 
      if responseData == 0:
         self.logger_.info( 'Application checksum verification Failed!' )
         return -1, 'Application checksum verification Failed!'
      else:
         self.logger_.info( 'Application checksum verification Successful!' )
         return 0, ""

   def extractBytes( self, line, byteLoc, numBytes):
      binRaw = ''.join( bin( int( x, 16 ) )[ 2: ].zfill( 4 ) for x in line )
      begin = ( byteLoc-1 ) *8
      end = begin + ( numBytes * 8 )
      final = binRaw[ begin : end ]
      if final:
         return int( final, 2 )
      return 0

   def syncBootloader( self ):
      self.sendDatalessPacket( 0x35 ) 
      time.sleep( 0.050 )

   def isRowChecksumValid( self, rowId, pktHead, flashRowChecksum ):
      payloadData = self.addPayload( 0x3A, 0x03, pktHead )
      self.writeBootPmbus( self.packData( payloadData, 10 ) )
      time.sleep( 0.010 )
      response = self.readBootPmbus( 9 )
      returnedChecksum = self.extractBytesResponse( response, 5, 1 ) 
      self.checkResponse( self.extractBytesResponse( response, 2, 1 ))
      if returnedChecksum == flashRowChecksum:
         self.logger_.debug( 'Row checksum verification Passed: Row %x' % rowId )
         return True
      else:
         self.logger_.info( 'Row checksum verification Failed!: Row %x' % rowId )
         self.logger_.info( 'Checksum: %x received: %x' %
                            ( flashRowChecksum, returnedChecksum ) )
         return False

   # Read in PSOC rail voltage update file.  Format is:
   # railnum,voltage,margin
   @staticmethod
   def readPsocRailConfig( fileName ):
      config = []
      with open( fileName ) as configFile:
         while True:
            line = configFile.readline().strip()
            if not line:
               break
            params = line.split( "," )
            rail = int( params[ 0 ].strip() )
            vout = float( params[ 1 ].strip() )
            margin = float( params[ 2 ].strip() )
            config.append( ( rail, vout, margin ) )
      return config

   # Write out PSOC rail voltage update file.
   @staticmethod
   def writePsocRailConfig( fileName, railConfig ):
      with open( fileName, "w" ) as dstFile:
         for ( rail, vout, margin ) in railConfig:
            dstFile.write( "%d,%.3f,%.3f\n" % ( rail, vout, margin ) )
         dstFile.write( "\n" )

   # update Rails with specified voltage and margin.
   def updateRailVoltageMargin( self, rail, vout, margin=None, save=False ):

      vout = float( vout )

      # page number is from 0, rail numbering from 1
      self.logger_.debug( 'Setting page to %d' % ( rail - 1 ) )
      self.ham_.write8( 0x0, rail - 1 )

      linearExponent = self.ham_.read8( 0x20 )
      # the linear exponent read back is in 2 complement, 5 bit so convert
      linearExponent = -( ( ~linearExponent & 0x1f ) + 1 )

      voutReg = int( vout / ( pow( 2, linearExponent ) ) )
      self.logger_.debug( 'Write VOUT_COMMAND 0x%x' % voutReg )
      self.ham_.write16( 0x21, voutReg )

      # adjust margin if requested
      if margin:
         margin = float( margin )
         marginHigh = vout + ( vout * ( margin / 100.00 ) )
         marginLow = vout - ( vout * ( margin / 100.00 ) )

         voutMarginHighReg = int( marginHigh / ( pow( 2, linearExponent ) ) )
         self.ham_.write16( 0x25, voutMarginHighReg )

         voutMarginLowReg = int( marginLow / ( pow( 2, linearExponent ) ) )
         self.ham_.write16( 0x26, voutMarginLowReg )

      if save:
         # Wait for 250ms after STORE_USER_ALL ( as per PSoC PMBus User Guide ).
         self.logger_.info( "writing STORE_USER_ALL" )
         self.ham_.sendByte( 0x15 )
         time.sleep( 0.25 )

   # update Rails with specified voltage and margin.
   def readRailVoltage( self, rail ):

      # page number is from 0, rail numbering from 1
      self.logger_.debug( 'Setting page to %d' % ( rail - 1 ) )
      self.ham_.write8( 0x0, rail - 1 )

      linearExponent = self.ham_.read8( 0x20 )
      # the linear exponent read back is in 2 complement, 5 bit so convert
      linearExponent = -( ( ~linearExponent & 0x1f ) + 1 )
      vReg = self.ham_.read16( 0x21 )
      self.logger_.debug( 'Read VOUT_COMMAND 0x%x (%d)' % ( vReg, vReg ) )
      vout = vReg * pow( 2, linearExponent )
      self.logger_.info( 'voltage %f' % vout )

      return vout
