#!/usr/bin/env python
# Copyright (c) 2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
#
from ctypes import CDLL, c_bool, c_uint32, c_uint64, c_char_p, \
   c_void_p, byref, create_string_buffer, memset, c_size_t, POINTER
import os, sys, re, inspect
import Tracing

t0 = Tracing.Handle( 'libdmamem' ).trace0       # pylint: disable-msg=E1101

class dmamem( object ):
   lib = CDLL( 'libdmamem.so' )
   lib.dmamem_alloc_func.restype = c_uint64
   lib.dmamem_alloc_func.argtypes = [ c_char_p, c_size_t, c_char_p ]
   lib.dmamem_map.restype = c_void_p
   lib.dmamem_map.argtypes = [ c_char_p, c_bool, c_size_t, c_size_t ]
   lib.dmamem_map_all.restype = c_void_p
   lib.dmamem_find.restype = c_char_p
   lib.dmamem_find.argtypes = [ c_uint32, POINTER( c_uint32 ) ]
   lib.dmamem_dump_begin.restype = c_bool
   lib.dmamem_dump_next.restype = c_bool

   DUMP_MAX_LINE_SIZE = 1024

   def __init__( self, name, size, offset=0, readOnly=False, highIova=False ):
      """ Allocates a dmamem chunk and mmap it at a given offset. """

      agent = os.getenv( 'AGENT_PROCESS_NAME' )
      if not agent:
         # Path to calling file, <stdin> if such doesn't exists
         agent = '@' + inspect.stack()[ 1 ][ 1 ] 
      self.name = name
      self.size = size
      self.offset = offset
      t0( 'alloc: %s, size: %d' % ( name, size ) )
      flags = 0
      if readOnly:
         flags |= dmamem.lib.dmamem_flag_readonly()
      if highIova:
         flags |= dmamem.lib.dmamem_flag_64bit()
      pa = dmamem.lib.dmamem_alloc_func( name, size, agent, flags )
      if pa == 0:
         t0( 'nomem: %s, size: %d' % ( name, size ) )
         raise MemoryError
      va = dmamem.map( name, True, offset=offset )
      if va is None:
         t0( 'map failed: %s, pid: %d, dma: %x'  % ( name, os.getpid(), pa ) )
         raise MemoryError
      self.pa = pa + offset
      self.va = va
      size = size - offset
      self.uint32Arr = ( c_uint32 * ( size / 4 )).from_address( long( self.va ) )
      t0( 'mem: %s, size: %d, offset: 0x%x, dma: %x, va: %x' % 
          ( name, offset, size, pa, va ) )

   def zeroChunk( self ):
      memset( self.uint32Arr, 0, self.size - self.offset )

   @staticmethod
   def iommu_map( domain, bus, devfn, iommu_prot ):
      return dmamem.lib.dmamem_iommu_map( domain, bus, devfn, iommu_prot )

   @staticmethod
   def register_device( domain, bus, devfn ):
      return dmamem.lib.dmamem_register_device( domain, bus, devfn )

   @staticmethod
   def iommu_unmap( domain, bus, devfn ):
      return dmamem.lib.dmamem_iommu_unmap( domain, bus, devfn )

   @staticmethod
   def unregister_device( domain, bus, devfn ):
      return dmamem.lib.dmamem_unregister_device( domain, bus, devfn )

   @classmethod
   def map( cls, name, writable, offset=0, size=0 ):
      # mmap needs to be page aligned
      assert( offset % 0x1000 == 0 )
      va = dmamem.lib.dmamem_map( name, writable, offset, size )
      if va is None:
         t0( 'map failed: %s, pid: %d, writable: %d, off %x size %x' % \
            ( name, os.getpid(), writable, offset, size ) )
         raise MemoryError
      t0( 'map: %s, writable: %d, va: %x off:%x size: %x' % ( 
            name, writable, va, offset, size ) )
      return va

   @classmethod
   def map_all( cls ):
      va = dmamem.lib.dmamem_map_all()
      if va is None:
         t0( 'map_all failed: pid: %d' % os.getpid() )
         raise MemoryError
      t0( 'map_all: va: %s' % va )
      return va

   @classmethod
   def unmap_all( cls ):
      dmamem.lib.dmamem_unmap_all()

   @classmethod
   def unmap( cls, name, va, size = 0 ):
      rc = dmamem.lib.dmamem_unmap( name, c_void_p( va ), 
                                    size, 0 ) # pylint: disable-msg=W0212
      if rc:
         t0( 'unmap failed: %s, pid: %d, va: %x'  % ( name, os.getpid(), va ) )
         raise MemoryError
      t0( 'unmap: %s, va: %x'  % ( name, va ) )

   @classmethod
   def sizeof( cls, name ):
      return dmamem.lib.dmamem_sizeof( name )

   @classmethod
   def free( cls, name, ignoreError=False ):
      t0( 'free: %s' % name )
      rc = dmamem.lib.dmamem_free( name )
      if ( rc != 0 ) and not ignoreError:
         raise NameError

   @classmethod
   def free_all( cls, nameRegEx=None ):
      t0( 'free_all' )
      if not nameRegEx :
         nameRegEx = '^.*$'
      for line in dmamem.dump():
         name = line.split() [ 1 ]
         match = re.match( nameRegEx, name )
         if match and match.end() == len( name ):
            dmamem.free( name )

   # Valid DMA address limits: (start, end)
   @classmethod
   def limits( cls ):
      startDmaAddr, endDmaAddr = c_uint32(), c_uint32()
      if dmamem.lib.dmamem_get_limits( byref(startDmaAddr), 
              byref(endDmaAddr) ) != 0:
         return ( 0, 0 )
      return ( startDmaAddr.value, endDmaAddr.value )

   @classmethod
   def _dump( cls, callback ):
      buf = create_string_buffer( dmamem.DUMP_MAX_LINE_SIZE )

      result = []
      if dmamem.lib.dmamem_dump_begin():
         try:
            while callback( byref(buf), dmamem.DUMP_MAX_LINE_SIZE ):
               result.append( buf.value )
         finally:
            dmamem.lib.dmamem_dump_end()
      return result

   @classmethod
   def dump( cls ):
      return dmamem._dump( dmamem.lib.dmamem_dump_next )

   @classmethod
   def dump_iommu( cls ):
      return os.listdir("/sys/kernel/dmamem/devices")

   @classmethod
   def find( cls, dmaAddr ):
      baseDmaAddr = c_uint32()
      name = dmamem.lib.dmamem_find( dmaAddr, byref(baseDmaAddr) )
      return ( name, baseDmaAddr.value )

