# Copyright (c) 2008, 2009, 2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
import crypt, os, re, bisect
import errno

from Aaa import makePwEntry
import AaaPluginLib
from AaaPluginLib import TR_ERROR, TR_AUTHEN, TR_AUTHZ
from AaaPluginLib import TR_SESSION, TR_INFO, TR_DEBUG
from BothTrace import traceX as bt
from BothTrace import Var as bv
import Logging
from Tracing import traceX
import Tac

reactors_ = []

Logging.logD( id="AAA_ROOT_PASSWORD_NOTUPDATED",
              severity=Logging.logWarning,
              format="Password setting for root not changed due to "
                     "internal error (%s)",
              explanation="The password for the root user could not be changed "
                          "from its current value. The old password (if any) "
                          "continues to be valid for the account. "
                          "This is a potential security risk.",
              recommendedAction=Logging.CALL_SUPPORT_IF_PERSISTS )

Logging.logD( id="AAA_REMOTE_LOGIN_DENIED_BY_POLICY",
              severity=Logging.logWarning,
              format="Remote login from %s denied for user %s due to "
                     "authentication policy",
              explanation="A remote login attempt was denied due to "
                          "authentication policy.  For example, the configured "
                          "policy may prohibit remote logins for accounts "
                          "with empty passwords.",
              recommendedAction=Logging.NO_ACTION_REQUIRED )

Logging.logD( id="AAA_SSHDIR_FILESYSTEM_ERROR",
              severity=Logging.logError,
              format="Error generating ssh %s at path %s: %s ",
              explanation="An attempt to create a file in the ssh directory failed "
                          "because of a filesystem error.",
              recommendedAction="Check to see if the filesystem has run out of "
                                "disk space.  (Potentially from large log files) "
                                "Contact TAC Support if this problem persists " )

Logging.logD( id="AAA_INVALID_REGEX_IN_ROLE",
              severity=Logging.logError,
              format="Rule %d of Role %s has an invalid regular expression",
              explanation="The specified rule has an invalid regular expression,"
                          " and will not take effect until the syntax is fixed.",
              recommendedAction="Fix the role configuration." )

class CompiledRule( object ):
   def __init__( self, rule ):
      self.permit = ( rule.action == 'permit' )
      self.modeKey = rule.modeKey
      self.modeKeyRe = re.compile( rule.modeKey )
      self.regex = re.compile( rule.regex )

class CompiledRole( object ):
   '''A list of compiled rules for the role.'''
   def __init__( self, name ):
      self.name = name
      self.seqs = [] # sorted sequence numbers
      self.rules = dict() # seq -> CompiledRule

# dictionary of roleName -> CompiledRole
compiledRoles = {}

def removeFile( filename ):
   try:
      os.unlink( filename )
   except OSError:
      pass

def setUpSshDir( sshDirPath ):
   try:
      os.mkdir( sshDirPath, 0700 )
   except OSError as e:
      if e.errno == errno.EEXIST and os.path.isdir( sshDirPath ):
         pass
      else:
         Logging.log( AAA_SSHDIR_FILESYSTEM_ERROR, 'directory', sshDirPath, e )

def writeSshDirFile( filename, content, logString ):
   try:
      with open( filename, 'w' ) as sshDirFile:
         sshDirFile.write( content )
      os.chmod( filename, 0600 )
   except ( IOError, OSError ) as e:
      bt( TR_ERROR, "Failed to write file ", bv( filename ), ":", bv( e.strerror ) )
      Logging.log( AAA_SSHDIR_FILESYSTEM_ERROR, logString, filename, e.strerror )

