#!/usr/bin/env python
# Copyright (c) 2015 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import argparse
import PyClient
import re
import socket
import os
from contextlib import closing

import TacMarco
import Kick

#
# This script establishes a pyclient connection to a set of specified
# EOS agents to retrieve Smash collections info such as owner, counters and
# mount status.
#

def getSmashEntities( root ):
   ''' Returns a list of mounted Smash entities from the given root '''

   if root.eval( '.tacType.fullTypeName' ) != 'Tac::Dir':
      # found the entity
      return [ root ]

   # walk the directory
   mounts = []
   for d in root.values():
      mounts.extend( getSmashEntities( d ) )

   return mounts

def remoteDumpState( socketName ):
   ''' Connects to the socket given by socketName and dumps the shocket state
       for the thread which we connect to.
       Returns the shocket state dumped from /tmp/smashinfo-<socketName>.log,
       otherwise, an error message
   '''
   shocketFileName = 'smashinfo-' + socketName + '.log'
   dumpFilePath = os.path.join( '/tmp', shocketFileName )
   state = 'shocket state failed to be dumped to file at ' + dumpFilePath

   try:
      # the python socket module does not have a context manager, so use the one 
      # provided by contextlib
      sock = socket.socket( socket.AF_UNIX, socket.SOCK_SEQPACKET )
      with closing( sock ):
         # When a socket object is created, it is default blocking.
         # Set a timeout so the user doesn't wait indefinitely.
         sock.settimeout( 2 * 60 )
         # unix abstract sockets are prefaced with a null byte
         sock.connect( '\0' + socketName )

         dumpMsg = Kick.DebugMsg( Kick.DebugMsgCommandType.REQ_DUMP ).toBuffer()
         sock.sendall( dumpMsg )

         # create our bytearray to be the size of a serialized debug message
         # to fetch our response
         recvBuf = bytearray( len( dumpMsg ) )
         recvLen = sock.recv_into( recvBuf )
         if recvLen != len( dumpMsg ):
            state = 'unexpected message length received, expected {}\
                     bytes, received {}'.format( len( dumpMsg ), recvLen )
            return state

         replyMsgType = Kick.msgType( str( recvBuf ) )

         # we only expect to receive a debug message with the status as a response
         if replyMsgType != Kick.MsgType.DEBUG:
            state = 'unexpected message type received {}, expected {}'\
                        .format( replyMsgType, Kick.MsgType.DEBUG )
         else:
            reply = Kick.debugCommandType( str( recvBuf ) )
            if reply == Kick.DebugMsgCommandType.RESP_SUCCESS:
               with open( dumpFilePath, 'r' ) as f:
                  state = f.read()
            elif reply == Kick.DebugMsgCommandType.RESP_FAIL:
               state = 'error requesting shocket state to be dumped to {}, response \
                       received: {}'.format( dumpFilePath, reply )
            else:
               state = 'unexpected response received requesting shocket state to be \
                       dumped to {}, response {}'.format( dumpFilePath, reply )
   except OSError as e:
      state = 'problem reading shocket state at {}, error: {}'\
                  .format( dumpFilePath, e.strerror )
   except ( socket.error, socket.timeout ) as e:
      state = 'socket operation failed on {}, error: {}'.format( socketName, e )

   return state


