# Copyright (c) 2006-2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
import Tac, os, Tracing, threading
import time, re, os.path
import ctypes

t0 = Tracing.trace0
t1 = Tracing.trace1
t2 = Tracing.trace2

def markDiagsDirty( config ):
   config.modTime = Tac.now()

def diagsDirty( config ):
   return config.modTime > 0

debugging = False
def mix(mult=1.0):
   import random
   if not debugging:
      return
   time.sleep( random.random() * mult )
   

def splitInTwo( path, char, nth ):
   """Split 'path' into two parts, at the 'nth' appearance of 'char'."""
   # Like split, but splits into two halves, using char as the split
   # character, where n identifies the index of the character to split
   # on
   #
   # splitInTwo( 'a/b/c/d', '/', 1 ) => ('a','b/c/d')
   # splitInTwo( 'a/b/c/d', '/', 2 ) => ('a/b','c/d')
   # splitInTwo( 'a/b/c/d', '/', 4 ) => ('a/b/c/d',None)
   # splitInTwo( 'a/b/c/d', '/', 9999 ) => ('a/b/c/d',None)
   splitPath = path.split( char, nth )
   if nth >= len( splitPath ):
      return path, None
   else:
      return char.join(splitPath[:nth]), splitPath[nth]
   

def _lookup( cwd, path, root, state='/' ):
   # 1. Handle the case of looking at a '/'
   if path.startswith( '/' ):
      if state == '^':
         if len( path ) > 1:
            return _lookup( root, path[1:], root )
         else:
            return root, None
      else:
         return lookup( cwd, path[1:], root )

   # 2. dir/rest
   spath = path.split( '/', 1 )

   # 2a. '..'
   if spath[0] == '..':
      if cwd == root:
         parent = root
      else:
         parent = cwd.parent
      if len(spath) == 1:
         return parent, None
      else:
         return _lookup( parent, path[3:], root )

   i = 1
   while True:
      first, rest = splitInTwo( path, '/', i )

      #2b. "dir/..."
      if first in cwd.subdir:
         child = cwd.subdir.get( first )
         if rest is None:
            return ( child, None )
         else:
            return _lookup( child, path[len(first)+1:], root )

      if rest is None:
         break

      i += 1

   # 2c "notadir..."
   return (cwd, path)

def lookup( cwd, path, root ):
   """Lookup path relative to cwd.  It returns a two-tuple of
   dir,rest.  If dir unambiguously names a directory, then 'rest' is
   None.  Otherwise, cwd is the last directory matched and rest is the
   unmatched portion of the path string, including the slash separator."""
   return _lookup( cwd, path, root, '^' )

def testFromPath( cwd, path, root ):
   result = lookup( cwd, path, root )
   if result is None:
      return None
   ( child, rest ) = result
   # child is of type Diags::Dir
   if rest is None:
      return None
   return child.test.get( rest ) or child.suite.get( rest )

def makepath(path):
   """
    Taken from ASPN python cookbook

        creates missing directories for the given path and
        returns a normalized absolute version of the path.

    - if the given path already exists in the filesystem
      the filesystem is not modified.

    - otherwise makepath creates directories along the given path
      using the dirname() of the path. You may append
      a '/' to the path if you want it to be a directory path.

    from holger@trillke.net 2002/03/18
   """

   from os import makedirs
   from os.path import normpath, dirname, exists, abspath

   dpath = normpath(dirname(path))
   if not exists(dpath):
      makedirs(dpath)
   return normpath(abspath(path))
 

def _appendToArgv( argv, runOrTest, shellQuote=False ):
   i = 0
   while True:
      a = runOrTest.argv.get( i )
      if not a:
         break
      if shellQuote:
         a = Tac.shellQuote( a )
      argv.append(a)
      i += 1
   return argv

def testArgs( test ):
   argv = []
   _appendToArgv( argv, test )
   return argv

def findCellId( test ):
   d = test.parent
   while d.tacType.fullTypeName == "Diags::Dir":
      c = d.cellId
      if c == d.activeSupeCellId:
         return 0           # Should be Diags::Dir::activeSupeCellId
      if c != 0:
         return c
      d = d.parent
   return 0