class LocalUserConfigReactor( Tac.Notifiee ):
   '''When root passwd changes, update it in /etc/shadow.
   When root principals change, update corresponding principals file'''

   notifierTypeName = "LocalUser::Config"
   def __init__( self, cfg, agent ):
      self.cfg_ = cfg
      self.agent_ = agent
      self.localUserAcctReactors_ = {}
      self.roleConfigReactors_ = {}
      self.sshDirPath_ = '/root/.ssh/'
      self.sshPrincipalFilePath_ = self.sshDirPath_ + 'principals'
      Tac.Notifiee.__init__( self, cfg )
      self.handleEncryptedRootPasswd()
      for acct in self.notifier_.acct:
         self.handleAcct( acct )
      for role in self.notifier_.role:
         self.handleRole( role )
      self.handleRootPrincipal()

   @Tac.handler( 'encryptedRootPasswd' )
   def handleEncryptedRootPasswd( self ):
      bt( TR_INFO, "handle root password", bv( self.cfg_.encryptedRootPasswd ) )
      cmd = [ "/usr/sbin/usermod",
              "-p", self.cfg_.encryptedRootPasswd,
              "root"
            ]
      try:
         Tac.run( cmd )
      except Tac.SystemCommandError, e:
         Logging.log( AAA_ROOT_PASSWORD_NOTUPDATED, str( e ) )

   @Tac.handler( 'rootSshKey' )
   def handleRootSshKey( self ):
      bt( TR_INFO, "handle root ssh key" )
      authKey = self.cfg_.rootSshKey
      sshKeyDirPath = '/root/.ssh/'
      keysPath = sshKeyDirPath + 'authorized_keys'

      if not authKey:
         # A key was removed and now there are none
         removeFile( keysPath )
         return

      setUpSshDir( sshKeyDirPath )

      # Write the keys to file
      writeSshDirFile( keysPath, authKey + '\n', "key file" )

   @Tac.handler( 'rootPrincipal' )
   def handleRootPrincipal( self ):
      principal = self.cfg_.rootPrincipal
      principalPath = self.sshPrincipalFilePath_

      if not principal:
         removeFile( principalPath )
         return

      setUpSshDir( self.sshDirPath_ )

      # Write principals to file; SSH expects one line per principal
      content = "\n".join( principal.split() )
      writeSshDirFile( principalPath, content, "principal file" )

   @Tac.handler( 'acct' )
   def handleAcct( self, key ):
      Tac.handleCollectionChange( LocalUserAccountReactor, key, 
                                  self.localUserAcctReactors_, 
                                  self.cfg_.acct, 
                                  reactorArgs=(self.agent_,) )

   @Tac.handler( 'role' )
   def handleRole( self, key ):
      if key in self.notifier_.role:
         # new entry
         compiledRoles[ key ] = CompiledRole( key )
      else:
         del compiledRoles[ key ]
      Tac.handleCollectionChange( RoleConfigReactor, key, 
                                  self.roleConfigReactors_, 
                                  self.cfg_.role, 
                                  reactorArgs=() )

   def close( self ):
      for r in self.localUserAcctReactors_.itervalues():
         r.close()
      self.localUserAcctReactors_.clear()
      for r in self.roleConfigReactors_.itervalues():
         r.close()
      self.roleConfigReactors_.clear()
      Tac.Notifiee.close( self )

class RoleConfigReactor( Tac.Notifiee ):
   notifierTypeName = "LocalUser::Role"

   def __init__( self, role ):
      Tac.Notifiee.__init__( self, role )
      self.role_ = compiledRoles[ role.name ]
      for seq in self.notifier_.rule:
         self.handleRule( seq )

   @Tac.handler( 'rule' )
   def handleRule( self, seq ):
      if seq in self.notifier_.rule:
         # new or updated rule
         try:
            rule = self.notifier_.rule[ seq ]
            r = CompiledRule( rule )
            bisect.insort( self.role_.seqs, seq )
            self.role_.rules[ seq ] = r
            return
         except SyntaxError:
            # The CLI should have protected us so it should be impossible,
            # but lets just handle it here.
            Logging.log( AAA_INVALID_REGEX_IN_ROLE, seq, self.notifier_.name )
      # deleted rule
      if seq in self.role_.rules:
         self.role_.seqs.remove( seq )
         del self.role_.rules[ seq ]

