#!/usr/bin/env python
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from TableOutput import createTable
from CliModel import Int, Str, Dict, Model

class AllocTrackTypeModel( Model ):
   size = Int( help='Size of type in bytes' )
   memoryAllocationOverhead = Int( help='Memory allocation overhead in bytes' )
   currentAllocations = Int( help='Number of objects currently allocated' )
   totalAllocations = Int( help='Total number of allocations (calls to new)' )
   highestAllocations = Int( help='High watermark of currentAllocations' )

class AllocTrackModel( Model ):
   types = Dict( keyType=str,
                 valueType=AllocTrackTypeModel,
                 help='Dictionary of types' )
   _sortOrder = Str( help='Order in which type info should be rendered',
                     default='typeName' )
   _typePrefix = Str( optional=True, help='Type name prefix to skip in display' )
   _maxNameLen = Int( help='Maximum number of type name characters to render',
                      default=100 )
   _tableWidth = Int( help='Character width of rendered table', default=200 )
   _limit = Int( help='Number of type entries to display', default=0 )
   _fullOrDelta = Str( help='Render full model or delta from _baseTypes',
                       default='full' )
   _baseTypes = Dict( optional=True, keyType=str, valueType=AllocTrackTypeModel,
                      help='Dictionary of type information to diff against' )

   def render( self ):
      if self._fullOrDelta == 'full':
         self._render()
      else:
         assert self._fullOrDelta == 'delta'
         self._renderDiff()

   def _render( self ):
      def _sortSize( t ):
         return t[ 1 ].size + t[ 1 ].memoryAllocationOverhead

      def _sortCurrent( t ):
         return -1 * _sortSize( t ) * t[ 1 ].currentAllocations

      def _sortHighest( t ):
         return -1 * _sortSize( t ) * t[ 1 ].highestAllocations

      sortFunctions = {
         'typeName' : lambda t : t[ 0 ],
         'currentMemory': _sortCurrent,
         'highWatermarkMemory': _sortHighest,
         'totalAllocations' : lambda  t : -t[ 1 ].totalAllocations,
      }
      sortedTypes = sorted( self.types.iteritems(),
                            key=sortFunctions[ self._sortOrder ] )

      table = createTable( ( ( 'type', 'l' ),
                             ( 'size +', 'r', ( 'overhead', ) ),
                             ( 'total', 'c', ( 'allocations', ) ),
                             ( 'current', 'c', ( 'count', 'memory' ) ),
                             ( 'high watermark', 'c', ( 'count', 'memory' ) ) ),
                           tableWidth=self._tableWidth )
      limit = self._limit if self._limit else None
      rows = [ [ self._displayTypeName( typeName ),
                 '%d + %2d' % ( typeInfo.size, typeInfo.memoryAllocationOverhead ),
                 typeInfo.totalAllocations,
                 typeInfo.currentAllocations,
                 typeInfo.currentAllocations * ( typeInfo.size +
                                                 typeInfo.memoryAllocationOverhead ),
                 typeInfo.highestAllocations,
                 typeInfo.highestAllocations * ( typeInfo.size +
                                                 typeInfo.memoryAllocationOverhead )
               ]
               for typeName, typeInfo in sortedTypes ]
      displayRows = rows[ : limit ]
      for row in displayRows:
         table.newRow( *self._formatRowMemory( row, [ 4, 6 ] ) )
      table.newRow()
      if len( displayRows ) != len( rows ):
         # When we display only a subset of rows, include a total of
         # the displayed rows in addition to the total of all rows.
         totalsRow = [ 'TOTAL displayed', '' ]
         totalsRow += [ sum( x ) for x in zip( *[ r[ 2 : ] for r in displayRows ] ) ]
         table.newRow( *self._formatRowMemory( totalsRow, [ 4, 6 ] ) )
      totalsRow = [ 'TOTAL', '' ]
      if rows:
         totalsRow += [ sum( x ) for x in zip( *[ r[ 2 : ] for r in rows ] ) ]
      else:
         totalsRow += [ 0 ] * 5
      table.newRow( *self._formatRowMemory( totalsRow, [ 4, 6 ] ) )
      print table.output()

   def _renderDiff( self ):
      transient = { 'currentMemory' : False,
                    'transientMemory' : True }[ self._sortOrder ]

      def metric( typeInfo ):
         if transient:
            return typeInfo.totalAllocations - typeInfo.currentAllocations
         else:
            return typeInfo.currentAllocations

      def metricsWithChange( typeName ):
         baseTypeInfo = self._baseTypes.get( typeName )
         baseMetric = metric( baseTypeInfo ) if baseTypeInfo else 0
         typeInfo = self.types[ typeName ]
         endMetric = metric( typeInfo )
         return [ baseMetric, endMetric, endMetric - baseMetric ]

      # Sort the rows
      def change( t ):
         return metricsWithChange( t[ 0 ] )[ -1 ]
      changedTypes = [ t for t in self.types.iteritems() if change( t ) != 0 ]
      sortedTypes = sorted( changedTypes, reverse=True, key=change )

      # Prepare sorted data for display
      hdr = ( ( 'type', 'l' ),
              ( 'size +', 'r', ( 'overhead', ) ),
              ( 'transient' if transient else 'current', 'c',
                ( 'begin', 'end', 'delta' ) ) )
      table = createTable( hdr, tableWidth=self._tableWidth )
      limit = self._limit if self._limit else None
      rows = [ [ self._displayTypeName( typeName ),
                 '%d + %2d' % ( typeInfo.size,
                                typeInfo.memoryAllocationOverhead ) ] +
               metricsWithChange( typeName )
               for typeName, typeInfo in sortedTypes ]
      displayRows = rows[ : limit ]
      for row in displayRows:
         table.newRow( *row )
      table.newRow()
      if len( displayRows ) != len( rows ):
         totalsRow = [ 'TOTAL displayed', '' ]
         totalsRow += [ sum( x ) for x in zip( *[ r[ 2 : ] for r in displayRows ] ) ]
         table.newRow( *totalsRow )
      totalsRow = [ 'TOTAL', '' ]
      if rows:
         totalsRow += [ sum( x ) for x in zip( *[ r[ 2 : ] for r in rows ] ) ]
      else:
         totalsRow += [ 0 ] * 3
      table.newRow( *totalsRow )
      print table.output()

   def _displayTypeName( self, typeName ):
      if self._typePrefix is not None and typeName.startswith( self._typePrefix ):
         typeName = typeName[ len( self._typePrefix ) : ]
      return typeName[ : self._maxNameLen ]

   def _formatRowMemory( self, row, memoryFields ):
      return [ self._formatMemory( x ) if i in memoryFields else x
               for i, x in enumerate( row ) ]

   def _formatMemory( self, size, precision=2 ):
      suffixes = [ 'B ', 'KB', 'MB', 'GB', 'TB' ]
      suffixIndex = 0
      s = size
      while s >= 1024 and suffixIndex < len(suffixes) - 1:
         suffixIndex += 1
         s = s / 1024.0
      return '%.*f %s' % ( precision, s, suffixes[ suffixIndex ] )
