# Copyright (c) 2007-2012 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import inspect
import json
import os
import traceback
import types

import Tac
import Tracing

from CliCommon import JsonRpcErrorCodes

traceHandle = Tracing.Handle( 'CapiJsonRpcBase' )
warn = traceHandle.trace1
info = traceHandle.trace2
trace = traceHandle.trace3
debug = traceHandle.trace4
nitty = traceHandle.trace5

class JsonRpcError( Exception ):
   def __init__( self, code, msg, data=None ):
      """ Base class for all JSON RPC errors. Contains a numeric
      `code` corresponding to a JsonRpc error code and a string
      message in `msg`. The error can also store a json-serializable
      object in `data` containing data corresponding to the
      Exception"""

      Exception.__init__( self )
      self.code = code
      self.msg = msg
      self.data = data

   def __str__( self ):
      return "JSON-RPC Error %d %s: %r" % ( self.code, self.msg, self.data )

class ParseError( JsonRpcError ):
   """ Json RPC error indicating invalid JSON was received by the server """
   def __init__( self, msg ):
      JsonRpcError.__init__( self, JsonRpcErrorCodes.PARSE_ERROR, msg )

class InvalidRequestError( JsonRpcError ):
   """ Json RPC error indicating the JSON sent is not a valid Request object """
   def __init__( self, msg ):
      JsonRpcError.__init__( self, JsonRpcErrorCodes.INVALID_REQUEST, msg )

class InvalidMethodError( JsonRpcError ):
   """ Json RPC error indicating the method specified in the request is invalid """
   def __init__( self, msg ):
      JsonRpcError.__init__( self, JsonRpcErrorCodes.METHOD_NOT_FOUND, msg )

class InvalidParamsError( JsonRpcError ):
   """ Json RPC error indicating invalid parameters for the specified method """
   def __init__( self, msg ):
      JsonRpcError.__init__( self, JsonRpcErrorCodes.INVALID_PARAMS, msg )

class InternalError( JsonRpcError ):
   """ Json RPC error indicating an unexpected server-side error ocurred """
   def __init__( self, msg ):
      JsonRpcError.__init__( self, JsonRpcErrorCodes.INTERNAL_ERROR, msg )

