# Copyright (c) 2009 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import sys
import re

"""
ArHal - Arista Hardware Adaptation Layer

This layer adapts fields, registers, table, and modules, enabling
writing and reading of registers, as well as supporting a rich set of
string options.  The main use is for Petra, but it should be general
purpose enough for any register set with the exception of HalTable
which is geared somewhat towards Petra's table implementation.

There's limited support for multi-word registers and, as coded, it
only supports little-endian.


Example:

class HalFooStatusReg( HalReg ):
   def __init__( self, rep=0 ):
      HalReg.__init__( self, rep )
   field = {
      'pktCount' : ( 0xfffff, 4 ),
      'byteCount' : ( 0xf, 24 ),
   }

class HalFoo( HalModule ):
   def __init__( self, mmio, baseAddr ):
      HalModule.__init__( self, mmio, baseAddr )
      addr = 0x10000
      entries = 16
      bitSize = 32
      table = HalTable( 'status', mmio, addr, entries, bitSize, HalFooStatusReg )
      object.__setattr__( self, 'status', table )

foo = HalFoo(Mmio(), 0x100000)
r = foo.status[0] # read reg
r.pktCount = 3 # modify
foo.status[0] # write reg

print foo.status[3,8:10] # print selected reg entries
print foo.status['NonZero'] # print non-zero entries
print foo.status['NonZero', 9,10:15] # print selected non-zero entries
print foo.status[3]['NonZero'] # print non-zero fields of reg

for reg in foo.status[3,8:10]:
   print reg.pktCount
   
One oddity is that 'for i in foo.status[3]' doesn't work, for the same reason
that 'x = range(8); for i in x[3]' doesn't work.



"""

####### Hack to get numbers to print in hex on CLI ########

class hlong( long ):
   """derive from long just so I can print in hex"""
   """One odditity is that '%d' % hlong(x) prints __str__"""
   """so to print %d do '%d' % int( hlong(x) ) """
   def __str__( self ):
      return "%d (0x%08x)" % ( int(self), self)


def _displayhook(s):
   """Override Python interactive prompt to print hex for int and long"""
   thisType = type(s)
   if thisType==int:
      print "%d (0x%08x)" % (s, s)
   elif thisType in (long,hlong):
      print "%d (0x%016x)" % (long(s), s)
   else:
      print repr( s )

sys.displayhook = _displayhook


############# HAL stuff ###############

def _formatIndexAndEntry( index, entry ):
   splits = re.split('\n', entry)
   indexString = '[0x%x]:' % ( index )
   string = '%-8s   %s\n' % ( indexString, splits.pop( 0 ) )
   for split in splits:
      string += '           '+split+'\n'
   return string

def _flattenSlices( entries, slices ):
   # flatten slices and singletons
   flatten = []
   if type( slices ) != tuple:
      slices = [ slices ]
   for s in slices:
      if type( s )==slice:
         start = s.start if s.start else 0
         stop = s.stop if s.stop else entries
         step = s.step if s.step else 1
         flatten += range( start, stop, step )
      else: # singelton
         flatten.append( s )
   return flatten
         
class HalVectorField( object ):
   def __init__( self, caller, field ):
      self.caller = caller
      self.field = field
   def __getitem__( self, index ):
      mask, offset, strides = self.field
      strides = strides[:] # make a copy
      stride, entries = strides.pop( 0 )
      assert index < entries
      offset += stride*index
      if len( strides )==0:
         return ( self.caller.rep & mask << offset ) >> offset
      else:
         field = ( mask, offset ,strides )
         return HalVectorField( self.caller, field )    
   def __setitem__( self, index, value ):
      assert len( self.field[2] )==1 # right dimensions?
      ( mask, offset, [(stride, entries)] ) = self.field
      assert ( value & ~mask )==0
      assert index < entries
      offset += stride*index
      self.caller.rep &= ~(mask << offset)
      self.caller.rep |= value << offset
   def __str__( self ):
      ( mask, offset, strides ) = self.field
      strides = strides[:] # make a copy
      stride, entries = strides.pop( 0 )
      string = ''
      for i in range(entries):
         string += str(self[i])+' '
      return string
      