def runArgs( runObj ):
   argv = [ "sudo" ]
   # BUG3214 All diags are run as root for now. At the point that we explicitly
   # whitelist commands that can run as root, adding 'sudo' will probably become
   # the burden of the person registering the command.
   testCellId = findCellId( runObj.test )
   import Cell
   if testCellId != 0 and testCellId != Cell.cellId():
      argv += ['issh','127.1.%d.%d' % (0xff & (testCellId >> 8), testCellId & 0xff )]
   # BUG - else we should check this is the active supervisor
   onCell = ( testCellId != 0 )
   _appendToArgv( argv, runObj.test, shellQuote=onCell )
   _appendToArgv( argv, runObj, shellQuote=onCell )
   return argv

def setTestArgs( test, args ):
   if type( args ) is str:
      args = args.split(" ")
   i = 0
   for a in args:
      test.argv[i] = a
      i += 1

def resultString( result ):
   if result == "passed":
      return "PASS"
   if result == "failed":
      return "FAIL"
   if result == "skipped":
      return "SKIP"
   if result == "interrupted":
      return "INTR"
   return result[:4]

def testPath( test ):
   p = test.parent.fullName + "/" + test.name
   p = re.sub( '^.*/hardware/diags/', '/', p, 1 )
   return p

class TestOptions(object):
   def __init__( self, stopOnError=None, verbose=None,
                 log=None, logDir=None, root=None ):
      self._setOptions( stopOnError=stopOnError, verbose=verbose,
                        log=log, logDir=logDir, root=root )

   def _setOptions( self, **kargs ):
      self.__dict__.update( kargs )

   def __setattr__( self, name, value ):
      raise AttributeError( "TestOptions cannot be changed" )

   def __getattr__( self, name ):
      return self.__dict__.get( name )


class RunRecorder( object ):

   def __init__( self, r, options=None ):
      assert r
      self.options = options or TestOptions()
      self.run = r
      testPathName = r.test.parent.fullName + "/" + r.test.name
      testPathName = re.sub( '^.*/hardware/diags/', '/', testPathName, 1 )
      self.testPath_ = testPathName

      # Open up <logdir>/<testname> for append, creating directories if necessary
      self.logdirfile_ = None
      if options.logDir:
         self.outpath_ = os.path.join( options.logDir, self.testPath_[1:] )
         try:
            self.outpath_ = makepath( self.outpath_ )
            self.logdirfile_ = file( self.outpath_, 'a' )
         except IOError, e:
            print "unable to open file '%s':" % self.outpath_
            print e
      t2( "RunRecorder construction complete", Tac.threadId().pthread )

   def formatOutputLine( self, msg ):
      return "%5s: %s.%s %s" % ( 
         msg, self.testPath_, self.run.id, time.strftime( "%X %x %Z" ) )

   def start( self ):
      s = self.formatOutputLine( "START" )
      self.run.startTime = time.time()
      options = self.options
      if options.verbose:
         print s
      if options.log:
         print >> options.log, s
      if self.logdirfile_:
         print >> self.logdirfile_, s

   def finish( self, result, output ):
      assert result in ( 'none', 'passed', 'failed', 'skipped', 'interrupted' )
      r = self.run
      # Now do the generic post-test stuff, updating finish time, result, passcount
      # and writing output to the logdir 
      r.finishTime = time.time()
      r.result = result
      if result == "passed":
         r.passCount += 1
      if result == "failed":
         r.failCount += 1
      options = self.options
      if options.logDir or options.log or options.verbose:
         s = self.formatOutputLine( resultString(result) )
      if self.logdirfile_:
         r.output = self.outpath_
         if output:
            self.logdirfile_.write( output )
         print >> self.logdirfile_, s
      if options.verbose:
         if result == "failed" and output:
            import cStringIO
            for i in cStringIO.StringIO( output ):
               print ">", i,
         print s
      if options.log:
         if output:
            print >> options.log, output
         print >> options.log, s
         
      self.run = None
      t2( "run finished and set to None", Tac.threadId().pthread )

