# Copyright (c) 2012 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import ctypes
import hashlib
import uuid

import AaaClientLib
import Tac
import Tracing
from UwsgiAaaCacheLib import Cache
from UwsgiConstants import AuthenticationConstants

traceHandle = Tracing.Handle( 'UwsgiAaa' )
warn = traceHandle.trace1
info = traceHandle.trace2
trace = traceHandle.trace3
debug = traceHandle.trace4

class AuthenticationError( Exception ):
   """ Error representing a failure to authenticate. """
   pass

class DeauthenticationError( Exception ):
   """ Error representing a failure to authenticate. """
   pass

class CAuthDetails( ctypes.Structure ):
   _fields_ = [ ( 'handle', ctypes.POINTER( ctypes.c_void_p ) ),
                ( 'aaaAuthnId', ctypes.c_char_p ) ]

class AuthDetails( object ):
   def __init__( self ):
      self.handle = None
      self.aaaAuthnId = None

class AuthEntry( object ):
   def __init__( self, uid, gid, privLevel, aaaAuthnId, pamHandle, user,
                 requesterIp, userAgent ):
      self.uid = uid
      self.gid = gid
      self.privLevel = privLevel 
      self.aaaAuthnId = aaaAuthnId
      self.pamHandle = pamHandle
      self.user = user
      self.requesterIp = requesterIp
      self.userAgent = userAgent