#
# The JsonRPC plugin base class.
#
class JsonRpcBase( object ):
   def __init__( self ):
      """ Base Json RPC object """
      self.requestId_ = None
      self.streamedResponse_ = False
      self.outputFd_ = None
      self.requestCount_ = 0
      self.commandCount_ = 0

   @staticmethod
   def rpc( m ):
      """Decorate those methods which are RPC
      procedure implementations with this."""
      m.isRpcProcedure = True
      return m

   def handlePostRawText( self, request ):
      """Override handlePostRawText to accept HTTP POST raw text methods."""
      trace( 'handlePostRawText entry' )
      msg = 'POST raw text not supported'
      trace( 'handlePostRawText exit', msg )
      raise InvalidRequestError( msg )

   #
   # The following methods and classes are for use by the CAPI
   # infrastructure and should not be called or overridden.
   #
   def setRequestCount( self, requestCount ):
      self.requestCount_ = requestCount

   def setCommandCount( self, commandCount ):
      self.commandCount_ = commandCount

   def outputFd( self ):
      return self.outputFd_ if self.streamedResponse_ else None

   def processRequest( self, request, outputFd ):
      trace( 'processRequest enter' )
      self.outputFd_ = outputFd
      try:
         self._processRequestInternal( request )
      except Exception as e: # pylint: disable-msg=W0703
         print "ERROR:", e
         traceback.print_exc()  # Log stack trace to agent log.
      finally:
         requestCount = self.requestCount_
         commandCount = self.commandCount_
         self.requestId_ = None
         self.streamedResponse_ = False
         self.outputFd_ = None
         self.requestCount_ = 0
         self.commandCount_ = 0
      return requestCount, commandCount

   def _processRequestInternal( self, request ):
      try:
         procedure, params = self._parseRequest( request )
      except JsonRpcError as e:
         result = self._generateFullErrorResponse( e.code, e.msg, e.data )
         os.write( self.outputFd_, result )
         return
      except Exception as e:
         traceback.print_exc()  # Log stack trace to agent log.
         result = self._generateFullErrorResponse( JsonRpcErrorCodes.INTERNAL_ERROR,
                                                   str( e ) )
         os.write( self.outputFd_, result )
         return

      os.write( self.outputFd_, self._generateJsonRespHeader() )
      if not self.streamedResponse_:
         self._processRequestNonStreaming( procedure, params )
      else:
         self._processRequestStreaming( procedure, params )

   def _processRequestStreaming( self, procedure, params ):
      os.write( self.outputFd_, '"result": [' )
      try:
         self._executeRpcProcedure( procedure, params )
      except JsonRpcError as e:
         os.write( self.outputFd_, ']' )
         os.write( self.outputFd_, ', "code": ' )
         os.write( self.outputFd_, json.dumps( e.code ) )
         os.write( self.outputFd_, ', "message": ' )
         os.write( self.outputFd_, json.dumps( e.msg ) )
         os.write( self.outputFd_, '}' )
      else:
         os.write( self.outputFd_, ']}' )

   def _processRequestNonStreaming( self, procedure, params ):
      out = None
      try:
         result = self._executeRpcProcedure( procedure, params )
         out = [ '"result": ', self._py2json( result ), '}' ]
      except JsonRpcError as e:
         out = [ self._generatePartialErrorResponse( e.code, e.msg, e.data ) ]
      except Exception as e:
         traceback.print_exc()  # Log stack trace to agent log.
         errCode = JsonRpcErrorCodes.INTERNAL_ERROR
         out = [ self._generatePartialErrorResponse( errCode, str( e ) ) ]
      finally:
         if out is None:
            print "Out was None!!!", procedure, params
            traceback.print_exc()  # log stack trace to agent log.
            out = []

         for i in out:
            os.write( self.outputFd_, i )

   def _parseRequest( self, request ):
      """ parses the request and returns the procedure and params associated with
          this request """
      trace( '_produceContent entry', request )

      if not request:
         msg = 'No data provided in HTTP POST body'
         trace( '_produceContent raise', msg )
         raise ParseError( msg )

      # The posted content may be either json or plain text. Infer
      # this from the first character. If plain text, pass to the
      # special procedure. Otherwise process as JSON-RPC.
      if request[ 0 ] == '{':
         trace( '_produceContent exit' )
         procedure, params = self._parseJsonRequest( request )
      else:
         trace( '_produceContent exit' )
         procedure = self.handlePostRawText
         params = [ request ]

      self._validateRpcMethodParams( procedure, params )
      return procedure, params

   def _parseJsonRequest( self, request ):
      def byteify( val ):
         if isinstance( val, dict ):
            return { byteify( key ): byteify( value )
                     for key, value in val.iteritems() }
         elif isinstance( val, list ):
            return [ byteify( element ) for element in val ]
         elif isinstance( val, unicode ):
            return val.encode( 'utf-8' )
         else:
            return val

      # Parse the request content.
      try:
         trace( 'requestObject:', request )
         # cjson's decoding is buggy with utf8, so use json (see BUG 106554),
         # use !strict to allow \n in strings ala cjson (json is also impl in c!)
         decoder = json.JSONDecoder( strict=False )
         decodedObj = decoder.decode( request )
         requestObject = byteify( decodedObj )
      except ValueError as e:
         msg = 'Invalid JSON: %s' % e
         trace( '_produceContentFromJson raise', msg )
         raise ParseError( msg )

      if not isinstance( requestObject, dict ):
         msg = 'Invalid request: expected a top level JSON object'
         trace( '_validateRequestStructure raise', msg )
         raise InvalidRequestError( msg )

      # Make sure we're dealing with a sane request that complies with
      # the JSON RPC 2.0 spec. TODO: Accept JSON-RPC 1.0 commands.
      self._validateJsonRpcField(
         requestObject, "jsonrpc", expectedTypes=( basestring, ),
         expectedValue="2.0" )
      self._validateJsonRpcField( requestObject, "id",
                                  expectedTypes=( types.NoneType, basestring,
                                                  int, float, long ) )
      # Set the id now, so any errors we send have a valid id in the response
      self.requestId_ = requestObject[ "id" ]

      # figure if we are streaming or not
      if "streaming" in requestObject:
         self._validateJsonRpcField( requestObject, "streaming",
                                     expectedTypes=( bool ) )
         self.streamedResponse_ = requestObject[ "streaming" ]

      # get the method
      self._validateJsonRpcField( requestObject, "method",
                                  expectedTypes=( basestring, ) )

      # Get the procedure corresponding to the specified JSON RPC Method
      method = requestObject[ "method" ]
      try:
         procedure = getattr( self, method )
      except AttributeError as e:
         msg = "Invalid 'method' specified: %r" % method
         trace( '_produceContentFromJson raise', msg )
         raise InvalidMethodError( msg )

      # The procedure must be annotated as such.
      if not hasattr( procedure, 'isRpcProcedure' ):
         msg = "Invalid 'method' specified: %r" % method
         trace( '_produceContentFromJson raise', msg )
         raise InvalidMethodError( msg )

      params = requestObject.get( "params", None )
      if not isinstance( params, ( list, dict ) ):
         msg = "Invalid type of 'params' specified: expected array or object"
         trace( '_produceContentFromJson raise', msg )
         raise InvalidParamsError( msg )

      if isinstance( params, dict ) and params.get( "streaming" ):
         self.streamedResponse_ = True

      return procedure, params

   def _validateRpcMethodParams( self, procedure, params ):
      """ This takes a procedure and params and valdiates that
          the params match up to the names/quantity that the procedure expects """
      expectedArgs, _, _, defaults = inspect.getargspec( procedure )
      # Ignore the "self" argument:
      expectedArgs = expectedArgs[ 1 : ]
      mandatoryArgs = []
      if defaults:
         mandatoryArgs = expectedArgs[ : - len( defaults ) ]

      if isinstance( params, dict ):
         # Check all mandatory args are specified
         for arg in mandatoryArgs:
            if arg not in params:
               msg = "Expected parameter %r for method %r not provided" % (
                  arg, procedure.__name__ )
               raise InvalidParamsError( msg )
         # parameters are valid, so now run the method:
         return True

      elif isinstance( params, list ):
         if len( params ) < len( mandatoryArgs ):
            msg = "%r takes at least %d arguments (%d given)" % (
               procedure.__name__, len( mandatoryArgs ), len( params ) )
            raise InvalidParamsError( msg )
         if len( params ) > len( expectedArgs ):
            msg = "%r takes at most %d arguments (%d given)" % (
               procedure.__name__, len( expectedArgs ), len( params ) )
            raise InvalidParamsError( msg )
         # parameters are valid, so now run the method:
         return True
      assert False, "Unexpected parameter format %r" % params

   def _executeRpcProcedure( self, procedure, params ):
      trace( '_executeRpcProcedure: Executing procedure', procedure, params )
      if isinstance( params, dict ):
         return procedure( **params )

      elif isinstance( params, list ):
         return procedure( *params )

   def _generateFullErrorResponse( self, code, msg, data=None ):
      trace( '_generateFullErrorResponse entry', code, msg )
      initialResponse = self._generateJsonRespHeader()
      errorResponse = self._generatePartialErrorResponse( code, msg, data )
      return '%s %s' % ( initialResponse, errorResponse )

   def _generatePartialErrorResponse( self, code, msg, data=None ):
      errorObj = { "code": code, "message": str( msg ) }
      if data:
         errorObj[ "data" ] = data
      return '"error": %s}' % ( self._py2json( errorObj ) )

   def _generateJsonRespHeader( self ):
      idStr = self._py2json( self.requestId_ )
      return '{"jsonrpc": "2.0", "id": %s, ' % idStr

   def _py2json( self, jsonObject ):
      """Sends a json-rpc response."""
      trace( '_py2json entry' )
      try:
         return json.dumps( jsonObject )
      except UnicodeDecodeError:
         trace( '_py2json raise EncodeError' )
         raise JsonRpcError( JsonRpcErrorCodes.INTERNAL_ERROR,
                             'response contains incompatible text encoding' )

   def _validateJsonRpcField( self, requestObject, field, expectedTypes,
                              expectedValue=None ):
      """ Validate `field` in the `requestObject`, checking that it
      exists, it's value is an instance of `expectedType` and, if
      matches the `expectedValue` if provided"""

      trace( "_validateJsonRpcField field=%r, expectedTypes=%r, expectedValue=%r" %
             ( field, expectedTypes, expectedValue ) )
      if field not in requestObject:
         msg = 'Expected field %r not specified' % field
         raise InvalidRequestError( msg )

      if not isinstance( requestObject[ field ], expectedTypes ):
         typeStr = ', '.join( t.__name__ for t in expectedTypes )
         msg = 'Invalid type of %r specified: expected %s' % ( field, typeStr )
         raise InvalidRequestError( msg )

      if expectedValue and requestObject[ field ] != expectedValue:
         msg = 'Invalid %r value specified: must be %r' % ( field, expectedValue )
         raise InvalidRequestError( msg )