class LocalUserAccountReactor( Tac.Notifiee ):
   notifierTypeName = "LocalUser::Account"
   
   def __init__( self, acct, agent ):
      Tac.Notifiee.__init__( self, acct )
      self.acct_ = acct
      self.agent_ = agent
      self.userName_ = self.acct_.userName
      self.sshKeyDirPath_ = '/home/%s/.ssh/' % acct.userName
      self.sshKeyFilePath_ = self.sshKeyDirPath_ + 'authorized_keys'
      self.sshPrincipalFilePath_ = self.sshKeyDirPath_ + 'principals'
      self.handleSshAuthorizedKey()
      self.handleSshPrincipal()

   def _adjustFileOwnership( self, filename ):
      pwent = self.agent_.getPwEnt( name=self.userName_ )
      assert pwent != None
      os.chown( filename, pwent.userId, pwent.groupId )

   def _setUpLocalSshDir( self ):
      sshDirPath = self.sshKeyDirPath_
      agent = self.agent_

      # We need to setup the user incase it hasn't been done yet (if the 
      # account was created in the Cli, it will not be setup until the first
      # login.)   We need to create it now because we need a valid PwEnt
      # in order to create our .ssh directory and keys.)
      with agent.mutex:
         if self.userName_ not in agent.status.account:
            agent.setUpNewUser( self.userName_ )
      agent.ensureHomeDirExists( self.userName_ )

      setUpSshDir( sshDirPath )
      self._adjustFileOwnership( sshDirPath )

   @Tac.handler( 'sshAuthorizedKey' )
   def handleSshAuthorizedKey( self ):
      authKey = self.acct_.sshAuthorizedKey
      keysPath = self.sshKeyFilePath_

      if not authKey:
         # A key was removed and now there are none
         removeFile( keysPath )
         return

      self._setUpLocalSshDir()

      # Write the keys to file
      writeSshDirFile( keysPath, authKey + '\n', "key file" )
      self._adjustFileOwnership( keysPath )

   @Tac.handler( 'principal' )
   def handleSshPrincipal( self ):
      principal = self.acct_.principal
      principalPath = self.sshPrincipalFilePath_

      if not principal:
         removeFile( principalPath )
         return

      self._setUpLocalSshDir()

      # Write principals to file; SSH expects one line per principal
      content = "\n".join( principal.split() )
      writeSshDirFile( principalPath, content, "principal file" )
      self._adjustFileOwnership( principalPath )

   def close( self ):
      # user deleted, cleanup the ssh files
      removeFile( self.sshKeyFilePath_ )
      removeFile( self.sshPrincipalFilePath_ )
      Tac.Notifiee.close( self )

class EnableAuthenticator( AaaPluginLib.Authenticator ):
   # state machine values
   needPassword = 1
   askedPassword = 2
   failed = 3
   succeeded = 4

   stateMap = { 1 : "needPassword", 2 : "askedPassword", 3 : "failed",
                4 : "succeeded" }

   def __init__( self, config, aaaConfig, method, type, service, remoteHost,
                 remoteUser, tty, user, privLevel ):
      self.config = config
      self.privLevel = privLevel
      self.attempts = 0
      self.state = self.needPassword
      AaaPluginLib.Authenticator.__init__( self, aaaConfig, 
                                           method, type, service,
                                           remoteHost, remoteUser, tty, user )
      assert type == 'authnTypeEnable'

   def authenticate( self, *responses ):
      configuredPasswd = self.config.encryptedEnablePasswd
      if self.state == self.needPassword:
         if configuredPasswd == "":
            # There is no 'enable' password set -- allow the user to switch
            # into 'enable' mode.
            return self.transition( self.succeeded, 'success' )
         else:
            return self.transition( self.askedPassword, 'inProgress',
                                    [ self.passwordPrompt() ] )
      elif self.state == self.askedPassword:
         if len( responses ) == 1:
            traceX( TR_AUTHEN, "authenticate: state=askedPassword" )
            self.attempts += 1
            if self.checkPassword( responses[ 0 ] ):
               return self.transition( self.succeeded, 'success' )
            elif self.attempts < 3:
               return self.transition( self.askedPassword, 'inProgress',
                                       [ self.passwordPrompt() ] )
            else:
               failMsg = Tac.Value( "AaaApi::AuthenMessage",
                                    style='error', text='Bad secret' )
               return self.transition( self.failed, 'fail', [ failMsg ] )
      else:
         BT( TR_ERROR, "EnableAuthenticator.authenticate unexpected state:",
             bv( self.state ) )
         # fall through
      return self.transition( self.failed, 'fail' )

   def transition( self, state, authenStatus, messages=[] ):
      sm = self.stateMap
      traceX( TR_AUTHEN, "EnableAuthenticator transitioning from", sm[ self.state ],
              "to", sm[ state ] )
      self.state = state
      r = { "status" : authenStatus, "messages" : messages, "user" : self.user,
            "authToken" : "" }
      return r

   def checkPassword( self, password ):
      # Pass in the current encrypted password as the salt for the
      # encryption, so we don't choose a new salt, which will result
      # in the password entered at the CLI to be encrypted to a
      # different string, so even if it matches the configured passwd,
      # we won't be able to tell...
      configuredPasswd = self.config.encryptedEnablePasswd
      encryptedPasswd = crypt.crypt( password, configuredPasswd )
      return ( configuredPasswd == encryptedPasswd )

