#!/usr/bin/python
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import os
import sys
# pylint: disable-msg=c-extension-no-member
import CliHeapcheck

# This keeps the state/counters for the memory stats that can be collected under a
# cli debuging mode. It also encapsulates the apis to go collect those stats.
# This is used by the context manager that follows (MemoryStatsContext)
class MemoryStatsState( object ):
   def __init__( self ):
      # setup calls to libc's mallinfo query
      import ctypes
      self.libc = ctypes.cdll.LoadLibrary( "libc.so.6" )
      self.mallinfo = self.libc.mallinfo

      class MallinfoRet( ctypes.Structure ):
         _fields_ = [
            ( "arena", ctypes.c_int ),     # Non-mmapped space allocated (bytes)
            ( "ordblks", ctypes.c_int ),   # Number of free chunks
            ( "smblks", ctypes.c_int ),    # Number of free fastbin blocks
            ( "hblks", ctypes.c_int ),     # Number of mmapped regions
            ( "hblkhd", ctypes.c_int ),  # Space allocated in mmapped regions (bytes)
            ( "usmblks", ctypes.c_int ),   # Maximum total allocated space (bytes)
            ( "fsmblks", ctypes.c_int ),   # Space in freed fastbin blocks (bytes)
            ( "uordblks", ctypes.c_int ),  # Total allocated space (bytes)
            ( "fordblks", ctypes.c_int ),  # Total free space (bytes)
            ( "keepcost", ctypes.c_int ),  # Top-most, releasable space (bytes)
         ]
      self.mallinfo.restype = MallinfoRet
      self.malloc_trim = self.libc.malloc_trim
      # stats: [0]=previous run @enter; [1]=this run @enter; [2]=this run @exit
      self.free = [ 0, 0, 0 ]
      self.alloced = [ 0, 0, 0 ]
      self.rss = [ 0, 0, 0 ]
      self.objs = [ 0, 0, 0 ]
      # total counters (cumulative of the deltas). Start to accumulate totals only
      # after 5 iterations (warmup of any caches)
      self.iter = 0 # count the iterations of debugged command runs
      self.freeD_cum = 0 # cummulative delta of free mem
      self.freeD_h2h_cum = 0 # same as above, but for handler to handler case
      self.rssD_cum = 0
      self.rssD_h2h_cum = 0
      self.objsD_cum = 0
      self.objsD_h2h_cum = 0
      self.allocedD = 0 # only set if sum of free/alloc deltas is not null
      self.msg = "" # what we msg to the user, might happen later (added to prompt)

   # APIs to collect snapshots of various memory stats

   def snapMalloc( self, inst ):
      m = self.mallinfo()
      self.free[ inst ] = m.fordblks
      self.alloced[ inst ] = m.uordblks

   def snapRss( self, inst ):
      with open( "/proc/self/statm" ) as f:
         c = f.read()
      self.rss[ inst ] = int( c.split( " " )[ 1 ] )

   def snapObjs( self, inst ):
      import gc
      self.objs[ inst ] = len( gc.get_objects() )

   # APIs to return the stats

   def getMallocStats( self ):
      # compute the deltas
      freeD = self.free[ 2 ] - self.free[ 1 ]
      allocedD = self.alloced[ 2 ] - self.alloced[ 1 ]
      if freeD + allocedD != 0:
         self.allocedD = allocedD
      else:
         self.allocedD = 0
      freeD_h2h = self.free[ 1 ] - self.free[ 0 ]
      if self.iter > 5:
         self.freeD_cum += freeD
         self.freeD_h2h_cum += freeD_h2h
         return "Free: %4d / %-5d %4d / %-6d " % ( freeD_h2h, self.freeD_h2h_cum,
                                                   freeD, self.freeD_cum )
      else:
         return "Free: %6d %6d " % ( freeD_h2h, freeD )

   def getRssStats( self ):
      # compute the deltas
      rssD = self.rss[ 2 ] - self.rss[ 1 ]
      rssD_h2h = self.rss[ 1 ] - self.rss[ 0 ]
      if self.iter > 5:
         self.rssD_cum += rssD
         self.rssD_h2h_cum += rssD_h2h
         return "RSS: %3d / %-4d %3d / %-5d " % ( rssD_h2h, self.rssD_h2h_cum,
                                                  rssD, self.rssD_cum )
      else:
         return "RSS: %6d %6d " % ( rssD_h2h, rssD )

   def getObjsStats( self ):
      # compute the deltas
      objsD = self.objs[ 2 ] - self.objs[ 1 ]
      objsD_h2h = self.objs[ 1 ] - self.objs[ 0 ]
      if self.iter > 5:
         self.objsD_cum += objsD
         self.objsD_h2h_cum += objsD_h2h
         return "objs: %3d / %-4d %3d / %-5d " % ( objsD_h2h, self.objsD_h2h_cum,
                                                   objsD, self.objsD_cum )
      else:
         return "objs: %6d %6d " % ( objsD_h2h, objsD )

   # Save this run so next run does not overwrite it and we can still provide
   # call to call diffs ("shift" values in the array to prevent overwrite)
   def shift( self ):
      self.free[ 0 ] = self.free[ 1 ]
      self.alloced[ 0 ] = self.alloced[ 1 ]
      self.rss[ 0 ] = self.rss[ 1 ]
      self.objs[ 0 ] = self.objs[ 1 ]
      self.iter += 1

   def collect( self ):
      import gc
      gc.collect()
      self.malloc_trim()