def queryAgent( rootName, agentName, thread ):
   ''' open pyclient connection to an agent and query for smash collection info '''

   # connect to agent
   pc = PyClient.PyClient( rootName, agentName )

   # collect our shocket info
   shocketState = ''
   shmemEm = pc.root()[ rootName ][ agentName ][ agentName ].shmemEm
   if not thread:
      shocketState = shmemEm.eval( '.dumpState()' )
   else:
      threadCmd = '.socketName(\'{}\')'.format( thread )
      socketName = shmemEm.eval( threadCmd )
      if not socketName:
         currentThreads = shmemEm.eval( '.socketNamesMap()' )
         shocketState = 'thread name: {} could not be found\nregistered: {}'\
                           .format( thread, currentThreads )
      else:
         shocketState = remoteDumpState( socketName )
      # bypass dumping the smash counters and mounts belonging to the main thread
      # since we are interested in the non-main thread
      return {}, shocketState

   # collect info about our smash counters and mounts
   smashInfo = {}
   # the smash root point
   smashRoot = pc.root()[ rootName ].get( 'Shmem' )
   if not smashRoot:
      # no smash mounts on this agent, we are done here!
      return smashInfo, shocketState

   # get all the smash mounts for the agent
   entities = getSmashEntities( smashRoot )
   for entity in entities:
      # figure out smash collections from the entity attributes
      controlLen = len( 'Control' )
      controls = [ s for s in entity.attributes if s.endswith( 'Control' ) and
                   s[ :-controlLen ] in entity.attributes ]

      for control in controls:
         cc = '.' + control + '.'

         # retrieve root path
         rootPath = entity.eval( cc + 'rootPath()' )

         # query mount status
         mountStatus = entity.eval( cc + 'mountStatus' )

         # query owner and retrieve pid portion
         owner = entity.eval( cc + 'owner()' )
         ownerPid = int( owner.split( ':' )[ 1 ] )

         # retrieve the collection counters
         # counters are different for reader and writer, and we expose 2 functions
         # to be able to retrieve them. Unfortunately, invoking the wrong function
         # makes the remote agent to throw an internal 'unimplemented' exception
         # which is caught by the PyServer activity. This is not nice, so we must
         # be careful to which function to call, if the collection is 'connected'
         # we can infer that is an active reader. If 'attached', its either a
         # passive reader or a writer, we compare the owner pid to the agent pid
         # to figure out if its a writer, otherwise we assume passive reader.
         # This won't be needed if we had a way to figure out the mode of the
         # collection, see BUG143113 for details.
         # Also, note that we cannot simply do
         #    counters = entity.eval( '.fooControl.readerCounters()' )
         # For reasons I don't understand, this results in a default-constructed
         # TacSmash::ReaderCounters.  See related BUG420254.
         # Instead, evaluate the string representation of the reader counters.
         countersCmd = "str( %s.%s.readerCounters() )" % ( repr( entity ), control )
         if mountStatus == 'attached' and pc.pid() == ownerPid:
            # collection mount status is attached and owner pid matches agent pid
            # we think this is a writer
            countersCmd = \
               "str( %s.%s.writerCounters() )" % ( repr( entity ), control )
         counters = pc.eval( countersCmd )

         smashInfo[ rootPath ] = { 'mount status': mountStatus,
                                   'owner': owner,
                                   'counters': counters }
   return smashInfo, shocketState

if __name__ == '__main__':
   parser = argparse.ArgumentParser( description='Script that queries agents '
                                     'for smash collections info.' )
   parser.add_argument( '--sysname', default='ar',
                        help='system name (default: \'%(default)s\')' )
   parser.add_argument( '--path', default='.',
                        help='smash path regex (default: \'%(default)s\')' )
   parser.add_argument( 'agents', nargs='+',
                        help='name of agents to collect smash information from' )
   parser.add_argument( '--thread',
                        help='name of thread to collect shocket state from, \
                        default to main-thread if none specified' )
   args = parser.parse_args()

   # compile the path regular expression
   pathPattern = re.compile( args.path )

   for agent in args.agents:
      print( '---------------------------- %s agent ----------------------------' % \
          agent )
      _smashInfo, _shocketState = queryAgent( args.sysname, agent, args.thread )
      for smash, data in _smashInfo.iteritems():
         if pathPattern.search( smash ):
            print( 'Smash collection %s:' % smash )
            for i, value in data.iteritems():
               print( '\t%s: %s' % ( i, value ) )
            print( '\n' )

      print( 'Shocket internal state for ' + \
            ( args.thread if args.thread else 'main thread' ) + ':\n' )
      print( _shocketState )
   print( '----------------------------------------------------------------------' )