class LocalUserAuthenticator( AaaPluginLib.BasicUserAuthenticator ):
   def __init__( self, config, aaaConfig, method, type, service, remoteHost,
                 remoteUser, tty, user, privLevel ):
      self.config = config
      AaaPluginLib.BasicUserAuthenticator.__init__( self, aaaConfig, method, 
                                                    type, service,
                                                    remoteHost, remoteUser, tty,
                                                    user, privLevel )
      # BUG 13784 requires that we return unavailable for unknown user,
      # but we don't log the typical FALLBACK message.
      self.logFallback = False
      
   def checkUser( self, user ):
      traceX( TR_AUTHEN, "checkUser: user=", user )
      return user in self.config.acct

   def checkEmptyPassword( self, user ):
      return self.checkPassword( user, "" )

   def checkPassword( self, user, password ):
      traceX( TR_AUTHEN, "checkPassword: user=", user, "password=*****" )
      acct = self.config.acct.get( user )
      failMsg = Tac.Value( "AaaApi::AuthenMessage",
                           style='error', text='Bad secret' )
      if acct:
         if self.policyDeniesLogin( acct ):
            host = self.remoteHost if self.remoteHost else "(unknown host)"
            Logging.log( AAA_REMOTE_LOGIN_DENIED_BY_POLICY, host, user )
            return { 'state': self.failed, 'authenStatus': 'fail',
                     'messages': [ failMsg ] }
         encrypted = crypt.crypt( password, acct.encryptedPasswd ) or ''
         traceX( TR_AUTHEN, "  encrypted:", encrypted, "in database:",
                 acct.encryptedPasswd )
         if encrypted == acct.encryptedPasswd:
            attrs = dict( roles = [ acct.role ] )
            return { 'state': self.succeeded, 'authenStatus': 'success',
                     'messages': [], 'user': self.user, 'authToken': self.password,
                     'sessionData' : attrs }
      return { 'state': self.failed, 'authenStatus': 'fail', 
               'messages': [ failMsg ], 'user': user, 'authToken': password }

   def policyDeniesLogin( self, acct ):
      # service is set to 'login' for console login, 'remote' for telnet, and
      # 'sshd' for ssh.
      remoteServices = ( 'sshd', 'remote', 'command-api' )
      if( not self.config.allowRemoteLoginWithEmptyPassword and
            (self.service in remoteServices ) and
            (acct.encryptedPasswd == '') ):
         return True
      return False

