# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import errno
import ldap
import os
import re
import shutil
import tempfile
import threading

from Ark import synchronized
import AaaPluginLib
from AaaPluginLib import TR_ERROR, TR_WARN, TR_AUTHEN, TR_AUTHZ, TR_INFO, TR_DEBUG
import Arnet
from Arnet.NsLib import DEFAULT_NS
from BothTrace import traceX as bt
from BothTrace import Var as bv
from LdapUtil import RunInNetworkNamespace
from ReversibleSecretCli import decodeKey
import Tac
from Tracing import traceX

LDAP_OPT_X_TLS_PROTOCOL_TLS1_0 = ( ( 3 << 8 ) + 1 )
LDAP_OPT_X_TLS_PROTOCOL_TLS1_1 = ( ( 3 << 8 ) + 2 )
LDAP_OPT_X_TLS_PROTOCOL_TLS1_2 = ( ( 3 << 8 ) + 3 )

Constants = Tac.Type( "Mgmt::Security::Ssl::Constants" )
ldapCounterAttrs = ( "bindRequests",
                     "bindFails",
                     "bindSuccesses",
                     "bindTimeouts" )
counterSysdbStatusLock = threading.Lock()

class _Server( object ):
   def __init__( self, host, port, ns, baseDn, userDn, sslProfile,
                 activeGroupPolicy, searchUsernamePassword ):
      self.server = None # ldap object
      self.host_ = host
      self.port_ = port
      self.ns_ = ns
      self.baseDn = baseDn
      self.userDn = userDn
      self.sslProfile = sslProfile
      self.activeGroupPolicy = activeGroupPolicy
      self.searchUsernamePassword = searchUsernamePassword
      self.statusCallback = None
      self.caCertDir = ""

      aaaHost = host.hostname
      ipv6Match = re.search( Arnet.Ip6AddrRe, host.hostname )
      if ipv6Match:
         aaaHost = "[" + aaaHost + "]"
      self.hostStr = "ldap://" + aaaHost + ":" + str( port )

