#!/usr/bin/env python
# Copyright (c) 2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import

import os
import random
import re
import socket
import struct
import sys

from collections import namedtuple

import Smbus_pb2

typeToProtoBuf = {
   Smbus_pb2.TRANSACTION_REQUEST : Smbus_pb2.TransactionRequest,
   Smbus_pb2.TRANSACTION_RESPONSE : Smbus_pb2.TransactionResponse,
   }

def connect():
   address = os.path.join( os.environ.get( "PL_SOCKET_DIR", "/var/run/platform" ),
                           "smbus" )
   sock = socket.socket( socket.AF_UNIX, socket.SOCK_STREAM )
   sock.connect( address )
   return sock

def intToString( data, length=1 ):
   result = ""
   for _ in xrange( length ):
      result += chr( data & 0xff )
      data = data >> 8
   return result

def stringToInt( string ):
   assert len( string ) <= 4
   result = 0
   for i, byte in enumerate( string ):
      result |= ord( byte ) << ( i * 8 )
   return result

def encodeProtoBufMessage( typ, message ):
   response = message.SerializeToString()
   typeLength = encodeTypeLength( typ, len( response ) )
   return typeLength + response

def decodeTypeLength( data ):
   # 8 bit type | 24 bit length
   assert len( data ) == 4, \
      "typeLength invalid (Daemonfailure?): %s; (%d chars)" % ( data, len( data ) )

   typeLength, = struct.unpack( "<I", data )
   typ = typeLength & 0xFF
   length = ( typeLength >> 8 ) & 0xFFFFFF
   return typ, length

def encodeTypeLength( typ, length ):
   assert typ <= 0xFF
   assert length <= 0xFFFFFF
   return struct.pack( "<I", typ | ( length << 8 ) )

PCIAddr = namedtuple( 'PCIAddr', [ 'domain', 'bus', 'device', 'function' ] )

def decodePCIAddress( addr ):
   # domain is 16 bits
   # bus is 8 bits
   # device is 5 bits
   # function is 3 bits
   domain = ( addr >> 16 ) & 0xffff
   bus    = ( addr >> 8 ) & 0xff
   device = ( addr >> 3 ) & 0x1f
   func   = addr & 0x07
   return PCIAddr( domain, bus, device, func )

def encodePCIAddress( pci ):
   hexDigit = '[0-9a-fA-F]'
   # [<domain>:]bus:device.function
   # 0000:00:00.0
   bdfRe = '(?:(%s{1,4}):)?(%s{1,2}):(%s{1,2}).(%s)$' % ( ( hexDigit, ) * 4 )
   m = re.match( bdfRe, pci )
   if not m:
      sys.exit( "Unable to parse PCI address" )
   ( domain, bus, device, function ) = m.groups()
   if domain is None:
      domain = '0'
   ( domain, bus, device, function ) = [ int( x, 16 ) for x in
                                          ( domain, bus, device, function ) ]
   # Must match /src/PlutoUtil/PlutoSmbusDaemon/SmbusUtils.h
   pciAddr = ( domain << 16 ) | ( bus << 8 ) | ( device << 3 ) | function
   return pciAddr


Addr = namedtuple( 'Addr', ['readDelay', 'writeDelay', 'bug30005ExtraDelays',
                            'addrSize', 'busTimeout', 'writeNoStopReadCurrent',
                            'accelId', 'busId', 'deviceId', 'register'] )

def decodeAddress( address ):
   # From SmbusAham.tac:
   # readDelay           : extern Hardware::SmbusDelay;    // 62-63
   # writeDelay          : extern Hardware::SmbusDelay;    // 60-61
   # bug30005ExtraDelays : bool : value[59:59];
   # addrSize            : extern Hardware::SmbusAddrSize; // 57-58
   # busTimeout          : extern Hardware::SmbusTimeout;  // 54-56
   # writeNoStopReadCurrent : bool : value[53:53];
   # accelId                : U7 : value[46:52];
   # busId                  : U7 : value[39:45];
   # deviceId               : U7 : value[32:38];
   # offset                 : extern U32;
   register                 = ( address )       & ( ( 2 ** 32 ) - 1 )
   deviceId                 = ( address >> 32 ) & ( ( 2 **  7 ) - 1 )
   busId                    = ( address >> 39 ) & ( ( 2 **  7 ) - 1 )
   accelId                  = ( address >> 46 ) & ( ( 2 **  7 ) - 1 )
   writeNoStopReadCurrent   = ( address >> 53 ) & ( ( 2 **  1 ) - 1 )
   busTimeout               = ( address >> 54 ) & ( ( 2 **  3 ) - 1 )
   addrSize                 = ( address >> 57 ) & ( ( 2 **  2 ) - 1 )
   bug30005ExtraDelays      = ( address >> 59 ) & ( ( 2 **  1 ) - 1 )
   writeDelay               = ( address >> 60 ) & ( ( 2 **  2 ) - 1 )
   readDelay                = ( address >> 62 ) & ( ( 2 **  2 ) - 1 )

   return Addr (
      readDelay, writeDelay, bug30005ExtraDelays, addrSize, busTimeout,
      writeNoStopReadCurrent, accelId, busId, deviceId, register
   )

