# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function
import ctypes
from fcntl import ioctl
import os
import posix
import struct
import time

# smbus_access read or write markers
I2C_SMBUS_READ = 1
I2C_SMBUS_WRITE = 0

I2C_SLAVE = 0x0703
I2C_SLAVE_FORCE = 0x0706
I2C_SMBUS = 0x0720 # SMBus-level access

I2C_SMBUS_BYTE = 1
I2C_SMBUS_BYTE_DATA = 2
I2C_SMBUS_I2C_BLOCK_DATA = 8

class i2c_smbus_data( ctypes.Union ):
   _fields_ = [ ( 'byte', ctypes.c_uint8 ),
                ( 'word', ctypes.c_uint16 ),
                ( 'block', ctypes.c_uint8 * 34 ) ]

class i2c_smbus_ioctl_data( ctypes.Structure ):
   _fields_ = [ ( 'read_write', ctypes.c_uint8 ),
                ( 'command', ctypes.c_uint8 ),
                ( 'size', ctypes.c_int ),
                ( 'data', ctypes.POINTER( i2c_smbus_data ) ) ]

class TwoByteIdpromHam( object ):
   def __init__( self, busNumber, addr ):
      self.addr = addr
      self.busNumber = busNumber
      self.fd = -1
      self.filename = ''
      self.addrSize = 1

      #open fd
      self.open()

      # setSlaveAddr
      force = False 
      ioctl( self.fd, I2C_SLAVE_FORCE if force else I2C_SLAVE, self.addr )

   def open( self ):
      def tryOpenFile( filename ):
         try:
            self.filename = filename
            return posix.open( filename, os.O_RDWR )
         except OSError:
            return -1
      self.fd = tryOpenFile( '/dev/i2c/%d' % self.busNumber )
      if self.fd < 0:
         self.fd = tryOpenFile( '/dev/i2c-%d' % self.busNumber )
      if self.fd < 0:
         raise OSError

   # pylint: disable-msg=W0201
   def smbusAccess( self, read_write, command, size, data=None ):
      ar = i2c_smbus_ioctl_data()
      ar.read_write = read_write
      ar.command = command
      ar.size = size
      if not data:
         data = i2c_smbus_data()
      ar.data = ctypes.pointer( data )

      try:
         ioctl( self.fd, I2C_SMBUS, ar, 1 )
      except IOError as e:
         raise Exception( 'Smbus access to %s failed. Linux reports %r' %
            ( self.filename, str( e ) ) )

      return ar.data

   def write8( self, command, data ):
      _data = i2c_smbus_data()
      _data.byte = data # pylint: disable-msg=W0201
      self.smbusAccess( I2C_SMBUS_WRITE, command, I2C_SMBUS_BYTE_DATA, _data )
  
   def setAddressPointer( self, addr ):
      self.write8( (addr>>8)&0xff, addr&0xff )
      time.sleep( 0.01 )
      return 

   def receiveByte( self ):
      data = self.smbusAccess( I2C_SMBUS_READ, 0, I2C_SMBUS_BYTE )
      return data.contents.byte & 0xff

   def read( self, addr ):
      self.setAddressPointer( addr )
      return self.receiveByte()

   def readSequenceStr( self, address, count ):
      self.setAddressPointer( address )
      data = bytearray( count )
      for offset in range( count ):
         data[offset] = self.receiveByte()
      return ''.join( [ chr( x ) for x in data ] )

   # Assuming 1 byte addressed part below
   # For a 2 byte addressed part populate:
   #    dataArray[0] with address[7:0] and 
   #    addr with address[15:8]
   def writeI2cBlock( self, addr, dataArray ):
      # pylint: disable-msg=W0201
      numBytes = len( dataArray )
      assert numBytes < 34
      assert addr <= 0xff
      ar = i2c_smbus_ioctl_data()
      ar.read_write = I2C_SMBUS_WRITE
      ar.command = addr
      data = i2c_smbus_data()
      # Byte count itself not transmitted, used by 
      # controller to track buffer size
      data.block[0] = numBytes
      for offset in range( numBytes ):
         data.block[ 1+offset ] = dataArray[ offset ]
      ar.size = I2C_SMBUS_I2C_BLOCK_DATA
      ar.data = ctypes.pointer( data )
      ioctl( self.fd, I2C_SMBUS , ar )
      return 

   def write( self, addr, value ):
      assert isinstance(value, int) and value <=255
      time.sleep( 0.25 )
      self.writeI2cBlock( (addr>>8)&0xff, [ addr&0xff, value ] )
      return

   def writeChunk( self, address, data ):
      time.sleep( 0.25 )
      addrHigh = ( address >>8 ) & 0xff
      addrLow  = address & 0xff
      dataArray = bytearray( [ addrLow ] ) + bytearray( data )
      self.writeI2cBlock( addrHigh, dataArray )
      return 

   def writeSequenceStr( self, address, data ):
      chunkSize = 16 
      for x in range( 0, len( data ), chunkSize ):
         offset = address + x
         self.writeChunk( offset, data[x:x+chunkSize] )

