#!/usr/bin/env python
# Copyright (c) 2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, print_function

import argparse
import socket
import sys

import PLSmbusUtil
import Smbus_pb2

backendLookup = {
   "simulated" : Smbus_pb2.SIMULATED,
   "scd" : Smbus_pb2.SCD,
   "kernel" : Smbus_pb2.KERNEL_DEV,
   "ioport" : Smbus_pb2.IOPORT,
   "celestica" : Smbus_pb2.CELESTICA,
   "plm" : Smbus_pb2.PLM,
}

def dumpBytes( data, fmt='dot', step=16 ):
   """Dump data in the same fashion as i2cdump"""
   if not data:
      return

   if fmt == 'repr':
      print( repr( ''.join( chr( b ) for b in data ) ) )
      return

   schr = lambda c: '.' if c < 0x20 or c > 0x7e else chr( c )
   print( '   ',
          ' '.join( '%-2x' % v for v in range( 16 ) ),
          '  ',
          ''.join( '%x' % v for v in range( 16 ) ) )

   begin = 0
   end = step
   while begin < len( data ):
      sub = data[ begin : end ]
      print( '%02x:' % begin,
             ' '.join( '%02x' % b for b in sub ),
             '  ',
             ''.join( schr( b ) for b in sub ) )
      begin += step
      end += step

