# Copyright (c) 2016 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import cPickle
import errno
import os
import socket
import sys
import time
import traceback

import CliCommon
import CliInputWrapper as CliInput
import FastServ
import FastServUtil
import TerminalUtil
import Tracing

traceHandle = Tracing.Handle( 'CliShellLib' )
log = traceHandle.trace0
warn = traceHandle.trace1
info = traceHandle.trace2
trace = traceHandle.trace3
debug = traceHandle.trace4

class ConnectTimeout( Exception ):
   pass

class CliShell( object ):
   def __init__( self, cliInterface ):
      self.cliInterface_ = cliInterface
      self.termTitleCtrlSeq_ = None
      self.runCmd_ = False

   def updateTerminalCtrlSeq( self ):
      self.termTitleCtrlSeq_ = TerminalUtil.terminalCtrlSeq()

   def terminalTitleStr( self, title ):
      """ Returns a string that will set the terminal's title, or
      empty string if the terminal cannot be controlled. """
      seq = self.termTitleCtrlSeq_
      if not seq:
         return ""
      return "%s%s%s" % ( seq[ 0 ], title, seq[ 1 ] )

   def complete( self, key, cmd, pos ):
      # Wrap tab completion and help output
      assert pos == len( cmd )
      result = ""
      if key == "\t":
         result = self.cliInterface_.tabComplete( cmd )
      elif key == "?":
         print( "?" )
         result = self.cliInterface_.printHelp( cmd )
      else:
         assert False

      return result

   def prompt( self ):
      promptPrefix = os.environ.get( 'CLI_PROMPT_PREFIX', '' )
      return self.cliInterface_.prompt( promptPrefix )

   def _writeTerminalTitleStr( self, prompt ):
      # Any terminal control characters are written outside
      # of readline, which includes its length in the
      # terminal width calculation. This prevents odd line
      # wrapping behavior.

      # Sometimes the first prompt is missing because the write got interrupted, so
      # retry the write in case of EINTR, no idea which signal is was!
      printed = False
      while not printed:
         try:
            os.write( sys.stdout.fileno(), self.terminalTitleStr( prompt ) )
            printed = True
         except OSError as e:
            if e.errno == errno.EINTR:
               continue
            if e.errno != errno.EIO:
               raise
            # write can give an EIO if the primary end of the pseudo-tty has
            # been closed (which can happen if the parent, login or sshd, has
            # been killed, perhaps by the OOM killer).  If this has happened,
            # the Cli process ought to have been sent a SIGHUP, which ought to
            # have killed it.  However, empirically it seems that this is not
            # always the case.
            raise EOFError()

   def startReadlineLoop( self, firstPromptCallback=None ):
      log( "start readline loop" )
      TerminalUtil.enableCtrlZ( False )
      self.updateTerminalCtrlSeq()

      while True:
         trace( "Waiting for command" )
         try:
            prompt = self.prompt()
            trace( "Received prompt", prompt )
            self._writeTerminalTitleStr( prompt )
            if firstPromptCallback:
               firstPromptCallback()
               firstPromptCallback = None
            line = ""
            # Only enable Ctrl-Z during readline so Ctrl-Z can still be
            # turned into KeyboardInterrupt, but disable it while running
            # commands, or we may hang if user presses ctrl-Z (BUG34718).
            with TerminalUtil.CtrlZ( True ):
               historyKeys = self.cliInterface_.getHistoryKeys()
               trace( "Reading line" )
               line = CliInput.readline( prompt, self.complete, *historyKeys )
               trace( "Line read:", line )
         except KeyboardInterrupt:
            trace( "Keyboard interrupt:" )
            print()
            self.cliInterface_.exitConfigMode()
            continue
         except EOFError:
            trace( "EOF ERROR:" )
            print()
            break
         except ConnectTimeout:
            raise
         except:
            log( repr( traceback.format_exc() ) )
            # the call below will blowup while pickling certain exception...
            self.cliInterface_.handleCliException( sys.exc_info(),
                                                   '(incomplete)' )
            continue

         if not line:
            continue

         try:
            trace( "Trying to run:", line )
            self.runCmd_ = True
            self.cliInterface_.runCmd( line )
         except SystemExit:
            trace( "System exit, leaving readline loop", line )
            break
         except:
            self.cliInterface_.handleCliException( sys.exc_info(), line )
         finally:
            self.runCmd_ = False
            historyKeys = self.cliInterface_.getOrigModeHistoryKeys()
            if not CliInput.addHistory( line, *historyKeys ):
               # parent history mode does not exist
               for key in self.cliInterface_.getParentHistoryKeys():
                  CliInput.newHistoryMode( *key )
               r = CliInput.addHistory( line, *historyKeys )
               assert r

      self.cliInterface_.exitConfigMode()
      self.cliInterface_.endSession()

class _RemoteRequest( object ):
   def __init__( self, method, args, kwargs ):
      self.method_ = method
      self.args_ = args
      self.kwargs_ = kwargs

