# Copyright (c) 2016 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import pwd
import os
import socket
import traceback

import CliCommon
import PyClient
import Tac
import Tracing

traceHandle = Tracing.Handle( 'AaaClientLib' )
warn = traceHandle.trace1
info = traceHandle.trace2
trace = traceHandle.trace3
debug = traceHandle.trace4

PYCLIENT_CONNECT_TIMEOUT = 30
PyClientErrors = ( socket.error, ValueError )

AUTHZ_FAILED = "Authorization failed"
disabledAaaPrivLvl = 15

# pylint: disable-msg=E1103
# Pylint is baffled by the attributes of of RPC results,
# specifically, the contents of strepToValue() results.
class AaaManager( object ):
   def __init__( self, sysname,
                 defaultInitialPrivLevel=CliCommon.DEFAULT_PRIV_LVL ):
      self.sysname_ = sysname
      self.defaultInitialPrivLevel_ = defaultInitialPrivLevel
      self.aaaPc_ = None

   def _aaaPc( self ):
      """Returns a PyClient instance to the Aaa agent."""
      if self.aaaPc_ is None:
         try:
            self.aaaPc_ = PyClient.PyClient(
               self.sysname_, "Aaa",
               execMode=PyClient.Rpc.execModeThreadPerConnection,
               initConnectCmd="import AaaApi",
               connectTimeout=PYCLIENT_CONNECT_TIMEOUT )
         except socket.error as e:
            warn( "socket.error during PyClient setup:", str( e ) )
            self.aaaPc_ = None
         except PyClient.RpcError as e:
            warn( "PyClient.RpcError during PyClient setup:", str( e ) )
            self.aaaPc_ = None
      return self.aaaPc_

   def _aaaRpc( self, cmd ):
      """Evaluates Python code 'cmd' on the Aaa agent."""
      pc = self._aaaPc()
      if not pc:
         return None
      start = Tac.now()
      try:
         evalResult = pc.eval( cmd )
      except PyClientErrors as e:
         warn( "RPC Aaa Exception:", e.__class__.__name__, str( e ) )
         info( traceback.format_exc() )
         self.aaaPc_ = None
         return None

      result = Tac.strepToValue( evalResult )
      elapsed = max( Tac.now() - start, 0.0 )
      warn( "RPC Aaa:", cmd, "=", result, "(", "%.3fs)" % elapsed )
      return result

   def _getPwNam( self, user ):
      try:
         return pwd.getpwnam( user )
      except KeyError:
         warn( "getUid failed", user )
         return None

   def _getPrivLevel( self, user, sessionId ):
      debug( "Getting priv level for user", user, "sessionId", sessionId )
      try:
         sessionId = int( sessionId )
      except ( ValueError, TypeError ):
         sessionId = None
      pwdData = self._getPwNam( user )
      if not pwdData:
         return -1, None

      cmd = "AaaApi.authorizeShell( uid=%r, sessionId=%r )" % ( pwdData.pw_uid,
                                                                sessionId )
      authorizeShellResult = self._aaaRpc( cmd )
      if authorizeShellResult is None:
         warn( "_getPrivLevel unable to run cmd", cmd )
         return -1, None

      if authorizeShellResult.status != "allowed":
         warn( "_getPrivLevel user not allowed", user, authorizeShellResult.status )
         return -1, None

      return self._parsePrivLevel( authorizeShellResult ), pwdData

   def _parsePrivLevel( self, authorizeShellResult ):
      for av in authorizeShellResult.av.itervalues():
         if av.name == "privilegeLevel":
            try:
               return int( av.val )
            except ValueError:
               warn( "Warning: invalid privilege level:", av.val,
                   "type", type( av.val ) )
      return self.defaultInitialPrivLevel_

   def _createSession( self, user, remoteHost, service, authenMethod, tty ):
      cmd = ( ( "AaaApi.createSession( user=%r, service=%r, "
                "authenMethod=%r, "
                "tty=%r, remoteHost=%r, remoteUser=%r )" ) %
              ( user, service, authenMethod, tty, remoteHost, user ) )
      session = self._aaaRpc( cmd )
      if session is None or session.status != "open":
         warn( "_createSession unable to create session", session )
         return None

      debug( "_createSession session create with id", session.id )
      if not session.id:
         return None

      return str( session.id )

   def closeSession( self, sessionId ):
      if sessionId == "":
         return
      result = self._aaaRpc( "AaaApi.closeSession( %r )" % int( sessionId ) )
      debug( "closeSession result was", result )

   def getSessionRoles( self, sessionId ):
      if not sessionId or sessionId == "":
         return
      result = self._aaaRpc( "AaaApi.getSessionRoles( %r )" % int( sessionId ) )
      return result

   def authorizeUser( self, user, requester, service, authenMethod, tty, 
                      disableAaa ):
      """Authorizes a User

      Args:
        user: username to be authorized
        requester: the requesting domain
        service: PAM Service name
        authenMethod: Aaa AUTHEN method
        tty: Name of TTY
        disableAaa: Disable all AAA checking

      Returns:
        A dictionary with a bool key "authorized" that states whether authz was 
        successful.
      """
      if disableAaa:
         return { "authorized": True, "aaaSessionId": "",
                  "uid": os.getuid(), "gid": os.getgid(), 
                  "privLevel" : CliCommon.MAX_PRIV_LVL }
         
      # Perform exec authorization to see if the user is allowed to
      # execute commands on the switch (at all), and to get the
      # privilege level for the user. Works for local and TACACS+ users.
      aaaAuthnId = self._createSession( user, requester, service, 
                                        authenMethod, tty )
      if not aaaAuthnId:
         return { "authorized": False, "aaaSessionId": None,
                  "uid": None, "gid": None, "privLevel" : None }

      privLevel, pwdData = self._getPrivLevel( user, aaaAuthnId )
      if privLevel < 0:
         self.closeSession( aaaAuthnId )
         return { "authorized": False, "aaaSessionId": None,
                  "uid": None, "gid": None, "privLevel" : None }

      return { "authorized": True, "aaaSessionId": aaaAuthnId,
               "uid": pwdData.pw_uid, "gid": pwdData.pw_gid, 
               "privLevel" : privLevel }