class Session( AaaPluginLib.Session ):
   """A session with a LDAP server, or potentially multiple servers in a
   failover arrangement.

   The usage model for authentication is to create a Session, add one or more
   servers using addServer, create an authentication request using
   createAuthenReq, configuring the request by setting the user, etc, then
   calling sendAuthenReq.  Check the return value from sendAuthenReq, then call
   continueAuthen one or more times, providing information as requested by the
   server until the authentication negotiation completes.

   The usage model for authorization is to create a Session, add one or more
   servers using addServer, create an authorization request using
   createAuthzReq, configuring the request by setting the user, command, etc,
   then calling sendAuthzReq and checking the return value.

   When a failure occurs during the initial communication with a server, ie.
   before a response has been successfully received by this class, the code
   herein will retry up to once per server.  Any failures that happen later
   during an authentication session will result in an AuthenticationError
   being raised.  Authorization is treated much the same way, except an
   AuthorizationError is raised."""
   def __init__( self, hostgroup ):
      traceX( TR_INFO, "creating LDAP session for %s" % hostgroup )
      AaaPluginLib.Session.__init__( self, hostgroup )
      self.servers_ = []
      self.activeServer_ = None

   def getActiveServer( self ):
      return self.activeServer_

   def makex509HashMap( self, x509FileList ):
      traceX( TR_DEBUG, "Ldap::Session::makex509HashMap: ", str( x509FileList ) )
      x509HashMap = {}
      for x509File in x509FileList:
         traceX( TR_DEBUG, "Working on file:", x509File )
         # Update the file to have the full path for later usage
         x509File = Constants.certPath( x509File )
         opensslSubCmd = ""
         with open( x509File, "r" ) as handle:
            rawText = handle.read()
            if "BEGIN CERTIFICATE" in rawText:
               opensslSubCmd = "x509"
            elif "BEGIN X509 CRL" in rawText:
               opensslSubCmd = "crl"
         assert opensslSubCmd, "This should have been set"
         certHash = Tac.run( [ "openssl", opensslSubCmd,
                               "-hash", "-noout",
                               "-in", x509File ],
                             stdout=Tac.CAPTURE,
                             stderr=Tac.DISCARD )
         # remove whitespace
         certHash = certHash.strip()
         hashList = x509HashMap.setdefault( certHash, list() )
         hashList.append( x509File )
      return x509HashMap

   def createCaCertSymlinks( self, hashMap, hashDir, isCrl=False ):
      traceX( TR_DEBUG, "Ldap::Session::createCaCertSymlinks: hashDir", hashDir )
      for hashVal, fileList in hashMap.items():
         for idx, fileName in enumerate( fileList ):
            traceX( TR_DEBUG, "hash", hashVal, "idx", idx, "file", fileName )
            numPrefix = ""
            if isCrl:
               #CRLs have a 'r' prefix on their symlink number
               numPrefix = "r"
            symlinkName = hashVal + ".{numPrefix}{idx}".format( numPrefix=numPrefix,
                                                                idx=idx )
            symlinkPath = os.path.join( hashDir, symlinkName )
            try:
               os.symlink( fileName, symlinkPath )
            except OSError as e:
               bt( TR_ERROR, "symlink", symlinkPath, "error:", str( e.strerror ) )
               if e.errno != errno.ENOSPC:
                  # Gracefully handle no space, bubble all other errors
                  # up
                  raise
               else:
                  return

   def createCaCertDir( self, trustCertList, crlList ):
      traceX( TR_DEBUG, "Ldap::Session::createCaCertDir" )
      try:
         caCertDir = tempfile.mkdtemp( prefix="LdapCaCertDir-" )
      except OSError as e:
         bt( TR_ERROR, "mkdtemp error:", str( e.strerror ) )
         if e.errno != errno.ENOSPC:
            # Gracefully handle no space, bubble all other errors
            # up
            raise
         else:
            return ""
      traceX( TR_DEBUG, "TLS caCertDir is", caCertDir )
      certHashMap = self.makex509HashMap( trustCertList )
      crlHashMap = self.makex509HashMap( crlList )
      self.createCaCertSymlinks( certHashMap, caCertDir )
      self.createCaCertSymlinks( crlHashMap, caCertDir, isCrl=True )
      return caCertDir

   def addServer( self, aaaHost, ns=DEFAULT_NS, counterCallback=None,
                  baseDn=None, userDn=None, sslProfile=None, activeGroupPolicy=None,
                  searchUsernamePassword=None ):
      ss = _Server( aaaHost, aaaHost.port, ns, baseDn, userDn,
                    sslProfile, activeGroupPolicy, searchUsernamePassword )
      ss.statusCallback = counterCallback
      self.servers_.append( ss )
      traceX( TR_INFO, "addServer:", ss.host_.spec, ":", ss.port_, "ns:", ss.ns_ )

      try:
         with RunInNetworkNamespace( ns ):
            ss.server = ldap.initialize( ss.hostStr )
            traceX( TR_INFO, "Initialized server in namespace", ns )
            ss.server.protocol_version = ldap.VERSION3
            if ss.sslProfile:
               ss.caCertDir = self.createCaCertDir( ss.sslProfile[ "trustedCert" ],
                                                    ss.sslProfile[ "crl" ] )
               ss.server.set_option( ldap.OPT_X_TLS_CACERTDIR,
                                       ss.caCertDir )
               ss.server.set_option( ldap.OPT_X_TLS_REQUIRE_CERT,
                                       ldap.OPT_X_TLS_DEMAND )
               if ss.sslProfile[ "crl" ]:
                  ss.server.set_option( ldap.OPT_X_TLS_CRLCHECK,
                                          ldap.OPT_X_TLS_CRL_ALL )
               ss.server.set_option( ldap.OPT_X_TLS_CIPHER_SUITE,
                                     ss.sslProfile[ "cipherSuite" ] )

               # Set tls version
               tlsVersionMap = [ ( Constants.tlsv1_2,
                                   LDAP_OPT_X_TLS_PROTOCOL_TLS1_2 ),
                                 ( Constants.tlsv1_1,
                                   LDAP_OPT_X_TLS_PROTOCOL_TLS1_1 ),
                                 ( Constants.tlsv1,
                                   LDAP_OPT_X_TLS_PROTOCOL_TLS1_0 ) ]
               # Use TLS1_2 by default
               tls_protocol_min = LDAP_OPT_X_TLS_PROTOCOL_TLS1_2
               for mask, version in tlsVersionMap:
                  # Check from 1_2 to 1_0
                  if ss.sslProfile[ "tlsVersion" ] & mask:
                     tls_protocol_min = version
               ss.server.set_option( ldap.OPT_X_TLS_PROTOCOL_MIN,
                                       tls_protocol_min )

               ss.server.set_option( ldap.OPT_X_TLS_NEWCTX, 0 )
               ss.server.start_tls_s()

      except ldap.LDAPError as error:
         # not collecting any stats for initialize or tls related failures.
         errorMsg = error.message[ 'desc' ]
         bt( TR_ERROR, "addServer exception:", bv( errorMsg ) )
         raise

   def sendAuthenReq( self, username, password ):
      result, failText, attrs = None, None, None
      self.activeServer_ = None
      for ss in self.servers_:
         result, failText, attrs = self._sendAuthenReq( ss, username, password )
         self.activeServer_ = ss.host_.spec
         if result == 'success':
            if ss.sslProfile:
               attrs[ 'sslProfile' ] = ss.sslProfile
            break
         elif result == 'fail':
            # authentication failed for searchUser or user
            break
         elif result == 'unavailable':
            # server error, try a different server
            bt( TR_ERROR, "Server error for", ss.host_.spec )

      return ( result, failText, attrs )

   def _sendAuthenReq( self, ss, username, password ):
      traceX( TR_AUTHEN, "Authenticating against server",
              ss.host_.spec.hostname )
      attrs = {}
      searchUsername = ss.searchUsernamePassword.username
      searchPassword = ss.searchUsernamePassword.password
      userMatches = []
      try:
         traceX( TR_AUTHEN, "Binding to", searchUsername, ss.ns_ )
         userCn = ss.userDn + "=" + username
         with RunInNetworkNamespace( ss.ns_ ):
            self.serverBind( ss, searchUsername, decodeKey( searchPassword ) )
            traceX( TR_AUTHEN, "Admin successfully authenticated" )
            userMatches = ss.server.search_s( ss.baseDn, ldap.SCOPE_SUBTREE,
                                                   "(& (%s))" % userCn, None )
            self.serverBindSuccess( ss, unbind=False )
         traceX( TR_AUTHEN, "Search returned results:", userMatches )

      except ldap.SERVER_DOWN:
         errorMsg = "LDAP Server Unavailable"
         traceX( TR_ERROR, "LDAP Server Unavailable during authentication" )
         self.serverError( ss, 'bindFails' )
         return ( "unavailable", errorMsg, None )
      except ldap.LDAPError as error:
         counterAttr = 'bindFails'
         if isinstance( error, ldap.TIMEOUT ):
            counterAttr = 'bindTimeouts'
         self.serverError( ss, counterAttr )
         errorMsg = error.message[ 'desc' ]
         # FIXME: handle auth failure or communication failure
         bt( TR_ERROR, "Failing to bind/search with", bv( searchUsername ),
             ":", bv( errorMsg ) )
         return ( "fail", errorMsg, None )

      for user in userMatches:
         userDn = user[ 0 ]
         if not userDn:
            continue
         try:
            with RunInNetworkNamespace( ss.ns_ ):
               traceX( TR_DEBUG, "Attempting bind to userDn:", userDn )
               self.serverBind( ss, userDn, password )
               traceX( TR_AUTHEN, "User", userDn, "is successfully authenticate" )
               # Close the connection
               self.serverBindSuccess( ss )
            attrs[ 'userDn' ] = userDn
            return ( "success", "", attrs )
         except ldap.LDAPError as error:
            counterAttr = 'bindFails'
            if isinstance( error, ldap.TIMEOUT ):
               counterAttr = 'bindTimeouts'
            self.serverError( ss, counterAttr )
            bt( TR_ERROR, "server error:", bv( error.message ) )

      errorMsg = "User %s authentication failed" % username
      return ( "fail", errorMsg, None )

   def sendAuthzReq( self, userDn, searchFilter, groupRolePrivilege ):
      result, failText, attrs = ( 'authzUnavailable', 'no usable server', None )
      self.activeServer_ = None
      for ss in self.servers_:
         result, failText, attrs = self._sendAuthzReq( ss, userDn, searchFilter,
                                                       groupRolePrivilege )
         self.activeServer_ = ss.host_.spec
         if result == 'allowed':
            break
         elif result == 'denied':
            break
         elif result == 'authzUnavailable':
            bt( TR_ERROR, "Server", self.activeServer_.stringValue(), "unavailable" )
            # try next Server

      return ( result, failText, attrs )

   def _sendAuthzReq( self, ss, userDn, searchFilter, groupRolePrivilege ):
      assert ss
      try:
         traceX( TR_AUTHZ, "Binding to", ss.searchUsernamePassword.username, ss.ns_ )
         with RunInNetworkNamespace( ss.ns_ ):
            self.serverBind( ss, ss.searchUsernamePassword.username,
                             decodeKey( ss.searchUsernamePassword.password ) )
         traceX( TR_AUTHZ, "Admin", ss.searchUsernamePassword.username,
                 "successfully authenticated" )
         matchedGroups = None
         with RunInNetworkNamespace( ss.ns_ ):
            matchedGroups = ss.server.search_s(
               ss.baseDn, ldap.SCOPE_SUBTREE, "(&(objectclass=%s)(%s=%s))" % (
                  searchFilter.group, searchFilter.member, userDn ),
               attrsonly=1 )
            traceX( TR_AUTHZ, "Matched groups:", matchedGroups )
         groupSet = set()
         for group in matchedGroups:
            if not group[ 0 ]:
               continue
            groupNameRegex = r"[^,]=([^,]+),"
            matchGroup = re.search( groupNameRegex, group[ 0 ] )
            if matchGroup:
               groupSet.add( matchGroup.group( 1 ) )
         attrs = {}
         traceX( TR_AUTHZ, "Trying to find matching role for group",
                 ', '.join( groupSet ) )
         for grp in groupRolePrivilege.itervalues():
            if grp.group in groupSet:
               bt( TR_AUTHZ, "Matched role", bv( grp.role ),
                   "privilege", bv( grp.privilege ),
                   "for group", bv( grp.group ) )
               attrs[ AaaPluginLib.roles ] = [ grp.role ]
               attrs[ AaaPluginLib.privilegeLevel ] = grp.privilege
               break
         with RunInNetworkNamespace( ss.ns_ ):
            self.serverBindSuccess( ss )
         traceX( TR_DEBUG, "Updated attrs", attrs )
         return ( "allowed", "", attrs )

      except ldap.SERVER_DOWN:
         errorMsg = "LDAP Server Unavailable"
         bt( TR_ERROR, "LDAP Server Unavailable during authorization" )
         self.serverError( ss, 'bindFails' )
         return ( "authzUnavailable", errorMsg, {} )
      except ldap.LDAPError as error:
         counterAttr = 'bindFails'
         if isinstance( error, ldap.TIMEOUT ):
            counterAttr = 'bindTimeouts'
         self.serverError( ss, counterAttr )
         bt( TR_ERROR, "Authorization failed for", bv( userDn ), ":", bv( error ) )
         return ( "denied", str( error ), {} )

   def serverBind( self, ss, username, password ):
      cb = ss.statusCallback
      if cb:
         cb( 'bindRequests', 1 )
      else:
         traceX( TR_DEBUG, "No bindRequests callback found for", ss.host_.hostname )
      ss.server.bind_s( username, password )

   def serverBindSuccess( self, ss, unbind=True ):
      if unbind:
         ss.server.unbind_s()
         ss.server = None
      cb = ss.statusCallback
      if cb:
         cb( 'bindSuccesses', 1 )
      else:
         traceX( TR_DEBUG, "No bindSuccesses callback found for", ss.host_.hostname )

   def serverError( self, ss, counterAttr ):
      cb = ss.statusCallback
      if cb:
         cb( counterAttr, 1 )
      else:
         traceX( TR_DEBUG, "No", counterAttr, "callback found for",
                 ss.host_.stringValue() )

   def close( self ):
      for ss in self.servers_:
         if ss.server:
            traceX( TR_INFO, "Unbind from server" )
            ss.server.unbind_s()
            ss.server = None
         try:
            shutil.rmtree( ss.caCertDir )
         except OSError as excep:
            if excep.errno == errno.ENOENT:
               pass
            else:
               raise excep

   def __del__( self ):
      # If close was not called explicitly
      self.close()

class CounterCallback( object ):

   def __init__( self, status, hostspec ):
      self.status = status
      self.hostspec = hostspec

   @synchronized( counterSysdbStatusLock )
   def __call__( self, attrName, delta ):
      traceLevel = ( TR_ERROR if attrName in ( 'bindFails', 'bindTimeout' )
                     else TR_INFO )
      bt( traceLevel, "LDAP counter", bv( attrName ), "host",
          bv( self.hostspec.stringValue() ) )
      if not self.status:
         return # used for testing
      if attrName not in ldapCounterAttrs:
         bt( TR_ERROR, "unknown counter type", attrName )
         return
      counterToUse = self.status.counter
      if self.hostspec not in counterToUse:
         bt( TR_WARN, "LDAP counters missing entry for",
             self.hostspec.stringValue() )
         return # Operator probably removed ldap host in the
         # middle of the authentication request.
      c = Tac.nonConst( counterToUse[ self.hostspec ] )
      old = c.__getattribute__( attrName )
      new = old + delta
      traceX( TR_DEBUG, "Counter", attrName, "set to", new )
      c.__setattr__( attrName, new )
      counterToUse[ self.hostspec ] = c
