# Copyright (c) 2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import atexit
import simplejson
import os
import traceback

import Agent
import Cell
import JsonApiConstants
import Plugins
import Tac
import Tracing
import UwsgiAaa
import UwsgiConstants

from ApiBaseModels import ModelJsonSerializer
from ControllerdbEntityManager import Controllerdb
from UwsgiRequestContext import UwsgiRequestContext, HttpBadRequest, HttpException
from UwsgiRequestContext import HttpForbidden, HttpServiceAclDenied
from UrlMap import getHandler

traceHandle = Tracing.Handle( 'JsonUwsgiServer' )
log = traceHandle.trace0
warn = traceHandle.trace1
info = traceHandle.trace2
trace = traceHandle.trace3
debug = traceHandle.trace4

Constants = Tac.Value( "Controller::Constants" )

class JsonApiApp( Agent.Agent ):
   def __init__( self, sysEm ):
      trace( 'init entry' )
      Agent.Agent.__init__( self, sysEm, agentName="JsonApiApp" )
      self.mgDone = False
      self.cmgDone = False
      self.sysname = sysEm.sysname()
      self.aaaManager_ = None
      self.sysdbPluginMounts = {}
      self.cdbPluginMounts = {}
      self.mounts = {}
      self.serviceAclFilterSm_ = None
      # Service ACL is always enabled now regardless of Epoch setting.
      # We keep this flag so in the future when we reimplement the service ACL
      # in the kernel we can come back and remove any code checking this flag.
      self.serviceAclEnabled_ = True

      trace( "Connecting to controllerdb" )
      cdbSock = os.environ.get( 'CONTROLLERDBSOCKNAME',
                                Constants.controllerdbDefaultSockname )
      self.cEm = Controllerdb( sysEm.sysname(),
                               controllerdbSockname_=cdbSock,
                               dieOnDisconnect=True,
                               mountRoot=False )
      self.aclConfigAggregatorSm = None
      self.aclCpConfigAggregatorSm = None
      self.aclConfig = None
      self.aclCpConfig = None

      self._loadPlugins()
      trace( 'init exit' )

   def _loadPlugins( self ):
      # TODO this function needs de-duping to make sure we do not
      # mount the same thing twice
      trace( "Loading JSONApiPlugin" )
      pd = Plugins.loadPlugins( "JSONApiPlugin" )
      for plugin in pd.plugins():
         if not plugin:
            continue
         self.sysdbPluginMounts.update( plugin[ 0 ] )
         self.cdbPluginMounts.update( plugin[ 1 ] )

   def handleMountFailure( self, mountUrl=None ):
      warn( "Failed to mount from ControllerDb" )
      raise Exception( "Failed to mount from ControllerDb" )

   def doInit( self, sysEm ):
      def mountDone():
         trace( 'mountDone entry' )
         self.aclConfigAggregatorSm = Tac.newInstance(
                                                   "Acl::AclConfigAggregatorSm",
                                                   self.mounts[ 'aclConfigDir' ],
                                                   self.mounts[ 'aclCliConfig' ],
                                                   self.aclConfig )
         self.aclCpConfigAggregatorSm = Tac.newInstance(
                                                   "Acl::AclCpConfigAggregatorSm",
                                                   self.mounts[ 'aclCpConfigDir' ],
                                                   self.mounts[ 'aclCliCpConfig' ],
                                                   self.aclCpConfig )
         self.aaaManager_ = UwsgiAaa.UwsgiAaaManager( sysEm.sysname(),
                                                      tty=JsonApiConstants.TTY_NAME )
         serviceMap = self.mounts[ 'openStackAgentConfig' ].serviceAclTypeVrfMap
         self.mounts[ 'osApiStatus' ].aclStatus = ( 'OpenStack', )
         aclStatus = self.mounts[ 'osApiStatus' ].aclStatus
         if self.serviceAclEnabled_:
            self.serviceAclFilterSm_ = Tac.newInstance( 'Acl::ServiceAclFilterSm',
                                                        'OpenStack',
                                                        self.aclConfig,
                                                        serviceMap,
                                                        aclStatus )
         for sessId, sess in self.mounts[ 'aaaSessStatus' ].session.iteritems():
            if sess.tty == JsonApiConstants.TTY_NAME:
               self.aaaManager_.closeSession( sessId )
         self.mgDone = True
         trace( 'mountDone exit' )

      def cmgDone():
         trace( 'cmgDone entry' )
         self.cmgDone = True
         trace( 'cmgDone exit' )

      trace( 'doInit entry' )

      self.aclConfig = Tac.newInstance( "Acl::Config" )
      self.aclCpConfig = Tac.newInstance( "Acl::CpConfig" )

      cMountGroup = self.cEm.mountGroup(
            mountFailureCallback=self.handleMountFailure,
            persistent=True )

      # We must mount the root or else other mounts fail
      self.mounts[ 'controllerdbRoot' ] = cMountGroup.mount( "", "Tac::Dir", "rt" )

      for name, mountInfo in self.cdbPluginMounts.iteritems():
         path, obj, mode = mountInfo
         trace( "Adding ControllerDb mount %s, %s, %s under key %s" %
               ( path, obj, mode, name ) )
         self.mounts[ name ] = cMountGroup.mount( path, obj, mode )
      cMountGroup.close( cmgDone )

      sMountGroup = sysEm.mountGroup()
      self.mounts[ 'aclConfigDir' ] = sMountGroup.mount(
                          'acl/config/input',
                          'Tac::Dir',
                          'ri' )
      self.mounts[ 'aclCliConfig' ] = sMountGroup.mount(
                          'acl/config/cli',
                          'Acl::Input::Config',
                          'r' )
      self.mounts[ 'aclCpConfigDir' ] = sMountGroup.mount(
                          'acl/cpconfig/input',
                          'Tac::Dir',
                          'ri' )
      self.mounts[ 'aclCliCpConfig' ] = sMountGroup.mount(
                          'acl/cpconfig/cli',
                          'Acl::Input::CpConfig',
                          'r' )
      self.mounts[ 'aclParamConfig' ] = sMountGroup.mount(
                          'acl/paramconfig',
                          'Acl::ParamConfig',
                          'r' )
      # TODO - move this to something not in the openstack mount point
      self.mounts[ 'osApiStatus' ] = sMountGroup.mount(
                          'mgmt/openstack/osApiStatus',
                          'JsonApi::ApiStatus',
                          'w' )
      self.mounts[ 'aaaSessStatus' ] = sMountGroup.mount(
                          Cell.path( 'security/aaa/status' ),
                          'Aaa::Status',
                          'r' )
      self.mounts[ 'openStackAgentConfig' ] = sMountGroup.mount(
                          'mgmt/openstack/config',
                          'OpenStack::Config',
                          'r' )

      for name, mountInfo in self.sysdbPluginMounts.iteritems():
         path, obj, mode = mountInfo
         trace( "Adding Sysdb mount %s, %s, %s under key %s" %
               ( path, obj, mode, name ) )
         self.mounts[ name ] = sMountGroup.mount( path, obj, mode )

      trace( 'closing mount group' )
      sMountGroup.close( mountDone )
      trace( 'doInit exit' )

   def warm( self ):
      return self.mgDone and self.cmgDone

   def _checkAuthRole( self, requestContext, userContext ):
      cfgAuthRole = self.mounts[ 'openStackAgentConfig' ].authRole
      reqAuthRoles = requestContext.getAuthedRoles( userContext )
      if cfgAuthRole != '' and cfgAuthRole not in reqAuthRoles:
         raise HttpForbidden( 'User is not authorized' )

   def processRequest( self, request ):
      try:
         requestContext = UwsgiRequestContext( request, self.aaaManager_,
            serviceAclFilterSm=self.serviceAclFilterSm_ )

         if not requestContext.aclPermitConnection():
            raise HttpServiceAclDenied( 'Filtered by service ACL' )

         userContext = requestContext.authenticate()
         try:
            self._checkAuthRole( requestContext, userContext )
            rType = requestContext.getRequestType()
            parsedUrl = requestContext.getParsedUrl()
            func, kwargs = getHandler( rType, parsedUrl )
            if func is not None:
               trace( 'calling handler %s' % func.__name__ )
               result = func( requestContext, self.mounts, **kwargs )
               result = simplejson.dumps( result, cls=ModelJsonSerializer )
            else:
               raise HttpBadRequest( 'Invalid endpoint requested' )
            trace( 'processRequest exit', result )
            return ( '200 OK', 'application/json', None, result )
         finally:
            requestContext.deauthenticate( userContext )
      except HttpException as e:
         traceback.print_exc()
         trace( 'processRequest HttpException', e )
         msg = simplejson.dumps( { 'error' : e.message } )
         return ( '%s %s' % ( e.code, e.name ), 'application/json',
                  e.additionalHeaders, msg )
      except Exception as e:
         trace( 'processRequest Exception', e )
         traceback.print_exc()  # Log stack trace to agent log.
         msg = simplejson.dumps( { 'error' : e.message } )
         return ( '500 Internal Server Error', 'application/json', None, msg )

class JsonApplication( object ):
   def __init__( self ):
      trace( 'init entry' )
      self.container_ = Agent.AgentContainer( [ JsonApiApp ], passiveMount=True,
                                           agentTitle="JsonApiApp" )
      self.container_.startAgents()
      # pylint: disable-msg=W0212
      atexit.register( lambda: os._exit( 0 ) ) # on exiting we have to be brutal
      Tac.activityThread().start( daemon=True )
      Tac.waitFor( self.apiAgentWarm, description="JsonApiApp to be warm",
                   maxDelay=1, sleep=True )
      self.apiAgent_ = self.container_.agents_[ 0 ]
      trace( 'init exit' )

   def apiAgentWarm( self ):
      return self.container_.agents_ and self.container_.agents_[ 0 ].warm()

   def __call__( self, request, start_response ):
      ( reponseCode, contentType, headers, body ) = \
         self.apiAgent_.processRequest( request )
      headers = headers if headers else []
      headers.append( ( 'Content-type', contentType ) )
      if body:
         headers.append( ( 'Content-length', str( len( body ) ) ) )
      start_response( reponseCode,
                      UwsgiConstants.DEFAULT_HEADERS + headers )
      return [ body ]