def _HalRegSliceIterator( self ):
   slice = self.slice
   if type(slice)==tuple and slice[0]=='NonZero':
      slice = slice[1:]
      showNonZeroOnly = True
   else:
      showNonZeroOnly = False
   for i in _flattenSlices( self.entries, slice ):
      data = self.reg[ i ]
      if int(data)!=0 or not showNonZeroOnly:
         yield data

class _HalRegSlice( object ):
   # Enables __str__ for slices reg[x:y]
   def __init__( self, reg, entries, slice ):
      self.reg = reg
      self.entries = entries
      self.slice = slice
   def __iter__( self ):
      return _HalRegSliceIterator( self )
   def __str__( self ):
      slice = self.slice
      if type(slice)==tuple and slice[0]=='NonZero':
         slice = slice[1:]
         showNonZeroOnly = True
      else:
         showNonZeroOnly = False
      string = ''
      for i in _flattenSlices( self.entries, slice ):
         data = self.reg[ i ]
         if int(data)!=0 or not showNonZeroOnly:
            string += _formatIndexAndEntry( i, str( data ) )
      return string

def _HalVectorRegIterator( vectorReg ):
   ( offset, Reg, strides ) = vectorReg.reg
   strides = strides[:] # make a copy
   stride, entries = strides.pop(0)
   for i in range( entries ):
      yield vectorReg[i]

class _HalVectorReg( object ):
   def __init__( self, caller, name, reg ):
      self.caller = caller
      self.name = name
      self.reg = reg
   def __iter__( self ):
      return _HalVectorRegIterator( self )
   def __len__( self ):
      ( offset, Reg, strides ) = self.reg
      strides = strides[:] # make a copy
      stride, entries = strides.pop(0)
      return entries   
   def __getitem__( self, index ):
      ( offset, Reg, strides ) = self.reg
      strides = strides[:] # make a copy
      stride, entries = strides.pop(0)
      if len(strides)==0: # final dimension, read the reg
         if index=='NonZero': # let x['NonZero'] equal x['NonZero',:]
            index = ( 'NonZero', slice(None) )
         if isinstance( index, (int,long) ): # singleton
            assert index < entries
            offset += stride*index
            rep = self.caller.mmio.read( offset )
            return Reg( rep )
         else: # process slices
            return _HalRegSlice( self, entries, index )
      else: # still more dimensions to pop off
         assert index < entries
         reg = ( offset + stride*index, Reg, strides )
         return _HalVectorReg( self.caller, self.name, reg )   
   def __setitem__( self, index, value ):
      assert len( self.reg[2] )==1 # right dimensions?
      ( offset, Reg, [(stride, entries)] ) = self.reg
      assert index < entries
      offset += stride*index
      self.caller.mmio.write( offset, int(value) )
   def __str__( self ):
      ( offset, Reg, strides ) = self.reg
      string = self.name
      for stride, entries in strides:
         string += '[%d]' % entries
      return string
         