def main():

   intType = lambda x: int( x, 0 )

   commandList = [ "read8", "read16", "reads",
                   "write8", "write16", "writes",
                   "recvByte", "sendByte",
                   "processCall", "blockProcessCall",
                   "dump" ]

   usage = ( "plsmbus [-h] COMMAND [--pci PCI] [--scd ACCELID] bus deviceId "
             "[register] [data [data ...]] \n"
             "where COMMAND = {read8|read16|reads|write8|write16|"
             "writes|recvByte|sendByte|processCall|blockProcessCall|dump}" )

   parser = argparse.ArgumentParser( description="Perform Smbus transactions",
                                     usage=usage )
   parser.add_argument( "command", choices=commandList )
   parser.add_argument( "--backend", default="", help="Backend to use" )
   parser.add_argument( "--pec", action="store_true", help="Send PEC" )
   parser.add_argument( "--pci", default=0, help="A BDF PCI address "
                        "[<domain>:]bus:device.function" )
   parser.add_argument( "--scd", dest="accelId", type=intType,
                        default=None, help="interpreted as hex" )
   parser.add_argument( "bus", type=intType, help="Bus Id, interpeted as hex" )
   parser.add_argument( "deviceId", type=intType,
                        help="Device Address, interpeted as hex" )
   parser.add_argument( "register", type=intType,
                        help="Command/Register, interpeted as hex",
                        nargs='?', default=None )
   parser.add_argument( "data", nargs="*", type=intType, help="interpeted as hex" )
   parser.add_argument( "--writeNoStopReadCurrent", action="store_true",
                        default=False,
                        help="enable write no stop read current" )


   args = parser.parse_args()

   pciAddr = 0
   if args.pci:
      pciAddr = PLSmbusUtil.encodePCIAddress( args.pci )

   backend = backendLookup.get( args.backend )

   if args.command != "recvByte":
      # sendByte uses register to get the byte...
      assert args.register is not None

   sock = None
   try:
      sock = PLSmbusUtil.connect()
   except socket.error:
      print( "Could not connect to PlutoSmbus" )
      if backend == Smbus_pb2.IOPORT:
         print( "Falling back to raw access" )
      else:
         sys.exit( 1 )

   if args.command == "read8":
      result = PLSmbusUtil.read( sock, pciAddr, args.accelId,
                                 args.bus, args.deviceId, args.register, count=1,
                                 backend=backend )
      print( "0x%x" % PLSmbusUtil.stringToInt( result ) )

   elif args.command == "read16":
      result = PLSmbusUtil.read( sock, pciAddr, args.accelId,
                                 args.bus, args.deviceId, args.register, count=2,
                                 backend=backend )
      print( "0x%x 0x%x" % ( PLSmbusUtil.stringToInt( result[0] ),
                            PLSmbusUtil.stringToInt( result[1] ) ) )

   elif args.command == "recvByte":
      result = PLSmbusUtil.read( sock, pciAddr, args.accelId, args.bus,
                                 args.deviceId, None, backend=backend )
      print( "0x%x" % PLSmbusUtil.stringToInt( result ) )

   elif args.command == "reads":
      if not args.data:
         result = PLSmbusUtil.read( sock, pciAddr, args.accelId, args.bus,
                                    args.deviceId, args.register,
                                    readType=Smbus_pb2.BLOCK_READ,
                                    backend=backend )
      elif len( args.data ) == 1:
         result = PLSmbusUtil.read( sock, pciAddr, args.accelId, args.bus,
                                    args.deviceId, args.register,
                                    readCurrent=args.writeNoStopReadCurrent,
                                    count=args.data[0], backend=backend )
      else:
         assert False

      print( " ".join( hex( ord( x ) ) for x in result ) )

   elif args.command == "write8":
      assert len( args.data ) == 1
      assert args.data[0] <= 0xff
      data = PLSmbusUtil.intToString( args.data[0] )
      PLSmbusUtil.write( sock, pciAddr, args.accelId, args.bus,
                         args.deviceId, args.register, data,
                         backend=backend, pec=args.pec )

   elif args.command == "write16":
      if len( args.data ) == 1:
         assert args.data[0] <= 0xffff
         data = PLSmbusUtil.intToString( args.data[0], length=2 )
      elif len( args.data ) == 2:
         assert all( i <= 0xff for i in args.data )
         data = "".join( [ chr( i ) for i in args.data ] )
      else:
         assert False, "Improper 16 bit data provided"

      PLSmbusUtil.write( sock, pciAddr, args.accelId, args.bus,
                         args.deviceId, args.register, data,
                         backend=backend, pec=args.pec )

   elif args.command == "sendByte":
      assert not args.data
      assert args.register <= 0xff
      PLSmbusUtil.write( sock, pciAddr, args.accelId, args.bus,
                         args.deviceId, args.register, "",
                         backend=backend, pec=args.pec )

   elif args.command == "writes":
      assert args.data
      assert all( i <= 0xff for i in args.data )
      data = "".join( [ chr( i ) for i in args.data ] )
      PLSmbusUtil.write( sock, pciAddr, args.accelId, args.bus,
                         args.deviceId, args.register, data,
                         writeType=Smbus_pb2.BLOCK_WRITE,
                         backend=backend, pec=args.pec )

   elif args.command == "processCall":
      assert len( args.data ) == 2
      assert all( i <= 0xff for i in args.data )
      data = "".join( [ chr( i ) for i in args.data ] )
      result = PLSmbusUtil.processCall( sock, pciAddr, args.accelId, args.bus,
                                        args.deviceId, args.register, data,
                                        backend=backend )
      print( " ".join( hex( ord( x ) ) for x in result ) )

   elif args.command == "blockProcessCall":
      assert args.data
      assert all( i <= 0xff for i in args.data )
      count = args.data[0]
      data = "".join( [ chr( i ) for i in args.data[1:] ] )
      result = PLSmbusUtil.processCall( sock, pciAddr, args.accelId, args.bus,
                                        args.deviceId, args.register, data,
                                        count=count,
                                        callType=Smbus_pb2.BLOCK_PROCESS_CALL,
                                        backend=backend )
      print( " ".join( hex( ord( x ) ) for x in result ) )
   elif args.command == "dump":
      begin = args.register
      end = ( args.data or 0xff ) + 1
      res = []
      for i in range( begin, end ):
         try:
            byte = PLSmbusUtil.read( sock, pciAddr, args.accelId,
                                     args.bus, args.deviceId, i, count=1,
                                     backend=backend )
            res.append( PLSmbusUtil.stringToInt( byte ) )
         except Exception: # pylint: disable=broad-except
            res.append( 0xff )
      dumpBytes( res, 'dot' )

   return 0

if __name__ == "__main__":
   sys.exit( main() )