class UwsgiAaaManager( AaaClientLib.AaaManager ):
   """The AAA manager coordinates communications with the AAA agent."""

   def __init__( self, sysname,
                 authCacheTimeout=AuthenticationConstants.AUTH_CACHE_TIMEOUT,
                 sessionCacheTimeout=AuthenticationConstants.DEFAULT_SESSION_TIMEOUT,
                 tty=AuthenticationConstants.TTY_NAME ):
      trace( 'UwsgiAaaManager.__init__' )
      super( UwsgiAaaManager, self ).__init__( sysname )
      libUwsgi = ctypes.CDLL( "libUwsgi.so" )
      self.libUwsgiAuthenticate = libUwsgi.uwsgiAuthenticate
      self.libUwsgiAuthenticate.argstypes = [ ctypes.c_char_p,
                                             ctypes.c_char_p,
                                             ctypes.c_char_p,
                                             ctypes.c_char_p,
                                             ctypes.c_char_p,
                                             ctypes.POINTER( ctypes.c_void_p ) ]
      self.libUwsgiAuthenticate.restype = CAuthDetails
      self.libUwsgiAuthenticateContext = ctypes.c_void_p()

      self.libUwsgiCleanupPamHandle = libUwsgi.uwsgiCleanupPamHandle
      self.libUwsgiCleanupPamHandle.argstypes = [ ctypes.POINTER( ctypes.c_void_p ),
                                                  ctypes.POINTER( ctypes.c_void_p ) ]
      self.libUwsgiCleanupPamHandle.restype = None

      self.authCache_ = Cache( entryCleanupFn=self._cleanupPamHandleEntry,
                               timeout=authCacheTimeout,
                               extendTimeout=True )
      self.sessionCache_ = Cache( entryCleanupFn=self._cleanupSession,
                                  timeout=sessionCacheTimeout )
      self.cacheKeyToSessionId_ = {}
      self.sessionIdToCacheKey_ = {}
      self.tty_ = tty

   def _generateCacheKey( self, user, passwdHash, requesterIp, userAgent ):
      return ( user, passwdHash, requesterIp, userAgent )

   def _hashPassword( self, passwd ):
      try:    
         # Sometimes these could be unicode.
         # pylint: disable-msg=E1101
         return hashlib.sha1( passwd ).hexdigest() 
      except TypeError:
         trace( 'authenticate exit bad passwd' )
         raise AuthenticationError( "Illegal password format" )

   def _getAuthDetails( self, requesterIp, userAgent, user, passwd,
                        disableAuth ):
      if not disableAuth:
         trace( 'invoking PAM authentication' )
         authDetails = self.libUwsgiAuthenticate(
            user, passwd, AuthenticationConstants.PAM_SERVICE_NAME, self.tty_,
            requesterIp, ctypes.byref( self.libUwsgiAuthenticateContext ) )
         trace( 'authDetails for PAM AUTH', authDetails.handle, 
                authDetails.aaaAuthnId )
         if not authDetails.handle:
            trace( 'authenticate exit not authentic' )
            raise AuthenticationError( "Bad username/password combination" )
      else:
         authDetails = AuthDetails()
         authDetails.handle = None
         authDetails.aaaAuthnId = self._createSession(
            user, requesterIp, AuthenticationConstants.PAM_SERVICE_NAME,
            AuthenticationConstants.PAM_SERVICE_NAME, self.tty_ )
         if not authDetails.aaaAuthnId:
            trace( 'authenticate exit not authentic' )
            raise AuthenticationError( "Bad username" )
      return authDetails

   def _getPrivLevelAndPwdData( self, user, authDetails ):
      privLevel, pwdData = self._getPrivLevel( user, authDetails.aaaAuthnId )
      if privLevel < 0:
         if authDetails.handle:
            self._cleanupPamHandle( authDetails.handle )
         else:
            self.closeSession( authDetails.aaaAuthnId )
         trace( 'authenticate exit not authentic' )
         raise AuthenticationError( "Bad username/password combination" )
      return privLevel, pwdData

   def _createAuthEntry( self, requesterIp, userAgent, user, passwd,
                      disableAuth ):
      authDetails = self._getAuthDetails( requesterIp, userAgent, user, passwd,
                                          disableAuth )
      privLevel, pwdData = self._getPrivLevelAndPwdData( user, authDetails )
      authEntry = AuthEntry( pwdData.pw_uid, pwdData.pw_gid, privLevel,
                             authDetails.aaaAuthnId, authDetails.handle,
                             user, requesterIp, userAgent )
      return authEntry

   def _cleanupPamHandle( self, pamHandle ):
      assert pamHandle, "No pam handle specified"
      assert self.libUwsgiAuthenticateContext, "No auth context specified"
      libUwsgiCtx = ctypes.byref( self.libUwsgiAuthenticateContext )
      self.libUwsgiCleanupPamHandle( libUwsgiCtx, pamHandle )
               
   def _cleanupPamHandleEntry( self, key, entry ):
      if entry.pamHandle:
         self._cleanupPamHandle( entry.pamHandle )
      else:
         self.closeSession( entry.aaaAuthnId )

   def _cleanupSession( self, key, entry ):
      cacheKey = self.sessionIdToCacheKey_[ key ]
      del self.cacheKeyToSessionId_[ cacheKey ]
      del self.sessionIdToCacheKey_[ key ]
      if entry.pamHandle:
         self._cleanupPamHandle( entry.pamHandle )
      else:
         self.closeSession( entry.aaaAuthnId )

   @Tac.withActivityLock
   def logoutSession( self, sessionId ):
      debug( 'logoutSession', sessionId )
      if not self.sessionCache_.hasKey( sessionId ):
         raise DeauthenticationError( 'Illegal logout request' )
      self.sessionCache_.cleanupEntry( sessionId )

   @Tac.withActivityLock
   def createSession( self, requesterIp, userAgent, user, passwd ):
      debug( 'createSession', user )
      passwdHash = self._hashPassword( passwd )
      cacheKey = self._generateCacheKey( user, passwdHash, requesterIp, userAgent )

      # this means that this user already has a session, use that instead
      if cacheKey in self.cacheKeyToSessionId_:
         sessionId = self.cacheKeyToSessionId_[ cacheKey ]
         # we don't create want to increment usage count in the createSession func.
         # since when we login we don't ever use that entry
         authEntry, expiryTime = self.getSession( sessionId,
                                                  incrementUsageCnt=False )
         return ( authEntry, sessionId, expiryTime )

      authEntry = self._createAuthEntry( requesterIp, userAgent, user, passwd,
                                         disableAuth=False )
      sessionId = str( uuid.uuid4() )
      while self.sessionCache_.hasKey( sessionId ):
         # Let's keep on looping while we have the same session ID
         sessionId = str( uuid.uuid4() )
      self.cacheKeyToSessionId_[ cacheKey ] = sessionId
      self.sessionIdToCacheKey_[ sessionId ] = cacheKey
      self.sessionCache_.insert( sessionId, authEntry )
      expiryTime = self.sessionCache_.getExpiryTime( sessionId )
      return ( authEntry, sessionId, expiryTime )

   @Tac.withActivityLock
   def getSession( self, sessionId, incrementUsageCnt=True ):
      debug( 'getSession', sessionId )
      if not self.sessionCache_.hasKey( sessionId ):
         raise AuthenticationError( "Bad session Id" )
      expiryTime = self.sessionCache_.getExpiryTime( sessionId )
      if expiryTime < Tac.now():
         raise AuthenticationError( "Session Expired" )
      authEntry = self.sessionCache_.get( sessionId, incrementUsageCnt )
      return ( authEntry, expiryTime )

   @Tac.withActivityLock
   def releaseSession( self, sessionId ):
      debug( 'releaseSession', sessionId )
      self.sessionCache_.release( sessionId )

   @Tac.withActivityLock
   def authenticate( self, requesterIp, userAgent, user=None, passwd=None,
                     disableAuth=False ):
      """ Performs Authentication and Authorization against a supplied
      username/password combination. Authorization cannot by bypassed, but the
      authentication can be skipped if required. 
      
      Validates the given username/password combination 
      if disableAaa is False. First checks the AuthCache, and if no entry
      exists, uses Basic Auth/PAM to authenticate the user. On failure, 
      this function raises an AuthenticationError, and on success it returns
      authEntry object."""
      trace( 'authenticate entry', user )
      passwdHash = self._hashPassword( passwd )
      cacheKey = self._generateCacheKey( user, passwdHash, requesterIp, userAgent )
      authEntry = self.authCache_.get( cacheKey )
      if authEntry is not None:
         trace( 'authenticate', user, 'not found in cache' )
         return authEntry
         
      debug( 'authenticate', user, 'starting...' )
      authEntry = self._createAuthEntry( requesterIp, userAgent, user, passwd,
                                         disableAuth )
      debug( 'authenticate', user, 'done, authentry:', authEntry )
      self.authCache_.insert( cacheKey, authEntry )
      # we call get so the cache knows that this is in use
      return self.authCache_.get( cacheKey )

   @Tac.withActivityLock
   def deauthenticate( self, user=None, passwd=None, requesterIp=None,
                       userAgent=None ):
      """ Release this entry for the user """
      debug( 'deauthenticate', user, 'starting...' )
      passwdHash = self._hashPassword( passwd ) if passwd is not None else None
      cacheKey = self._generateCacheKey( user, passwdHash, requesterIp, userAgent )
      debug( 'deauthenticate', user, 'cache key', cacheKey )
      self.authCache_.release( cacheKey )

   @Tac.withActivityLock
   def cleanup( self ):
      """ Cleans up the any current state for the manager """
      self.sessionCache_.cleanup()
      self.authCache_.cleanup()