class _RemoteAttr( object ):
   def __init__( self, method, cliInputSock ):
      self.method_ = method
      self.cliInputSock_ = cliInputSock

   def __call__( self, *args, **kwargs ):
      request = _RemoteRequest( self.method_, args, kwargs )
      requestStr = cPickle.dumps( request )
      FastServUtil.writeString( self.cliInputSock_, requestStr )
      responseStr = FastServUtil.readString( self.cliInputSock_ )
      if not responseStr:
         return
      response = cPickle.loads( responseStr )
      if ( issubclass( type( response ), Exception ) or
           type( response ) is SystemExit ):
         raise response
      return response

class RemoteCliInput( object ):
   def __init__( self, cliInputSock ):
      self.cliInputSock_ = cliInputSock

   def __getattribute__( self, name ):
      if name == 'cliInputSock_':
         return super( RemoteCliInput, self ).__getattribute__( name )
      return _RemoteAttr( name, self.cliInputSock_ )

class CliConnector( object ):
   def _createArgStr( self, argv ):
      return '\x00'.join( argv )

   def _createEnvStr( self, env ):
      return '\x00'.join( [ '%s\x00%s' % ( key, value ) for
                          key, value in env.iteritems() ] )

   def _createAndSendSock( self, sock ):
      s1, s2 = socket.socketpair( socket.AF_UNIX, socket.SOCK_STREAM, 0 )
      FastServ.sendFds( sock.fileno(), [ s2.fileno() ] )
      s2.close()
      return s1

   def _connectToBackend( self, sysname, argv, env, uid, gid, ctty ):
      sock = socket.socket( socket.AF_UNIX, socket.SOCK_STREAM, 0 )
      startTime = time.time()
      # the CliServer might not be up at the beginning of time. So we will instead
      # keep track of our startTime, and if we haven't been able to connect
      # due to a connection refused we will raise that error.
      while True:
         try:
            sock.connect( CliCommon.CLI_SERVER_ADDRESS_FMT % sysname )
            break # this means we were able to connect. break while loop
         except socket.error as e:
            # if the error is something other than connection refuse we raise error
            if e.errno != errno.ECONNREFUSED:
               raise

            currTime = time.time()
            # if we have waited for more than 120 seconds we also raise the error
            # otherwise we will continue our loop
            if currTime - startTime >= 120:
               raise

            # we sleep a bit before we retry
            time.sleep( 0.1 )

      signalSock = self._createAndSendSock( sock )
      FastServUtil.writeString( sock, self._createArgStr( argv ) )
      FastServUtil.writeString( sock, self._createEnvStr( env ) )
      FastServUtil.writeString( sock, str( uid ) )
      FastServUtil.writeString( sock, str( gid ) )
      FastServUtil.writeString( sock, ctty )
      return sock, signalSock

class EapiCliConnector( CliConnector ):
   def __init__( self, stateless=True ):
      self.stateless_ = stateless

   def connectToBackend( self, sysname, argv, env, uid, gid ):
      sock, signalSock = self._connectToBackend( sysname, argv, env, uid, gid, '' )
      os.write( sock.fileno(), 'c' if self.stateless_ else 'd' )
      responseSock = self._createAndSendSock( sock )
      requestSock = self._createAndSendSock( sock )
      statisticsSock = self._createAndSendSock( sock )
      sock.close()
      return ( signalSock, responseSock, requestSock, statisticsSock )

class NonTtyCliConnector( CliConnector ):
   def connectToBackend( self, sysname, argv, env, uid, gid, stdinFd, stdoutFd,
                         stderrFd ):
      sock, signalSock = self._connectToBackend( sysname, argv, env, uid, gid, '' )
      os.write( sock.fileno(), 'u' )
      FastServ.sendFds( sock.fileno(), [ stdinFd, stdoutFd, stderrFd ] )
      sock.close()
      return signalSock

class TtyCliConnector( CliConnector ):
   def connectToBackend( self, sysname, argv, env, uid, gid, ctty, secondaryPty ):
      sock, signalSock = self._connectToBackend( sysname, argv, env, uid, gid, ctty )
      os.write( sock.fileno(), 't' )
      FastServ.sendFds( sock.fileno(), [ secondaryPty ] )
      requestSock = self._createAndSendSock( sock )
      cliInputSock = self._createAndSendSock( sock )
      sock.close()
      return signalSock, requestSock, cliInputSock

class SimpleCliConnector( CliConnector ):
   def connectToBackend( self, sysname, argv, env, uid, gid, stdoutFd, stderrFd ):
      sock, signalSock = self._connectToBackend( sysname, argv, env, uid, gid, '' )
      os.write( sock.fileno(), 's' )
      FastServ.sendFds( sock.fileno(), [ stdoutFd, stderrFd ] )
      requestSock = self._createAndSendSock( sock )
      sock.close()
      return signalSock, requestSock