def nextRun( test ):
   i = 0
   while True:
      if test.run.has_key( i ):
         i += 1
      else:
         break
   nextR = test.newRun( i, test.self )
   return nextR

def _run( r, options ):
   semaphore = threading.Semaphore( 0 )
   t1( "created semaphore", semaphore )
   def _finish( _r, result, output ):
      # NOTE - the ScriptRunner prints all output to stdout,
      # so we can ignore the output here
      t1( "releasing semaphore", semaphore )
      semaphore.release()
   sr = TestRunner( r, options, _finish )
   t2( '_run.TestRunner is constructed' )
   # We need to make sure activity thread runs.  This should really be automatic
   Tac.activityThread().start()
   t1( "acquiring semaphore", semaphore )
   semaphore.acquire()
   # Stop activity thread if started
   t2( "acquired semaphore", semaphore )
   return sr.result()

def run( r, loops=1, stopOnError=True, verbose=False, log=None, logdir=None ):
   t0( "running", testPath(r), Tac.threadId().pthread)
   if type( log ) is str:
      try:
         log = file( log, 'a' )
      except IOError, e:
         print "unable to open file '%s':" % log
         print e
         # Continue on without a log file
         log = None
   rootpath = r.fullName
   m = re.match( '^.*/hardware/diags', rootpath )
   root = Tac.entity( m.group() )
   if r.test.showLog == True:
      import sys
      log = sys.stdout
   options = TestOptions( verbose=verbose, stopOnError=stopOnError, log=log,
                          logDir=logdir, root=root )
   for i in xrange(loops):
      result = _run( r, options )

      if result == "interrupted":
         print "Aborting", r.test.name, "after", i+1, "iterations due to interrupt"
         break

      if stopOnError and result == "failed":
         print "Aborting", r.test.name, "after", i+1, "iterations due to failed run"
         break

# So the way this works when the top level is a suite:
#   1. run creates a SuiteRunner which calls Semaphore.v as the finish function
#   2. the main script calls Semaphore.p which blocks until v (or sigint)

#   3a. The SuiteRunner if serial creates a test runner for the first
#       test and passes it a finish function of handleOneTest.
#       handleOneTest creates a test runner for the next test with a
#       finish function equal to handleOneTest.  When the last test
#       finishes, then the status of the suite is updated and the
#       finish function passed into the ctor is called.

#   3b. The SuiteRunner if serial parallel creates a test runner for
#       each component test and passes it a finish function of handleOneTest.
#       handleOneTest waits until all tests have finished, then the
#       status of the suite is updated and the finish function passed
#       into the ctor is called.
#      
#   3c. The ScriptRunner forks the child process with the right args and
#       creates a ChildSiginfo for that child and returns
#       When the ChildSigInfo's signalCode changes, our process is done.
#       signalCode.  When finished, test results and output are handled
#       the run result is updated and the runner's finish function is called


class ChildWatcher( Tac.File ):
   # Figure out the reference loops

   def __init__( self, pipefd, pid, finish, parent ):
      parent.childWatcher_ = self
      self.finish_ = Tac.WeakBoundMethod( finish, ignoreDead=True )
      self.pid_ = pid
      Tac.File.__init__( self, pipefd )
      mix( .1 )
      self.fdEntity_.notifyOnReadable = True
      #mix( .03 )
      t2( "ChildWatcher", self.notifier_.name,"is constructed" )

   @Tac.handler( 'readableCount' )
   def handleReadableCount( self ):
      try:
         pid, rc = os.waitpid( self.pid_, os.WNOHANG )
         t1( "Child", self.pid_, "completed with rc", rc )
      except OSError, e:
         # Enter this on multiple attribute notification, I think.
         if e.errno != 10:
            t1( "waitpid got errno", e.errno, e )
            raise
         return
      if pid == 0: 
         self.fdEntity_.notifyOnReadable = True
         return
      self.fdEntity_.notifyOnReadable = False
      self.finish_( rc )
      mix( .005 )                       # longer delay here prevents the error

   def close( self ):
      t0( "ChildWatcher.close", self.notifier_.name, Tac.threadId().pthread )
      self.closeFile()
      Tac.File.close( self )
      t0( "ChildWatcher.close", self.notifier_.name, "done" )

