#!/usr/bin/env python
# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import os
import re
import signal
import threading
import QuickTrace
import Tracing

traceHandle = Tracing.Handle( 'CliSubprocMgr' )
t0 = traceHandle.trace0
t1 = traceHandle.trace1

qt0 = QuickTrace.trace0
qt1 = QuickTrace.trace1
qv = QuickTrace.Var

DB_PATH = '/var/run/CliSubproc.db'

class SubprocMgr( object ):
   # Managed subprocesses created by CLI
   # In case of ConfigAgent restart, we need to be able to find lingering
   # subprocesses and cleanup them up.
   #
   # Note all child processes are created with pdeathsig() set to SIGKILL,
   # but this is cleared whenever the subprocess changes its credentials
   # such as setuid binaries. So we keep track of child pids in a file
   # and use it to cleanup after restart.

   def __init__( self, dbfile=DB_PATH ):
      self.dbfile_ = dbfile
      self.lock_ = threading.Lock()
      self.children_ = set()
      self.ppidRe_ = re.compile( r"PPid:\s+(\d+)" )
      self.doCleanup()

   def addPid( self, pid ):
      with self.lock_:
         qt1( "add child", qv( pid ) )
         t1( "add child", pid )
         self.children_.add( pid )
         self.syncDb()

   def removePids( self, pids ):
      # This is called by individual threads
      with self.lock_:
         for pid in pids:
            qt1( "remove child", qv( pid ) )
            t1( "remove child", pid )
            self.children_.discard( pid )
         self.syncDb()

   def syncDb( self ):
      # write existing children to database
      try:
         with file( self.dbfile_, "w" ) as f:
            f.write( ' '.join( str( x ) for x in self.children_ ) )
      except IOError, e:
         qt0( "fail to write dbfile:", qv( e.strerror ) )

   def _isOrphan( self, pid ):
      try:
         with file( "/proc/%s/status" % pid ) as f:
            for line in f:
               if line.startswith( "PPid:" ):
                  ppid = 'unknown'
                  m = self.ppidRe_.match( line )
                  if m:
                     ppid = m.group( 1 )
                  qt1( "child", qv( pid ), "ppid", qv( ppid ) )
                  t1( "child", pid, "ppid", ppid )
                  return ppid == '1'
      except IOError:
         # just exited?
         pass

      return False

   def _kill( self, pid ):
      try:
         cmdline = file( "/proc/%s/cmdline" % pid ).\
                   read().replace( '\0', ' ' ).strip()
         qt0( "kill child", qv( pid ), repr( cmdline ) )
         t0( "kill child", pid, repr( cmdline ) )
         os.killpg( pid, signal.SIGKILL )
      except OSError:
         pass

   def doCleanup( self ):
      try:
         pids = file( self.dbfile_ ).read()
         for pid in pids.split():
            pid = int( pid )
            # only kill pid if its ppid is 1
            if self._isOrphan( pid ):
               self._kill( pid )
      except IOError:
         pass
      # clean up the file
      self.syncDb()
