#!/usr/bin/python
# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import os
import sys
import ExtensionMgrLib
from ExtensionMgr import logs
import PyClient
import Tracing
import Tac

t0 = Tracing.trace0

def printHelpAndExit():
   print '''small scripts we run daemonizied, or from the %post section:
Restart a list of agents, or provide warning/errors to the operator.
The rpm name is passed with EXTENSIONMGR_EXTENSION env var.
arguments:
  restart (<rpm>|all) <agent-1> [ { <agent-n> } ]
  error   <error>
  warning <warning>
  print   <text>
  agentsToRestart <agent-1> [ { <agent-n } ]
'''
   sys.exit( -1 )

sysname = os.environ.get( "SYSNAME", "ar" )
pc = PyClient.PyClient( sysname, "Sysdb" )
root = pc.agentRoot()
status = root.entity[ ExtensionMgrLib.statusPath() ]

def doRestart( rpm, agents ):
   
   ExtensionMgrLib.log( logs.EXTENSION_RESTARTING_AGENTS, " ".join( agents ) )
   errors = []
   for agent in agents:
      try:
         Tac.run( [ 'killall', '-9', agent ], stdout=Tac.CAPTURE )
      except Tac.SystemCommandError as e:
         errors.append( str( e ) )

   # For the agents we killed successfully, remove them from the agentsToRestart
   global status
   for info in status.info.values():
      if rpm != "" and info.filename != rpm:
         continue
      info.agentsToRestart.clear()
      if errors:
         info.postError = "; ".join( errors )
  
def getStatusInfo( rpm ):
   global status
   for info in status.info.values():
      if info.filename != rpm:
         continue
      return info

def setError( rpm, error ):
   info = getStatusInfo( rpm )
   if info:
      info.postError = error
 
def setWarning( rpm, warning ):
   info = getStatusInfo( rpm )
   if info:
      info.finalizationWarning = warning

def printText( text ):
   ttyname = os.environ.get( "REALTTY" )
   if ttyname and ttyname.startswith( "/dev/" ):
      ttyname = ttyname.split( " " )[ 0 ]
      parts = ttyname.split( "/" )
      ttynum = int( parts[ -1 ] )
      ttypath = "/".join( parts[ 0:-1 ] )
      ttyname = "%s/%d" % ( ttypath, ttynum + 1 )
      f = open( ttyname, "a" )
      f.write( text )
      f.write( "\n" )
      f.close()

def setAgentsToRestart( rpm, agents ):
   info = getStatusInfo( rpm )
   if info:
      info.agentsToRestart.clear()
      for i, a in enumerate( agents ):
         info.agentsToRestart[ i ] = a
      info.affectedAgents.clear()
      for i, a in enumerate( agents ):
         info.affectedAgents[ i ] = a

def main( argv ):
   
   argc = len( argv )
   if argc < 3:
      printHelpAndExit()

   topic = argv[ 1 ]
   if topic == "restart":
      if argc < 4:
         printHelpAndExit()
      rpm = argv[ 2 ]
      args = argv[ 3: ]
      doRestart( rpm, args )
      return

   rpm = os.environ.get( "EXTENSIONMGR_EXTENSION" )
   if not rpm:
      print "env var EXTENSIONMGR_EXTENSION not set to extension filename"
      sys.exit( -1 )

   args = argv[ 2: ]
   if topic == "error":
      setError( rpm, " ".join( args ) )
      return
   if topic == "warning":
      setWarning( rpm, " ".join( args ) )
      return
   if topic == "print":
      printText( " ".join( args ) )
      return
   if topic == "agentsToRestart":
      setAgentsToRestart( rpm, args )
      return
   printHelpAndExit()

if __name__ == '__main__':
   main( sys.argv )