class MemoryStatsContext( object ):
   """Context manager that will record before/after memory usage metrics based on
   the kind of metrics configured via the 'cli debug memory ...' cmd.
   There are 2 before/afters: before and after running the command handler, and
   before running command handler n and before running command handler n+1, which
   we abbreviate h2h. So we print deltas, and we can also integrate those deltas to
   weed out the noise that happens without forced garbage collection, but we only
   start the integration after a few commands to prime any caches.
   For the heapcheck, there is only a before/after command execution, and the result
   is printed by tcmalloc directly to stdout instead of the prompt/logs.
    """
   # TODO: for consistent results: prevent activity thread or other cli sessions from
   # interferring: Tac.ActivityLockHolder (will not do for h2h stats)
   def __init__( self, mode ):
      self.cfg = None
      self.state = None
      self.hc = None

   def cfgIs( self, cfg ):
      self.cfg = cfg
      self.state = MemoryStatsState()

   def __enter__( self ):
      if self.cfg:
         self.state.shift()
         if self.cfg.heapcheck:
            CliHeapcheck.heapcheckStart()
         if self.cfg.gc:
            self.state.collect()
         if self.cfg.mallinfo:
            self.state.snapMalloc( 1 )
         if self.cfg.rss:
            self.state.snapRss( 1 )
         if self.cfg.pyObjects:
            self.state.snapObjs( 1 )

   def __exit__( self, _exceptionType, _value, _traceback ):
      if self.cfg:
         self.state.msg = ""
         if not self.cfg.heapcheck:
            self.state.msg = "h2h/[cum] cmd/[cum] "
         if self.cfg.heapcheck:
            CliHeapcheck.heapcheckStop()
         if self.cfg.gc:
            self.state.collect()
         if self.cfg.mallinfo:
            self.state.snapMalloc( 2 )
            self.state.msg += self.state.getMallocStats()
         if self.cfg.rss:
            self.state.snapRss( 2 )
            self.state.msg += self.state.getRssStats()
         if self.cfg.pyObjects:
            self.state.snapObjs( 2 )
            self.state.msg += self.state.getObjsStats()
         if self.state.allocedD:
            self.state.msg += "Alloc: %d" % self.state.allocedD
         # normally we print stats into the prompt, but when non-interactive we dont
         # print a prompt, so print directly here (unless explicit desire to write
         # the stats to the agent log file instead).
         if not sys.stdin.isatty() and not self.cfg.log:
            print( self.state.msg )
         if self.cfg.log:
            os.write( 2, self.state.msg )
            os.write( 2, "\n" )