class LocalUserPlugin( AaaPluginLib.Plugin ):
   def __init__( self, config, aaaConfig, aaaStatus ):
      AaaPluginLib.Plugin.__init__( self, aaaConfig, "local" )
      self.aaaStatus = aaaStatus
      self.config = config

   def _userIsKnown( self, user ):
      return user in self.config.acct

   def ready( self ):
      return True

   def logFallback( self ):
      # BUG 13784 requires that we return unavailable for unknown user,
      # but we don't log the typical FALLBACK message.
      return False
 
   def createAuthenticator( self, method, type, service, remoteHost,
                            remoteUser, tty, user=None, privLevel=0 ):
      assert method == self.name
      if type == 'authnTypeLogin':
         a = LocalUserAuthenticator( self.config, self.aaaConfig, method, type,
                                     service, remoteHost, remoteUser, tty, user,
                                     privLevel)
      elif type == 'authnTypeEnable':
         a = EnableAuthenticator( self.config, self.aaaConfig, method, type,
                                  service, remoteHost, remoteUser, tty, user,
                                  privLevel )
      else:
         bt( TR_ERROR, "unknown authentication type:", bv( type ) )
         a = None
      return a

   def openSession( self, authenticator ):
      traceX( TR_SESSION, "openSession for user", authenticator.user )
      return authenticator

   def closeSession( self, token ):
      traceX( TR_SESSION, "closeSession for user", token.user )

   def authorizeShell( self, method, user, session ):
      traceX( TR_AUTHZ, "authorizeShell for method", method, "user", user )
      assert method == self.name

      # My policy is to authorize shells for users I knoxw about
      attrs = {}
      acct = self.config.acct.get( user )
      if acct is not None:
         status = 'allowed'
         message = ''
         attrs[ AaaPluginLib.privilegeLevel ] = acct.privilegeLevel
         attrs[ AaaPluginLib.roles ] = [ acct.role ]
      else:
         status = 'authzUnavailable'
         message = "Unknown user"
      return ( status, message, attrs )

   def _ruleMatchByMode( self, rule, mode ):
      modeName, modeKey, longModeKey = mode
      # A rule without a modeKey is applied to all the Cli modes
      if not rule.modeKey:
         return True
      elif modeName == 'Exec':
         return rule.modeKey == 'exec'
      elif rule.modeKey == 'config-all':
         return True
      elif modeName == 'Configure':
         return rule.modeKey == 'config'
      elif rule.modeKey == modeKey:
         return True
      elif rule.modeKeyRe.match( longModeKey ):
         return True

      # Not applicable
      return False

   def authorizeShellCommand( self, method, user, session, mode, privlevel, tokens ):
      traceX( TR_AUTHZ, "authorizeShellCommand for method", method, "user", user,
              "mode", mode, "privlevel", privlevel, "tokens", tokens )
      assert method == self.name

      role = None
      if session:
         # Deny the established session of a deleted local user
         if session.authenMethod == 'local' and not self._userIsKnown( user ):
            return ( "authzUnavailable", "Unknown user", {} )

         traceX( TR_DEBUG, 'user', user, 'sessionId', session.id )
         sessionData = session.property.get( session.authenMethod )
         if sessionData:
            traceX( TR_DEBUG, 'sessionData', sessionData )
            # We may support multiple roles per user in the future
            attr = sessionData.attr.get( 'roles' )
            if attr:
               role = eval( attr )[ 0 ]
         if role is None:
            return ( "authzUnavailable", "Unknown role", {} )
      else:
         return ( "authzUnavailable", "Unknown user", {} )

      # Get the effective role
      if not role or role not in self.config.role:
         role = self.config.defaultRole
         traceX( TR_DEBUG, 'user', user, 'fall back to role', role )

      compiledRole = compiledRoles.get( role )
      if compiledRole is None:
         return ( "denied", "Unknown role", {} )

      # Match the command and mode against rules
      for seq in compiledRole.seqs:
         try:
            rule = compiledRole.rules[ seq ]
         except KeyError:
            # this may be due to modification from another thread - ignore it
            continue
         if not self._ruleMatchByMode( rule, mode ):
            continue
         match = rule.regex.match( ' '.join( tokens ) )
         if match:
            if rule.permit:
               return ( "allowed", '', {} )
            else:
               return ( "denied", '', {} )

      return ( "denied", '', {} )

   def hasUserShell( self ):
      return True

   def getUserShell( self, name ):
      acct = self.config.acct.get( name )
      if acct and acct.shell:
         return acct.shell
      else:
         return None

   def getPwEnt( self, name ):
      acct = self.config.acct.get( name )
      if acct:
         shell = acct.shell if acct.shell else self.aaaConfig.shell
         ent = makePwEntry( name, shell, authenMethod="local" )
         return ent
      return AaaPluginLib.AAA_PWENT_RESULT.INVALID

def Plugin( ctx ):
   mountGroup = ctx.entityManager.mountGroup()
   config = mountGroup.mount( 'security/aaa/local/config', 'LocalUser::Config', 'r' )
   aaaConfig = ctx.aaaAgent.config
   aaaStatus = ctx.aaaAgent.status
   def _finish():
      reactors_.append( LocalUserConfigReactor( config, ctx.aaaAgent ) )
   mountGroup.close( _finish )
   return LocalUserPlugin( config, aaaConfig, aaaStatus )