class HalReg( object ):
   def __init__( self, rep=0 ):
      object.__setattr__( self, 'rep', rep )
      
   # For historical reasons (i.e., this was built for dune)
   # a HalReg is one 32-bit word. Other sizes can be derived.
   words = 1
   wordByteSize = 4
   
   def __getattr__( self, key ):
      field = self.field[ key ]
      if len( field )==2: # no strides
         mask, offset = field
         return ( self.rep & mask << offset ) >> offset
      else:
         return HalVectorField( self, field )
   def __setattr__( self, key, value ):
      try:
         self.__dict__[key]
         object.__setattr__( self, key, value )
      except KeyError:
         mask, offset = self.field[ key ]
         assert (value & ~mask)==0
         self.rep &= ~(mask << offset)
         self.rep |= value << offset
   def __int__( self ):
      return self.rep
   def __long__( self ):
      return self.rep
   # Abuse Python's __getitem__ to make "print reg['NonZero']"
   # print only non-zero fields
   def __getitem__( self, index ):
      assert index=='NonZero'
      reg = type(self)( self.rep )
      object.__setattr__(reg, 'hideNonZeroFields', True)
      return reg
   def __str__( self ):
      str = ''
      str += '%-22s %-5d (0x%08x)\n' % ( 'rep', int( self.rep ), self.rep)
      for field in self.field:
         val = self.__getattr__( field )
         if type( val ) == HalVectorField:
            str += '%-30s %s\n' % ( field, val )
         else:
            if val!=0 or not 'hideNonZeroFields' in self.__dict__:
               str += '%-22s %-5d (0x%08x)\n' % ( field, val, val)
      return str

class HalReg64( HalReg ):
   words = 2
   wordByteSize = 4

def _HalTableSliceIterator( self ):
   slice = self.slice
   if type(slice)==tuple and slice[0]=='NonZero':
      slice = slice[1:]
      showNonZeroOnly = True
   else:
      showNonZeroOnly = False
   for i in _flattenSlices( self.table.entries, slice ):
      data = self.table[ i ]
      if int(data)!=0 or not showNonZeroOnly:
         yield data

class _HalTableSlice( object ):
   # Enables __str__ for slices table[x:y]
   def __init__( self, table, slice ):
      self.table = table
      self.slice = slice
   def __iter__( self ):
      return _HalTableSliceIterator( self )
   def __str__( self ):
      slice = self.slice
      if type(slice)==tuple and slice[0]=='NonZero':
         slice = slice[1:]
         showNonZeroOnly = True
      else:
         showNonZeroOnly = False
      string = ''
      for i in _flattenSlices( self.table.entries, slice ):
         data = self.table[ i ]
         if int(data)!=0 or not showNonZeroOnly:
            string += _formatIndexAndEntry( i, str( data ) )
      return string

def _HalTableIterator( table ):
   for i in range( table.entries ):
      yield table[i]

class HalTable( object ):
   def __init__( self, name, mmio, base, entries, entryBitSize, type ):
      self.name = name
      self.mmio = mmio
      self.base = base
      self.entries = entries
      self.type = type
      self.entryBitSize = entryBitSize
   def __getitem__( self, index ):
      if index=='NonZero': # let x['NonZero'] equal x['NonZero',:]
         index = ( 'NonZero', slice(None) )
      if isinstance( index, (int,long) ):
         assert index < self.entries
         rep = self.mmio.read( self.base+index, (self.entryBitSize+31)/32 )
         return self.type( rep )
      else:
         return _HalTableSlice( self, index )
   def __setitem__( self, index, reg ):
      assert index < self.entries
      self.mmio.write( self.base+index, (self.entryBitSize+31)/32, int(reg) )
   def __iter__( self ):
      return _HalTableIterator( self )
   def __len__( self ):
      return self.entries
   def __str__( self ):
      return '%s[%d]' % ( self.name, self.entries )