def main( ):
   argv = sys.argv[1:]

   (dmaStart, dmaEnd) = dmamem.limits()
   if dmaEnd == 0 and dmaStart == 0:
      sys.exit( -1 )

   chunks = dmamem.dump()
   if len( argv ) == 0:
      used = 0
      for l in chunks:
         used += int( l.split()[ 4 ] )
         print l

      if "SIMULATION_VMID" in os.environ:
         # dmamem region in simulation starts at 1 page so we don't give users 0 dma
         # address
         size = dmaEnd - dmaStart - 0x1000
         reservationStr = 'Reserved simulation bytes: 0x%8x' % size
      else:
         with open("/proc/dmamem") as f:
            physLimits = f.read().split("\n")[0]
            # 4 groups, phys start, phys end, size, and allocator identifier
            lims = re.match(r"([0-9a-f]+)\t([0-9a-f]+)\t([0-9a-f]+)\t([A-Z]+)",
                    physLimits)
            size = lims.group(3)
            allocType = lims.group(4)
         reservationStr = ( 'Reserved CMA bytes: 0x%s' % size 
                    if allocType == "CMA" else "" )

      print 'DMA address range (0x%08x-0x%08x). Allocated bytes 0x%08x. %s' % \
            ( dmaStart, dmaEnd, used, reservationStr )

   elif argv[0] == '-f':
      for l in chunks:
         chunkname = l.split()[1]
         for arg in argv[ 1: ]:
            m = re.match( arg, chunkname )
            if m:
               dmamem.free( chunkname )
               break
   elif argv[0] == '-i':
      print 'Registered pci devices:'
      for dev in dmamem.dump_iommu():
         print dev

   else:
      print "usage: dmamem [ -f chunknames .. ]"
      sys.exit( -1 )

if __name__ == '__main__':
   os.environ[ "DMAMEM_NO_CLEANUP" ] = "1"
   os.environ[ "DMAMEM_STD_ERR" ] = "1"
   main()
