# Copyright (c) 2005-2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import Tracing

th = Tracing.defaultTraceHandle()
t0 = th.trace0

class ProviderBase( object ):
   """Base class for classes which provide authentication and authorization
   services to the Cli.  In the production software it's expected that only
   the provider provided by Aaa will be used, but other providers are useful
   in tests."""
   def __init__( self ):
      assert hasattr( self, "name" )

   def authorizeCommand( self, mode, privLevel, tokens ):
      """The hook function should return a 2-tuple in which the first element
      is True or False to indicate whether the command has been authorized, and
      the second element is a message to display to the user, which can be
      empty or None.  Two parameters are passed to each invocation of the
      function: the CliParser.Mode instance for the current mode, and a list of
      string tokens containing the command being authorized.  The hook is
      called after tokens have been autocompleted but before alias expansion is
      performed."""
      raise NotImplementedError( "Subclass must implement authorizeCommand" )

   def authenticateEnable( self, mode, privLevel ):
      """Performs authentication for transitioning between unprivileged exec
      mode and privileged exec mode.  Returns True if authentication
      succeeds."""
      raise NotImplementedError( "Subclass must implement authenticateEnable" )

   def sendCommandAcct( self, mode, privLevel, tokens, waitTime=0 ):
      """Performs command accounting accounting. Returns void.
      Four parameters are passed to each invocation of the
      function: the CliParser.Mode instance for the current mode, privLevel,
      a list of string tokens containing the command being sent to the
      accounting server and waitTime to wait for accounting before executing
      the command. The hook is called after tokens have been
      autocompleted but before alias expansion is performed."""
      raise NotImplementedError( "Subclass must implement sendCommandAcct" )

   def flushAcctQueue( self, mode, waitTime ):
      """Flushes accounting queue with a specified timeout."""
      pass

   def authenSessionData( self, mode ):
      """returns authentication session data for the current session"""
      return {}

class DefaultProvider( ProviderBase ):
   name = "default"

   def __init__( self ):
      ProviderBase.__init__( self )

   def authorizeCommand( self, mode, privLevel, tokens ):
      if mode.session_.standalone_:
         return ( True, "" )
      else:
         return ( False, "Default authorization provider rejects all commands" )

   def sendCommandAcct( self, mode, privLevel, tokens, waitTime=0 ):
      pass

   def authenticateEnable( self, mode, privLevel ):
      return mode.session_.standalone_

_defaultProvider = DefaultProvider()
_providers = {}
_activeProviderName = None
_activeProvider = _defaultProvider

def authenticateEnable( mode, privLevel ):
   """Authenticate the transition between unprivileged exec mode and privileged
   exec mode.  Returns True if authentication succeeds."""
   t0( "authenticateEnable: mode", mode.name, "privLevel", privLevel )
   return _activeProvider.authenticateEnable( mode, privLevel )

def authorizeCommand( mode, privLevel, tokens ):
   """Authorize the specified command when executed in the specified mode."""
   t0( "authorizeCommand: mode", mode.name, "privLevel", privLevel, "tokens",
       tokens )
   return _activeProvider.authorizeCommand( mode, privLevel, tokens )

def sendCommandAcct( mode, privLevel, tokens ):
   """Send accounting information for the specified command."""
   t0( "sendCommandAcct: mode", mode.name, "privLevel", privLevel, "tokens",
       tokens )
   _activeProvider.sendCommandAcct( mode, privLevel, tokens )

def flushAcctQueue( mode, waitTime=15 ):
   """Wait for the accounting queue to finish."""
   t0( "flushAcctQueue: mode", mode.name, "waitTime", waitTime )
   return _activeProvider.flushAcctQueue( mode, waitTime )

def authenSessionData( mode ):
   """returns authentication session data for the current session"""
   return _activeProvider.authenSessionData( mode )

def registerAaaProvider( provider ):
   """Register a AAA provider.
   The provider that Cli actually uses is determined by the current
   setting of the aaaProvider attribute."""
   _providers[ provider.name ] = provider
   if _activeProviderName == provider.name:
      global _activeProvider
      _activeProvider = provider

def _selectProvider( name ):
   """Selects which aaa provider will be used.  This function should
   only be called by the reactor in Cli when the aaaProvider attribute
   changes."""
   global _activeProviderName, _activeProvider
   _activeProviderName = name
   _activeProvider = _providers.get( name, _defaultProvider )