class HalModule( object ):
   def __init__( self, mmio, baseAddr ):
      object.__setattr__( self, 'mmio', mmio )
      object.__setattr__( self, 'baseAddr', baseAddr )
   reg = {}
   def __getattr__( self, key ):
      reg = self.reg[ key ]
      if len( reg )==2: # no strides
         rep = 0
         ( offset, Reg ) = reg
         if Reg==int:
            wordByteSize = 4
            words = 1
         else:
            wordByteSize = Reg.wordByteSize
            words = Reg.words
         wordBitSize = wordByteSize<<3
         for i in range(words-1,-1,-1): # this implies we're little-endian
            rep <<= wordBitSize
            rep |= self.mmio.read( self.baseAddr + offset + i*wordByteSize )
         return Reg( rep )
      else:
         ( offset, Reg, strides ) = reg
         reg = ( self.baseAddr+offset, Reg, strides )
         return _HalVectorReg( self, key, reg )
   def __setattr__( self, key, value ):
      try:
         self.__dict__[ key ]
         object.__setattr__( self, key, value )
      except KeyError:
         ( offset, Reg ) = self.reg[ key ]
         if Reg==int:
            wordByteSize = 4
            words = 1
         else:
            wordByteSize = Reg.wordByteSize
            words = Reg.words
         wordBitSize = wordByteSize<<3
         v = int(value)
         mask = (1<<wordBitSize) - 1
         for i in range(0, words): # this implies we're little-endian
            self.mmio.write( self.baseAddr + offset + i*wordByteSize,
               v & mask )
            v >>= wordBitSize
   def __str__( self ):
      string = ''
      keys = self.__dict__.keys()
      keys.sort()
      for key in keys:
         val = self.__dict__[ key ]
         if isinstance( val, HalModule ):
            baseAddr = val.baseAddr
            addr = hex( baseAddr/4 ) if isinstance( baseAddr, (int,long) ) else ''
            string += '%-20s %s\n' % (key, addr)
         elif isinstance( val, HalTable ):
            addr = hex( val.base )
            string += '%-40s %s\n' % ( str( val ), addr )
         elif type( val )==list:
            string += '%s[%d]\n' % ( key, len(val) )

      keys = self.reg.keys()
      keys.sort()
      for key in keys:
         reg = self.reg[ key ]

         if len(reg)==2: # no strides
            ( offset, Reg ) = reg
            addr = hex( (self.baseAddr + offset)/4 )
            string += '%-40s %s\n' % ( key, addr )
         else:
            name = key
            ( offset, Reg, strides ) = reg
            addr = (self.baseAddr + offset)/4
            for stride, entries in strides:
               name += '[%d]' % entries
            string += '%-40s 0x%04x\n' % (name, addr )
      return string



def addrToName(hal,name='hal', map={}):
   """
   Useful utility for working with Dune folks who like to think in numerical
   registers.  Converts between a word address and the logical name.

   Example:

   > addrToName = addrToName( linecard.petras[0].hal )
   > print addrToName[0x5401]
   hal.nif[1].srd_interrupt_reg

   > print eval(addrToName[0x5401])
   rep                    0     (0x00000000)
   srd_epb_interrupt      0     (0x00000000)
   srd_lane_interrupt     0     (0x00000000)
   srd_ipu_interrupt      0     (0x00000000)

   exec(addrToName[0x5401] + '= 2')

   The code assumes a word is 4 bytes ( as it is for Petra ).

   """
   for key in hal.__dict__.keys():
      string = ''
      val = hal.__dict__[ key ]
      if isinstance( val, HalModule ):
         baseAddr = val.baseAddr
         if isinstance(baseAddr, (int,long)):
            addrToName(val,name+'.'+key, map)
      elif isinstance( val, HalTable ):
         pass
      elif type( val )==list:
         for i,module in enumerate(val):
            addrToName(module,'%s.%s[%d]' % (name,key,i), map)
   if isinstance(hal, HalModule):
      for key,value in hal.reg.items():
         if len(value)==3:
            dimension0 = [[0,'']]
            for stride,entries in value[2]:
               if stride==0:
                  print "skipping "+name+'.'+key
                  continue
               dimension1 = range(0, stride*entries, stride)
               dimension0 = [ [i+j,k+'[%d]' % l] for i,k in dimension0
                  for l,j in enumerate(dimension1) ]    
            for offset,offsetKey in dimension0:
               addr = (hal.baseAddr+value[0]+offset)/4
               map[addr] = name+'.'+key+offsetKey
         else:
            map[ (hal.baseAddr + value[0])/4 ] = name+'.'+key
   return map