PR_SETDEATHSIG = 1
SIGKILL = 9
libC = None
def getLibc():
   global libC
   if libC is None:
      libC = ctypes.CDLL( 'libc.so.6' )
      libC.prctl.argtypes = [ ctypes.c_int ] + [ ctypes.c_ulong ]  * 4
      libC.prctl.restype = ctypes.c_int
   return libC

def dieWhenParentDies():
   l = getLibc()
   rc = l.prctl( PR_SETDEATHSIG, SIGKILL, 0, 0, 0 )
   assert rc == 0
   if os.getppid() == 1:
      os.kill( os.getpid(), SIGKILL )

def _fork( argv, stdin, stdout, stderr, preexec_fn ):
   if type(stdin) != int:
      stdin = stdin.fileno()
   if type(stdout) != int:
      stdout = stdout.fileno()
   if type(stderr) != int:
      stderr = stderr.fileno()
   pid = os.fork()
   if pid == 0:
      try:
         dieWhenParentDies()
         os.dup2( stdin, 0 )
         os.dup2( stdout, 1 )
         os.dup2( stderr, 2 )
         preexec_fn()
         #pylint: disable-msg=W0702
         try:
            # Ideally we'd just use:
            #  os.execvp( argv[0], argv )
            # However, we currently run all of our diags using 'sudo' (BUG3214),
            # which unfortunately starts by closing all file descriptors. This
            # screws up our use of pipes to determine when the test script has
            # exited. So instead we're forced to fork again, and have the parent
            # monitor the child and exit with the same return code. This way
            # the pipe stays opened until the diag actually exits.
            diagpid = os.fork()
            if diagpid == 0:
               dieWhenParentDies()
               os.execvp( argv[0], argv )
            else:
               # Wait for the child to exit, then exit with the same return code
               ( pid, status ) = os.waitpid( diagpid, 0 )
               if os.WIFSIGNALED( status ):
                  rc = -os.WTERMSIG( status )
               else:
                  assert os.WIFEXITED( status )
                  rc = os.WEXITSTATUS( status )
               #pylint: disable-msg=W0212
               os._exit( rc )
         except:
            print 'Failed to exec test script', argv[0]
            #pylint: disable-msg=W0212
            os._exit( 1 )
      except:
         print 'Failed second fork', argv[0]
         #pylint: disable-msg=W0212
         os._exit( 1 )
   t1( "Child", pid, "forked with argv", argv )
   return pid
   
def fork( argv, stdin, stdout, stderr, preexec_fn ):
   import gcctx
   with gcctx.gc_disabled():
      with Tac.ActivityLockHolder():
         x = _fork( argv, stdin, stdout, stderr, preexec_fn )
         return x

class ScriptRunner(object):
   """ Run an individual test. """
   def __init__(self, runObj, options, finish ):
      self.run_ = runObj
      self.recorder_ = RunRecorder( runObj, options )
      self.finish_ = finish
      
      import tempfile
      (fd, name) = tempfile.mkstemp( '.out', runObj.name + "." )
      self.tmpFd_ = fd
      self.tmpName_ = name
      # output goes to temp file
      stdin = file( "/dev/null", "r")
      stdout = fd
      stderr = fd
      (rpipe, wpipe) = os.pipe()
      def closePipe():
         os.close(rpipe)

      argv = runArgs( runObj )
      self.recorder_.start()
      child = fork( argv,
                    stdin=stdin,
                    stdout=stdout,
                    stderr=stderr,
                    preexec_fn=closePipe)
      mix( .003 )
      os.close( wpipe )
      self.child_ = child
      self.childWatcher_ = \
            ChildWatcher( rpipe, child, self._handleChildExit, self )

   # returnCode is the result of waitPid()
   def _handleChildExit( self, returnCode ):
      mix( .010 )
      os.lseek( self.tmpFd_, 0, 0 )
      output = ""
      while True:
         newOutput = os.read( self.tmpFd_, 4096 )
         if not newOutput:
            break
         output += newOutput
      os.close( self.tmpFd_ )
      os.unlink( self.tmpName_ )
      result = (returnCode == 0) and "passed" or "failed"
      self.recorder_.finish( result, output )
      self.finish_( self.run_, result, output )
      del self.childWatcher_

   def result( self ):
      return self.run_.result

   def __del__( self ):
      t0( "ScriptRunner.__del__" )