class TwoByteIdprom( object ):
   def __init__( self, busNumber, addr ):
      self.ham = TwoByteIdpromHam( busNumber, addr )

   def prefdlLength( self ):
      prefdlHeader = self.ham.readSequenceStr( 0x0, 8 )
      ( fformat , length ) = struct.unpack( ">LL", prefdlHeader )
      if fformat != 3:
         return -1
      return length 

   def fdlLength( self ):
      offset = self.prefdlLength()
      fdlHeader = self.ham.readSequenceStr( offset, 8 )
      ( fformat, length ) = struct.unpack( ">LL", fdlHeader )
      if fformat not in [ 2, 3 ]:
         return -1
      return length 
      
   def readPrefdl( self ):
      # read the prefdl header to determine the length of the prefdl
      length = self.prefdlLength()
      if length > 0:
         # read off the rest of the prefdl
         return self.ham.readSequenceStr( 8, length - 8 )
      print( "Error: invalid prefdl Length. The prefdl may be corrupted." )
      return ""

   def readFdl( self ):
      prefdlLength = self.prefdlLength()
      fdlLength = self.fdlLength()
      if ( prefdlLength > 0 ) and ( fdlLength > 0 ):
         return self.ham.readSequenceStr( prefdlLength + 8, fdlLength - 8 )
      print( "Error: invalid fdl Length. The fdl may be corrupted." )
      return ""

   def readAll( self, length=0 ):
      if length > 0:
         return self.ham.readSequenceStr( 0, length )

      prefdlLength = self.prefdlLength()
      fdlLength = self.fdlLength()
      if ( prefdlLength > 0 ) and ( fdlLength > 0 ):
         totalLength = prefdlLength + fdlLength
         return self.ham.readSequenceStr( 0, totalLength )
      
      if prefdlLength < 0:
         print( "Error: invalid prefdl Length. The prefdl may be corrupted." )
      if fdlLength < 0:
         print( "Error: invalid fdl Length. The fdl may be corrupted." )
      return ""

   def unlock( self ):
      pass

   def lock( self ):
      pass

   def writePrefdl( self, prefdlStr ):
      # copy off fdl before writing new prefdl to prevent overwriting
      prefdlLength = self.prefdlLength()

      # build new prefdl header
      newPrefdlLength = len( prefdlStr ) + 8 # length = prefdl + 8 byte header
      newHeader = struct.pack( ">LL", 0x3, newPrefdlLength )
      newPrefdl = newHeader + prefdlStr
      newEepromContents = newPrefdl
      fdlLength =  self.fdlLength()
      if fdlLength > 0:
         oldFdl = self.ham.readSequenceStr( prefdlLength, fdlLength + 8 )
         newEepromContents += oldFdl
      self.unlock() 
      self.ham.writeSequenceStr( 0x0, newEepromContents )
      self.lock() 

def doRead( busId, seepromId, offset, length, prefdl=True ):
   idprom = TwoByteIdprom( busId, seepromId )
   if prefdl:
      return idprom.readPrefdl()
   else:
      return idprom.ham.readSequenceStr( offset, length )

def doWrite( busId, seepromId, offset, data, prefdl=True):
   idprom = TwoByteIdprom( busId, seepromId )
   if prefdl:
      idprom.writePrefdl( data )
   else:
      idprom.ham.writeSequenceStr( offset, data )