def encodeAddress( bus, deviceId, register, busTimeout=1,
                   accelId=0, addrSize=1, writeNoStopReadCurrent=0 ):
   # Currently only encoding a subset of the fields
   assert deviceId < ( 2 ** 7 )
   assert bus < ( 2 ** 7 )
   assert accelId < ( 2 ** 7 )
   assert writeNoStopReadCurrent < 2
   assert addrSize == 1 or addrSize == 2 or addrSize == 4

   address = register
   address |= ( deviceId << 32 )
   address |= ( bus << 39 )
   address |= ( accelId << 46 )
   address |= ( writeNoStopReadCurrent << 53 )
   address |= ( busTimeout << 54 )
   address |= ( addrSize << 57 )
   return address

def sendAndRecv( sock, transReq ):
   """ Sends transReq to daemon, then receives transResponse and returns it.
       Sets xid field of transReq. """
   xid = random.randint( 0, 2 ** 31 - 1 )
   transReq.xid = xid
   message = encodeProtoBufMessage( Smbus_pb2.TRANSACTION_REQUEST, transReq )
   sock.sendall( message )

   typeLength = sock.recv( 4 )
   typ, length = decodeTypeLength( typeLength )
   assert typ == Smbus_pb2.TRANSACTION_RESPONSE
   bits = sock.recv( length )
   TransResponse = Smbus_pb2.TransactionResponse
   transactionResponse = TransResponse.FromString( bits ) # pylint: disable-msg=E1101
   if xid != transactionResponse.xid:
      raise IOError( "Got back bad xid. Expecting: %d received: %d" %
                     ( xid, transactionResponse.xid ) )
   if transactionResponse.failure:
      raise IOError( "transaction failed" )
   return transactionResponse

def read( sock, pci, accelId, bus, deviceId, register, count=1,
          readCurrent=False, readType=Smbus_pb2.READ, backend=None ):
   if not sock:
      assert backend == Smbus_pb2.IOPORT
      with open( '/dev/port', 'rb' ) as f:
         f.seek( register )
         return f.read( 1 )

   recvByte = register is None
   assert not recvByte or count == 1

   tr = Smbus_pb2.TransactionRequest()
   request = tr.request.add() # pylint: disable-msg=E1101
   writeNoStopReadCurrent = 1 if ( recvByte or readCurrent ) else 0
   request.address = encodeAddress( bus, deviceId, register or 0,
                                    accelId=accelId or 0,
                                    writeNoStopReadCurrent=writeNoStopReadCurrent )
   request.pci = pci
   request.type = readType
   if backend is not None:
      request.masterType = backend
   else:
      if accelId is None:
         request.masterType = Smbus_pb2.KERNEL_DEV
      else:
         request.masterType = Smbus_pb2.SCD
   request.count = count

   tr = sendAndRecv( sock, tr )
   return tr.response[ 0 ].data

def write( sock, pci, accelId, bus, deviceId, register, data,
           writeType=Smbus_pb2.WRITE, backend=None, pec=None ):
   if not sock:
      assert backend == Smbus_pb2.IOPORT
      with open( '/dev/port', 'wb' ) as f:
         f.seek( register )
         f.write( data[ 0 ] )
      return

   tr = Smbus_pb2.TransactionRequest()
   request = tr.request.add() # pylint: disable-msg=E1101
   request.address = encodeAddress( bus, deviceId, register or 0,
                                    accelId=accelId or 0 )
   request.pci = pci
   request.type = writeType

   if backend is not None:
      request.masterType = backend
   else:
      if accelId is None:
         request.masterType = Smbus_pb2.KERNEL_DEV
      else:
         request.masterType = Smbus_pb2.SCD

   request.count = len( data )
   request.data = data

   if pec:
      request.pec = True

   sendAndRecv( sock, tr )

def processCall( sock, pci, accelId, bus, deviceId, register, data, count=2,
                 callType=Smbus_pb2.PROCESS_CALL, backend=None ):
   tr = Smbus_pb2.TransactionRequest()
   request = tr.request.add() # pylint: disable-msg=E1101
   request.address = encodeAddress( bus, deviceId, register, accelId)
   request.pci = pci
   request.type = callType

   if backend:
      request.masterType = backend
   else:
      if accelId is None:
         request.masterType = Smbus_pb2.KERNEL_DEV
      else:
         request.masterType = Smbus_pb2.SCD
   request.count = count
   request.data = data

   tr = sendAndRecv( sock, tr )
   return tr.response[ 0 ].data