def TestRunner( runObj, options, finish, blocking=False ):
   if runObj.parent.tacType.fullTypeName == "Diags::Script":
      runner = ScriptRunner
   else:
      runner = SuiteRunner
   return runner( runObj, options, finish )


class SuiteRunner(object):
   """ Run a suite of tests.  Currently is completely serial. """
   def __init__(self, runObj, options, finish ):
      assert runObj
      self.lock = threading.RLock()
      self.run_ = runObj
      self.recorder_ = RunRecorder( runObj, options )
      self.recorder_.start()
      assert self.recorder_.run
      self.options_ = options
      self.finish_ = finish 
      self.nodesToDo_ = sorted( runObj.test.node.keys() )
      self.nodesPending_ = {}
      t2( 'Advancing SuiteRunner', runObj )
      with self.lock:
         self._advance()
      t2( 'Starting the recorder', runObj )

   def _result( self ):
      # Compute whether we passed or failed
      result = "passed"
      runObj = self.run_
      test = self.run_.parent
      for tid in test.node:
         if test.node[tid].enabled:
            nodeRun = runObj.run.get(tid)
            if not nodeRun:
               continue
            nodeResult = nodeRun.result
            if nodeResult == "passed" or nodeResult == "skipped":
               continue
            elif nodeResult == "failed" or nodeResult == "interrupted":
               result = nodeResult
               break
      return result

   def _advance( self ):
      t1( "advancing", self.run_ )
      while True:
         node = self._nextNode()
         if not node:
            break
         else:
            started = self._go( node )
            if started and not self.run_.parent.execution == "parallel":
               break

      if not (self.nodesToDo_ or self.nodesPending_):
         t1( "finished", self.run_, Tac.threadId().pthread )
         result = self._result()
         assert self.recorder_.run
         self.recorder_.finish(result, "")
         #assert self.recorder_.run
         if self.finish_:
            self.finish_( self.run_, result, "" )  # suites have no output

   def _nextNode( self ):
      runObj = self.run_
      while True:
         if not self.nodesToDo_:
            break
         nodeId = self.nodesToDo_.pop( 0 )
         node = runObj.test.node[nodeId]
         if node.enabled:
            return node
      return None

   def _go( self, node ):
      """Return True if we started a test"""
      t1( "running node", node, "of", self.run_ )
      runObj = self.run_
      nodeId = node.id

      t = node.test
      if not t:
         t = testFromPath( runObj.parent, node.path, self.options_.root )
      if not t:
         return False
         # If we can't look up the test by name, then simply return

      # memberRun points to a run of the component test object.  Its
      # parent is the component test. The same run is used for every
      # iteration of this test during this run of the suite.  A
      # pointer to runOfNode is stored in the Suite's run, keyed by
      # the current nodeId.
      memberRun = runObj.run.get( nodeId )
      if not memberRun:
         memberRun = nextRun( t )
         runObj.run[nodeId] = memberRun

      if not node.enabled:
         memberRun.result = 'skipped' # this is not tested
      else:
         with self.lock:
            child = TestRunner( memberRun,
                                self.options_,
                                self.handleCompletion )
            t2( 'Suite TestRunner is constructed' )
            self.nodesPending_[memberRun] = child
      return True
      
   def handleCompletion( self, runKey, result, output ):
      with self.lock:
         del self.nodesPending_[ runKey ]
         if result == "failed" and self.options_.stopOnError:
            self.nodesToDo_ = []
            print "Aborting after failed test"
         self._advance()

   def result( self ):
      return self.run_.result

# Tests that need to block (CLI) simply run a semaphore out of the handler
