# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
import errno
import os
import re
import shutil
import tempfile
import threading
import time
import uuid

import Tac
import Tracing
from glob import glob
from CertRotationModel import CertificateRotation
from CertRotationModel import StatusCode
from CertRotationModel import RotationGenerateSignRequestStatus
from CertRotationModel import RotationImportStatus
from CertRotationModel import RotationCommitStatus
from CertRotationModel import RotationCommitSslProfileStatus
from CertRotationModel import RotationClearStatus
import SslCliLib
from SslCertKey import SslCertKeyError
from SslCertKey import isCertificateMatchesKey
from SslCertKey import validateCertificate
from SslCertKey import validateRsaPrivateKey
from SslCertKey import validateCertificateData
from SslCertKey import getCertificateDates
from SslCertKey import generateCertificate
from SslCertKey import isSslDirsCreated
from SslCertKey import getCertPem
from GenericReactor import GenericReactor

__defaultTraceHandle__ = Tracing.Handle( 'CertRotation' )
t0 = Tracing.trace0
t1 = Tracing.trace1
Const = Tac.Type( "Mgmt::Security::Ssl::Constants" )
ErrorType = Tac.Type( "Mgmt::Security::Ssl::ErrorType" )

def getExceptionErrorMessage( exception ):
   genericError = "Request failed"
   specificError = ""
   if exception:
      if isinstance( exception, SslCertKeyError ):
         specificError = " (%s)" % str( exception )
      elif isinstance( exception, EnvironmentError ):
         specificError = " (%s)" % exception.strerror
   return "%s%s" % ( genericError, specificError )

def getErrorMessage( statusCode, exception=None ):
   if statusCode == StatusCode.invalidRotationId:
      errorMsg = "Rotation ID does not exist"
   if ( statusCode == StatusCode.invalidCert or 
        statusCode == StatusCode.invalidKey ):
      assert isinstance( exception, SslCertKeyError )
      errorMsg = "%s" % str( exception ) 
   if statusCode == StatusCode.invalidSslProfile:
      errorMsg = ( "SSL profile does not exist or not "
                   "configured with certificate and key" )
   if statusCode == StatusCode.certDoesNotExist:
      errorMsg = "Certificate does not exist"
   if statusCode == StatusCode.certDoesNotMatchKey:
      errorMsg = "Certificate does not match key"
   if statusCode == StatusCode.certNotYetValid:
      errorMsg = "Certificate is not yet valid"
   if statusCode == StatusCode.certExpired:
      errorMsg = "Certificate has expired"
   if statusCode == StatusCode.requestFailed:
      errorMsg = getExceptionErrorMessage( exception )
   return errorMsg

def getGenerateCsrStatus( statusCode, exception=None, csr=None, rotationId=None ):
   status = RotationGenerateSignRequestStatus()
   status.statusCode = statusCode
   if statusCode == StatusCode.success:
      assert csr
      assert rotationId
      status.csr = csr
      status.rotationId = rotationId
   else:
      status.errorMessage = getErrorMessage( statusCode, exception )
   return status
   
def getImportCertStatus( statusCode, rotationId, exception=None ):
   status = RotationImportStatus()
   status.statusCode = statusCode
   status.rotationId = rotationId
   if statusCode != StatusCode.success:
      status.errorMessage = getErrorMessage( statusCode, exception )
   return status

def getCommitStatus( statusCode, rotationId, exception=None ):
   status = RotationCommitStatus()
   status.statusCode = statusCode
   status.rotationId = rotationId
   if statusCode != StatusCode.success:
      status.errorMessage = getErrorMessage( statusCode, exception )
   return status

def getCommitSslProfileStatus( statusCode, errorMessage=None, exception=None ):
   status = RotationCommitSslProfileStatus()
   status.statusCode = statusCode
   if statusCode != StatusCode.success:
      if errorMessage:
         status.errorMessage = errorMessage
      else:
         status.errorMessage = getErrorMessage( statusCode, exception )
   return status

def getClearStatus( statusCode, rotationId, exception=None ):
   status = RotationClearStatus()
   status.statusCode = statusCode
   status.rotationId = rotationId
   if statusCode != StatusCode.success:
      status.errorMessage = getErrorMessage( statusCode, exception )
   return status

def cleanupFilesAndDirs( globPathname ):
   for cleanupFileOrDir in glob( globPathname ):
      if os.path.isfile( cleanupFileOrDir ):
         t1( "cleanupFilesAndDirs deleting file:", cleanupFileOrDir )
         try:
            os.remove( cleanupFileOrDir )
         except ( OSError, IOError ):
            pass
      else:
         t1( "cleanupFilesAndDirs deleting dir:", cleanupFileOrDir )
         shutil.rmtree( cleanupFileOrDir, ignore_errors=True )

class RotationReq( object ):
   def __init__( self, rotationMgr, rotationBaseDirPath, sslConfig, sslServiceStatus,
                 rotationId, profileName, committing=False ):
      '''
      Loads the rotationDir with the below structure if exists as checks
      for its expiry:
         rotationId.profileName.[commit]
         uuid: unique rotationId
         profileName: ssl profile name associated with this rotation id
         commit: Is this rotaionId being commited
      If rotationDir does not exist, then just initializes state
      '''
      t0( "Req.__init__ Start:",
          "rotationBaseDirPath:", rotationBaseDirPath,
          "rotaionId:", rotationId, 
          "profileName:", profileName, 
          "committing:", committing )
      
      self.rotationMgr_ = rotationMgr
      self.rotationBaseDirPath_ = rotationBaseDirPath
      self.sslConfig_ = sslConfig
      self.sslServiceStatus_ = sslServiceStatus 
      self.rotationId_ = rotationId
      self.profileName_ = profileName
      self.idDotProfile_ = self.rotationId_ + "."  + self.profileName_
      self.rotationDirPath_ = self._getRotationDirPath( commit=committing )
      self.expiryClock_ = Tac.ClockNotifiee( self.handleExpiry,
                                             timeMin=Tac.endOfTime )
      self.lock_ = threading.Lock()
      self.key_ = None
      self.keymtime_ = None
      self.importTimeout_ = None
      self.csr_ = None
      self.cert_ = None
      self.deleted_ = False
      
      if os.path.isdir( self.rotationDirPath_ ):
         try:
            self._cleanupRotationReq()
            self._loadRotationReq()
            self._handleRestore()
            self.handleExpiry()
         except Exception as e: # pylint: disable-msg=W0703
            if os.environ.get( "ARTEST_COMMIT_RESTORE_ABORT_AFTER" ):
               t1( "Req.__init__ test commit/restore abort exception" )
            else:
               t1( "Req.__init__ load Error:", str( e ) )
               self._cleanup() 
      
      t0( "Req.__init__ End:" )

   def getCliRotationModel( self ):
      with self.lock_:
         if self.deleted_:
            return None
         result = CertificateRotation()
         result.rotationId = self.rotationId_
         result.profileName = self.profileName_
         result.importTimeout = self._getImportTimeoutSeconds() / 60
         result.creationTime = self.keymtime_
         csrPath = self._csrFilepath()
         csrModel = SslCliLib.getCertificateSigningRequestModel( csrPath )
         result.certificateSigningRequest = csrModel
         if self.cert_ is not None:
            certPath = self._certFilepath()
            result.certificate = SslCliLib.getCertificateModel( certPath )
            timeout = result.certificate.notAfter
            rotationState = ( 'commitPending' if timeout > time.time() else
                              'commitExpired' )
         else:
            timeout = result.importTimeout * 60 + result.creationTime
            rotationState = ( 'importPending' if timeout > time.time() else
                              'importExpired' )
         result.rotationState = rotationState
         return result
      
   def _cleanupRotationReq( self ):
      t0( "Req._cleanupRotationReq Start:", self.idDotProfile_ )
      assert self.rotationDirPath_.startswith( Const.rotationBaseDirPath() )
      cleanupFilesAndDirs( "%s/tmp*" % self.rotationDirPath_ )
      cleanupFilesAndDirs( "%s/*%s" % ( self.rotationDirPath_, Const.rotDeleteExt ) )
      t0( "Req._cleanupRotationReq End:", self.idDotProfile_ )
      
   def _loadRotationReq( self ):
      t0( "Req._loadRotationReq Start:", self.idDotProfile_ )
      
      self.csr_ = self._readFile( self._csrFilepath() )
      t1( "Req._loadRotationReq csr:", self._csrFilepath() )
      
      self.importTimeout_ = self._readFile( self._importTimeoutFilepath() )
      t1( "Req._loadRotationReq importTimeout:", self.importTimeout_ )
      
      self.cert_ = self._readFile( self._certFilepath(), ignoreEnoent=True )
      t1( "Req._loadRotationReq cert:", self._certFilepath() )
      
      self.key_ = self._readFile( self._keyFilepath() )
      self.keymtime_ = int( os.path.getmtime( self._keyFilepath() ) )
      t1( "Req._loadRotationReq keymtime:", self.keymtime_ )
      
      t0( "Req._loadRotationReq End:", self.idDotProfile_ )
            
   def _keyFilepath( self ):
      return "%s/%s" % ( self.rotationDirPath_, Const.rotKeyFile )

   def _certFilepath( self ):
      return "%s/%s" % ( self.rotationDirPath_, Const.rotCertFile )

   def _csrFilepath( self ):
      return "%s/%s" % ( self.rotationDirPath_, Const.rotCsrFile )
   
   def _importTimeoutFilepath( self ):
      return "%s/%s" % ( self.rotationDirPath_, Const.rotImportTimeoutFile )
   
   def _restoreCertKeyFilepath( self, certKeyExt ):
      assert os.path.isdir( self._restoreDirpath() )
      for restoreFilepath in glob( "%s/*" % self._restoreDirpath() ):
         ( _, ext ) = os.path.splitext( restoreFilepath )
         if ext == certKeyExt:
            return restoreFilepath
      
   def _restoreCertFilepath( self ):
      return self._restoreCertKeyFilepath( Const.rotRestoreCertExt )
   
   def _restoreKeyFilepath( self ):
      return self._restoreCertKeyFilepath( Const.rotRestoreKeyExt )

   def _origCertFilepath( self ):
      restoreCertFilepath = self._restoreCertFilepath()
      certFile = os.path.basename( restoreCertFilepath )
      ( origCertFile, _ ) = os.path.splitext( certFile )
      return Const.certPath( origCertFile )

   def _origKeyFilepath( self ):
      restoreKeyFilepath = self._restoreKeyFilepath()
      keyFile = os.path.basename( restoreKeyFilepath )
      ( origKeyFile, _ ) = os.path.splitext( keyFile )
      return Const.keyPath( origKeyFile )
   
   def _isCommitting( self ):
      pattern = r'^(.*?)\.(.*?)(%s)?$' % Const.rotCommitExt
      match = re.match( pattern, os.path.basename( self.rotationDirPath_ ) )
      return match and bool( match.group( 3 ) )
      
   def _restoreDirpath( self ):
      return "%s/%s" % ( self.rotationDirPath_, Const.rotRestoreDir )

   def _readFile( self, filepath, ignoreEnoent=False ):
      t0( "Req._readFile Start:", self.idDotProfile_, 
          "filepath:", filepath, 
          "ignoreEnoent:", ignoreEnoent )
      
      buf = None
      try:
         with open( filepath, "r" ) as fp:
            buf = fp.read()
      except ( OSError, IOError ) as e:
         t1( "Req._readFile error:", e.strerror )
         if e.errno != errno.ENOENT or not ignoreEnoent:
            raise
      
      t0( "Req._readFile End:", self.idDotProfile_ )
      return buf

   def _writeFile( self, filepath, buf, mode ):
      t0( "Req._writeFile Start:", self.idDotProfile_,
          "filepath:", filepath, "mode:", mode )
      
      filedir = os.path.dirname( filepath )
      with tempfile.NamedTemporaryFile( dir=filedir, delete=False ) as tmpFile:
         t1( "Req._writeFile tempFilepath:", tmpFile.name )
         tmpFile.write( buf )
         tmpFile.flush()
      
      os.chmod( tmpFile.name, mode )
      os.rename( tmpFile.name, filepath )
      
      t0( "Req._writeFile End:", self.idDotProfile_ )

   def _deleteFile( self, filepath ):
      t0( "Req._deleteFile Start:", self.idDotProfile_, "filepath:", filepath )
      try:
         os.remove( filepath )
      except ( OSError, IOError ):
         pass 
      t0( "Req._deleteFile End:", self.idDotProfile_ )
               
   def _getRotationDirPath( self, commit=False ):
      t0( "Req._getRotationDirPath Start:", self.idDotProfile_ )
      
      dirName = "%s.%s%s" % ( self.rotationId_, self.profileName_,
                              Const.rotCommitExt if commit else "" )

      ret = self.rotationBaseDirPath_ + dirName
      t0( "Req._getRotationDirPath End:", self.idDotProfile_, 
          "returning", ret )
      return ret
   
   def _getImportTimeoutSeconds( self ):
      t0( "Req._getImportTimeoutSeconds Start:", self.idDotProfile_ )
      
      if os.environ.get( "ARTEST_CERTROT_TIMEOUT_SECS" ):
         t1( "Req._getImportTimeoutSeconds test env. using seconds" )
         ret = int( self.importTimeout_ )
      else:
         ret = int( self.importTimeout_ ) * 60
      t0( "Req._getImportTimeoutSeconds End:", self.idDotProfile_,
          "returning", ret )
      return ret
      
   def _getCertExpirySeconds( self ):
      t0( "Req._getCertExpirySeconds Start:", self.idDotProfile_ )
      
      ( _, na ) = getCertificateDates( self.cert_ )
      t1( "Req._getCertExpirySeconds cert naDt:", na )
      
      now = int( Tac.utcNow() )
      t1( "Req._getCertExpirySeconds UTC now:", now )

      ret = ( na - now ) if na >= now else 0
      t0( "Req._getCertExpirySeconds End:", self.idDotProfile_, 
          "returning", ret )
      return ret
      
   def _getKeyExpirySeconds( self ):
      t0( "Req._getKeyExpirySeconds Start:", self.idDotProfile_ )
      
      t1( "Req._getKeyExpirySeconds key modification time", self.keymtime_ )
      
      now = int( Tac.utcNow() )
      t1( "Req._getCertExpirySeconds UTC now:", now )
      
      importTimeoutSecs = self._getImportTimeoutSeconds()
      keyNa = self.keymtime_ + importTimeoutSecs
      t1( "Req._getKeyExpirySeconds Key notAfter:", keyNa )
      
      ret = ( keyNa - now ) if keyNa >= now else 0
      t0( "Req._getKeyExpirySeconds End:", self.idDotProfile_, 
          "returning", ret )
      return ret
   
   def _getExpirySeconds( self ):
      t0( "Req._getExpirySeconds Start:", self.idDotProfile_ )
      
      if self.cert_:
         ret = self._getCertExpirySeconds()
      else:
         ret = self._getKeyExpirySeconds()
      
      t0( "Req._getExpirySeconds End:", self.idDotProfile_, 
          "returning", ret )
      return ret
      
   def _cleanup( self ):
      '''
      Deletes the rotaionDir atomically by moving
      it to a uuid.profileName.delete and then deleting it
      sets 'deleted' flag to be cleaned up by rotationMgr
      '''
      t0( "Req._cleanup Start:", self.idDotProfile_ )
      
      if self.deleted_:
         t1( "Req._cleanup Already deleted" )
         return
         
      deleteDir = self.rotationDirPath_
      t1( "Req._cleanup deleteDir:", deleteDir )
      
      if self._isCommitting():
         t1( "Req._cleanup is committing" )
         deleteDir = re.sub( "%s$" % Const.rotCommitExt, "", deleteDir )
      
      deleteDir = "%s%s" % ( deleteDir, Const.rotDeleteExt )
      t1( "Req._cleanup deleteDir:", deleteDir )
      
      try:
         os.rename( self.rotationDirPath_, deleteDir )
         shutil.rmtree( deleteDir, ignore_errors=True )
      except Exception: # pylint: disable-msg=W0703
         pass
      
      self.expiryClock_.timeMin = Tac.endOfTime
      self.key_ = None
      self.keymtime_ = None
      self.importTimeout_ = None
      self.csr_ = None
      self.cert_ = None
      self.deleted_ = True
      
      self.rotationMgr_.handleDeletedReq( self.rotationId_ )
      t0( "Req._cleanup End:", self.idDotProfile_ )

   def _getExpiryBufferTime( self ):
      bufferTime = os.environ.get( "ARTEST_EXPIRY_BUFFER_TIME" )
      return int( bufferTime ) if bufferTime else Const.rotExpiryBufferTime
         
   def _handleExpiry( self ):
      t0( "Req._handleExpiry Start:", self.idDotProfile_ )

      if self.deleted_:
         t0( "Req._handleExpiry Already deleted" )
         return
      
      expirySeconds = self._getExpirySeconds()
      t1( "Req._handleExpiry will expire in", expirySeconds, "seconds" )
      
      if expirySeconds == 0:
         t1( "Req._handleExpiry Rotation req expired" )
         self._cleanup()
      else:
         # I observed that sometimes the clock gets triggered early
         # and we will have to re-arm the clock again. So, adding
         # some buffer time to the clock so that when we re-check,
         # the rotation is expired.
         expirySeconds = expirySeconds + self._getExpiryBufferTime()
         t1( "Req._handleExpiry Will check expiry after", expirySeconds, "seconds" )
         self.expiryClock_.timeMin = Tac.now() + expirySeconds
      
      t0( "Req._handleExpiry End:", self.idDotProfile_ )

   def _getSslProfileCertKeyFiles( self ):
      t0( "Req._getSslProfileCertKeyFiles Start:", self.idDotProfile_ )
      
      profile = self.sslConfig_.profileConfig.get( self.profileName_ )
      certFile = None
      keyFile = None
      if profile:
         certFile = profile.certKeyPair.certFile
         keyFile = profile.certKeyPair.keyFile
      else:
         t1( "Req._getSslProfileCertKeyFiles Ssl profile does not exist" )
      
      t0( "Req._getSslProfileCertKeyFiles End:", self.idDotProfile_,
          "returning", certFile, keyFile )         
      return ( certFile, keyFile )
      
   def generateCsr( self, importTimeout, digest, signReqParams, newKeyBits=None, 
                    key=None ):
      '''
      If newKeyBits is supplied, generates new key in <rotationDir>/key.pem 
      If key is supplied, copies the key in <rotationDir>/key.pem
      Generates csr for the key in <rotationDir>/csr.pem 
      Copies the importTimeout in <rotationDir>/import.timeout
      
      Returns: RotaionGenerateSignRequestStatus wil statusCode as below
      statusCode.success: successfull
      statusCode.requestFailed: request failed
      csr: PEM encode CSR  if successull
      '''
      assert not os.path.isdir( self.rotationDirPath_ )
      assert newKeyBits or key
      assert not ( newKeyBits and key )
      
      t0( "Req.generateCsr Start:", self.idDotProfile_,
          "digest:", digest, 
          "signReqParams:", signReqParams,
          "newKeyBits:", newKeyBits,
          "use key:", bool( key ) )
      
      status = None
      tmpDir = None
      
      try:
         tmpDir = tempfile.mkdtemp( dir=self.rotationBaseDirPath_ )
         os.chmod( tmpDir, Const.sslDirPerm )
         t1( "Req.generateCsr tmpDir:", tmpDir )
         
         keyFilepath = "%s/%s" % ( tmpDir, Const.rotKeyFile )
         csrFilepath = "%s/%s" % ( tmpDir, Const.rotCsrFile )
         importTimoutFilepath = "%s/%s" % ( tmpDir, Const.rotImportTimeoutFile )
         
         if key:
            self._writeFile( keyFilepath, key, mode=Const.sslKeyPerm )
            
         ( self.csr_, _ ) = generateCertificate( keyFilepath=keyFilepath,
                                                 certFilepath=csrFilepath, 
                                                 signRequest=True,
                                                 genNewKey=bool( newKeyBits ), 
                                                 newKeyBits=newKeyBits,
                                                 digest=digest,
                                                 **signReqParams )
         t1( "Req.generateCsr csr:", self.csr_ )
         
         self.importTimeout_ = importTimeout
         self.key_ = self._readFile( keyFilepath )
         self.keymtime_ = int( os.path.getmtime( keyFilepath ) )
         self._writeFile( importTimoutFilepath, "%d" % importTimeout,
                          mode=Const.sslCertPerm )
         os.rename( tmpDir, self.rotationDirPath_ )
         status = getGenerateCsrStatus( StatusCode.success, csr=self.csr_, 
                                        rotationId=self.rotationId_ )
      except ( SslCertKeyError, Exception ) as e: # pylint: disable-msg=W0703
         self._cleanup()
         t1( "Req.generateCsr Error:", str( e ) )
         status = getGenerateCsrStatus( StatusCode.requestFailed, exception=e )
      finally:
         if tmpDir:
            shutil.rmtree( tmpDir, ignore_errors=True )
      
      t0( "Req.generateCsr End:", self.idDotProfile_, 
          "returning", status )
      return status
   
   def importCertificate( self, certPem, validateExpiry=True ):
      '''
      Copies the certPem certificate to rotationDir/cert.pem
      
      Returns: RotaionImportStatus with status code as below
      statusCode.success: successful
      statusCode.invalidRotationId: rotaionId does not exist
      statusCode.invalidCert: cert is not valid
      statusCode.certDoesNotMatchKey: cert does not match with key
      statusCode.certExpired: cert has already expired
      statusCode.requestFailed: request failed
      '''
      with self.lock_:
         t0( "Req.importCertificate Start:", self.idDotProfile_,
             "certPem:", certPem )

         if self.deleted_:
            t1( "Req.importCertificate already deleted" )
            return getImportCertStatus( StatusCode.invalidRotaionId,
                                        self.rotationId_ )

         try:
            validateCertificate( certPem, isFile=False )
         except SslCertKeyError as e:
            t1( "Req.importCertificate invalid certificate:", str( e ) )
            return getImportCertStatus( StatusCode.invalidCert,
                                        self.rotationId_,
                                        exception=e )
         except Exception as e: # pylint: disable-msg=W0703
            t1( "Req.importCertificate validateCertificate failed:", str( e ) )
            return getImportCertStatus( StatusCode.requestFailed,
                                        self.rotationId_, exception=e )

         try:
            if validateExpiry:
               error = validateCertificateData( certPem, validateStartDate=False )
               t1( "Req.importCertificate validateCertificateData error:", error )
   
               if error == ErrorType.certExpired:
                  t1( "Req.importCertificate certificate expired" )
                  return getImportCertStatus( StatusCode.certExpired,
                                              self.rotationId_ )

            if not isCertificateMatchesKey( certPem, self.key_ ):
               t1( "Req.importCertificate certificate does not match key" )
               return getImportCertStatus( StatusCode.certDoesNotMatchKey,
                                           self.rotationId_ )

            self._writeFile( self._certFilepath(), certPem, mode=Const.sslCertPerm )
            self.cert_ = certPem

            t1( "Req.importCertificate success" )
            return getImportCertStatus( StatusCode.success,
                                        self.rotationId_ )
         except Exception as e: # pylint: disable-msg=W0703
            t1( "Req.importCertificate error:", str( e ) )
            return getImportCertStatus( StatusCode.requestFailed,
                                        self.rotationId_, exception=e )
   
   def _handleTestCommitRestoreAbortAfter( self, after ):
      abortAfter = os.environ.get( "ARTEST_COMMIT_RESTORE_ABORT_AFTER" )
      if abortAfter == after:
         t1( "Req._handleTestCommitRestoreAbortAfter:", after )
         raise Exception( "Test commit/restore exception after %s" % after )
      
   def _validateCommit( self, origCertFile, origKeyFile, validateDates ):
      t0( "Req._validateCommit Start:", self.idDotProfile_ ) 
      
      status = None
      
      if not origCertFile or not origKeyFile:
         t1( "Req.commit SSL profile is not valid" )
         status = getCommitStatus( StatusCode.invalidSslProfile, self.rotationId_ )
         
      if not self.cert_:
         t1( "Req.commit certificate does not exist" )
         status = getCommitStatus( StatusCode.certDoesNotExist, self.rotationId_ )
      elif validateDates:
         error = validateCertificateData( self.cert_ )
         if error == ErrorType.certNotYetValid:
            t1( "Req.commit cert not yet valid" )
            status = getCommitStatus( StatusCode.certNotYetValid, self.rotationId_ )
            
         if error == ErrorType.certExpired:
            t1( "Req.commit cert expired" )
            status = getCommitStatus( StatusCode.certExpired, self.rotationId_ )
      
      t0( "Req._validateCommit End: returning", status )
      return status
   
   def _waitForSslServiceUsingNewCert( self, newCert ):
      t0( "Req._waitForSslServiceUsingNewCert Start:", self.idDotProfile_,
          "new cert:", newCert )
      
      def allSslServicesUsingNewCert():
         for serviceName, service in self.sslServiceStatus_.items():
            if ( service.profileName == self.profileName_ and
                 service.certificate != newCert ):
               t1( "Req._waitForSslServiceUsingNewCert service",
                   serviceName, "is using", self.profileName_, 
                   "but not yet using new cert" )
               return False 
         return True

      timeout = os.environ.get( "ARTEST_SERVICE_COMMIT_TIMEOUT" )
      timeout = int( timeout ) if timeout else Const.rotServiceTimeout
      try:
         Tac.waitFor( allSslServicesUsingNewCert, maxDelay=2,
                      sleep=True, timeout=timeout )
      except Tac.Timeout:
         t1( "Req._waitForSslServiceUsingNewCert timedout" )
         return False
      
      t0( "Req._waitForSslServiceUsingNewCert End:" )
      return True
      
   def commit( self, validateDates=True, dontWait=False ):
      '''
      - Backs up original cert and key in <rotationDir>/restore/
      - Renames the <rotationDir> to <rotationDir>.commit
      - Copies the <rotationDir>/cert.pem and <rotationDir>/key.pem to the 
        certificate: and sslkey: filesystems
      - calls _cleanup
      
      Returns: RotationCommitStatus with statusCode as below
      statusCode.success: successful
      statusCode.invalidRotationId: rotaionId does not exist
      statusCode.certDoesNotExist: cert.pem does not exist in rotationDir
      statusCode.certNotYetValid: cert's notBefore in future
      statusCode.certExpired: cert has expired 
      statusCode.requestFailed: request failed
      '''
      with self.lock_:
         t0( "Req.commit Start:", self.idDotProfile_ )

         if self.deleted_:
            t1( "Req.commit already deleted" )
            return getCommitStatus( StatusCode.invalidRotaionId, self.rotationId_ )
         
         tmpDir = None
         try:
            ( origCertFile, origKeyFile ) = self._getSslProfileCertKeyFiles()
            errorStatus = self._validateCommit( origCertFile, origKeyFile, 
                                                validateDates=validateDates )
            if errorStatus:
               return errorStatus
               
            tmpDir = tempfile.mkdtemp( dir=self.rotationDirPath_ )
            os.chmod( tmpDir, Const.sslDirPerm )
            t1( "Req.commit tmpDir:", tmpDir )
            
            origCertFilepath = Const.certPath( origCertFile )
            t1( "Req.commit origCertFilepath:", origCertFilepath )
                
            origKeyFilepath = Const.keyPath( origKeyFile )
            t1( "Req.commit origKeyFilepath:", origKeyFilepath )
            
            restoreCertFile = "%s%s" % ( origCertFile, Const.rotRestoreCertExt )
            restoreCertFilepath = "%s/%s" % ( tmpDir, restoreCertFile )
            t1( "Req.commit restoreCertFilepath:", restoreCertFilepath )

            restoreKeyFile = "%s%s" % ( origKeyFile, Const.rotRestoreKeyExt )
            restoreKeyFilepath = "%s/%s" % ( tmpDir, restoreKeyFile ) 
            t1( "Req.commit restoreKeyFilepath:", restoreKeyFilepath )
                                             
            certData = self._readFile( origCertFilepath, ignoreEnoent=True )
            keyData = self._readFile( origKeyFilepath, ignoreEnoent=True )
            certData = certData if certData else ""
            keyData = keyData if keyData else ""

            self._writeFile( restoreCertFilepath, certData, Const.sslCertPerm )
            self._writeFile( restoreKeyFilepath, keyData, Const.sslKeyPerm )
            self._handleTestCommitRestoreAbortAfter( "tmpDir" )
            
            t1( "Req.commit Rename", tmpDir, "to", self._restoreDirpath() )
            os.rename( tmpDir, self._restoreDirpath() )
            self._handleTestCommitRestoreAbortAfter( "restoreDir" )
            
            commitDirpath = self._getRotationDirPath( commit=True )
            t1( "Req.commit Rename", self.rotationDirPath_, "to", commitDirpath )
            os.rename( self.rotationDirPath_, commitDirpath )
            self.rotationDirPath_ = commitDirpath
            self._handleTestCommitRestoreAbortAfter( "commitDir" )
            
            t1( "Req.commit replace", origCertFilepath, "with new cert" )
            self._writeFile( origCertFilepath, self.cert_, Const.sslCertPerm )
            self._handleTestCommitRestoreAbortAfter( "certReplace" )
            
            t1( "Req.commit replace", origKeyFilepath, "with new key" )
            self._writeFile( origKeyFilepath, self.key_, Const.sslKeyPerm )
            self._handleTestCommitRestoreAbortAfter( "keyReplace" )
            
            newCert = getCertPem( self._certFilepath() )
            self._cleanup()
            if dontWait or self._waitForSslServiceUsingNewCert( newCert ):
               return getCommitStatus( StatusCode.success, self.rotationId_ )
            else:
               return getCommitStatus( StatusCode.requestFailed, self.rotationId_ )
         except Exception as e: # pylint: disable-msg=W0703
            if os.environ.get( "ARTEST_COMMIT_RESTORE_ABORT_AFTER" ):
               t1( "Req.commit test commit/restore abort exception" )
               return
            
            t1( "Req.commit error:", str( e ) )
            try:
               self._handleRestore()
            except Exception as e:  # pylint: disable-msg=W0703
               t1( "Req.commit handleRestore error:", str( e ) )
               self._cleanup() 
            return getCommitStatus( StatusCode.requestFailed,
                                    self.rotationId_, exception=e )
         finally:
            if os.environ.get( "ARTEST_COMMIT_RESTORE_ABORT_AFTER" ):
               t1( "Req.commit test commit/restore abort exception" )
            elif tmpDir:
               shutil.rmtree( tmpDir, ignore_errors=True )
   
   def clear( self ):
      '''
      Cleans up the rotation request

      Returns: RotationClearStatus with statusCode as below
      statusCode.success: successful
      statusCode.invalidRotationId: rotaionId does not exist
      '''
      with self.lock_:
         t0( "Req.clear Start:", self.idDotProfile_ )

         if self.deleted_:
            t1( "Req.clear already deleted" )
            return getClearStatus( StatusCode.invalidRotaionId, self.rotationId_ )
         
         self._cleanup()
         t0( "Req.clear End:", self.idDotProfile_, "returning success" )
         return getClearStatus( StatusCode.success, self.rotationId_ )
      
   def _handleRestore( self ):
      '''
      Called on startup or commit fails in middle. 
      Restores the partially commited rotation
      '''
      t0( "Req._handleRestore Start:", self.idDotProfile_ )
      
      if not self._isCommitting():
         t1( "Req._handleRestore not committing. Deleteing any restoreDir" )
         shutil.rmtree( self._restoreDirpath(), ignore_errors=True )
         return

      if os.path.isdir( self._restoreDirpath() ):
         origCertFilepath = self._origCertFilepath()
         t1( "Req._handleRestore origCertFilepath:", origCertFilepath )
         
         origKeyFilepath = self._origKeyFilepath()
         t1( "Req._handleRestore origKeyFilepath:", origKeyFilepath )
         
         origCertData = self._readFile( self._restoreCertFilepath() )
         origKeyData = self._readFile( self._restoreKeyFilepath() )
         
         if not origCertData:
            t1( "Req._handleRestore no certdata removing:", origCertFilepath )
            self._deleteFile( origCertFilepath )
         else:
            t1( "Req._handleRestore restoring:", origCertFilepath )
            self._writeFile( origCertFilepath, origCertData, Const.sslCertPerm )
         self._handleTestCommitRestoreAbortAfter( "certRestore" )
         
         if not origKeyData:
            t1( "Req._handleRestore no keyData removing:", origKeyFilepath )
            self._deleteFile( origKeyFilepath )
         else:
            t1( "Req._handleRestore restoring:", origKeyFilepath )
            self._writeFile( origKeyFilepath, origKeyData, Const.sslKeyPerm )
         self._handleTestCommitRestoreAbortAfter( "keyRestore" )         
            
         restoreDeleteDir = "%s%s" % ( self._restoreDirpath(), Const.rotDeleteExt )
         t1( "Req._handleRestore rename", self._restoreDirpath(), 
             "to", restoreDeleteDir, "and delete it" )      
         os.rename( self._restoreDirpath(), restoreDeleteDir )
         shutil.rmtree( restoreDeleteDir, ignore_errors=True )
         self._handleTestCommitRestoreAbortAfter( "restoreDelete" )
      
      origDirpath = self._getRotationDirPath()
      t1( "Req._handleRestore rename", self.rotationDirPath_, "to", origDirpath )
      os.rename( self.rotationDirPath_, origDirpath )
      self.rotationDirPath_ = origDirpath
      
      t0( "Req._handleRestore End:", self.idDotProfile_ )
      
   def handleExpiry( self ):
      '''
      Called when expiryClock expires or during init of existing
      rotaionReq. If the rotaionId is expired, calls _cleanup
      '''
      with self.lock_:
         t0( "Req.handleExpiry Start:" )
          
         try:
            self._handleExpiry()
         except ( SslCertKeyError, Exception ) as e: # pylint: disable-msg=W0703
            t1( "Req.handleExpiry error:", str( e ) )
            self._cleanup()
          
         t0( "Req.handleExpiry End:" )

   def isDeleted( self ):
      return self.deleted_
   
   # Called by test code only
   def forceCleanup( self ):
      t0( "Req.forceCleanup:", self.idDotProfile_ )
      self._cleanup()
   
class RotationMgr( object ):
   '''
   RotationMgr manages all rotationIds. It should be created 
   only on master Supe and when SslManager has been initialized.
   '''
   def __init__( self, rotationBaseDirPath, redundancyStatus, sslConfig,
                 sslServiceStatus ):
      '''
      Cleans up <rotationBaseDirPath>/tmp* dirs
      Cleans up <rotationBaseDirPath>/*.delete dirs
      Creates rotationReq objects rotaionDirs under <rotationBaseDirPath>/
      Sets initDone_ at the end. Only after this, the rotation
      commands becomes available (They will be guarded until this)
      '''
      t0( "Mgr.__init__ Start:", rotationBaseDirPath )
      self.rotationBaseDirPath_ = rotationBaseDirPath
      self.redundancyStatus_ = redundancyStatus
      self.sslConfig_ = sslConfig
      self.sslServiceStatus_ = sslServiceStatus
      self.rotationReq_ = {}
      self.lock_ = threading.Lock()
      self.initDone_ = False      
      self.redundancyStatusReactor_ = GenericReactor( self.redundancyStatus_,
                                                      [ "mode" ],
                                                      self._handleRedundancyStatus,
                                                      callBackNow=True )      
      t0( "Mgr.__init__ End:" )
      
   def _handleRedundancyStatus( self, notifiee=None ):
      t0( "Mgr._handleRedundancyStatus Start:" )
      if self.redundancyStatus_.mode == "active":
         try:
            self._cleanupRotationReqs()
            self.loadRotationReqs()
         except Exception as e: # pylint: disable-msg=W0703
            t1( "Mgr._handleRedundancyStatus load error:", str( e ) )
         
         self.initDone_ = True
      t0( "Mgr._handleRedundancyStatus End:" )
               
   def _cleanupRotationReqs( self ):
      t0( "Mgr._cleanupRotationReqs Start:" )
      assert self.rotationBaseDirPath_.startswith( Const.rotationBaseDirPath() )
      tmpDirs = "%s%s" % ( self.rotationBaseDirPath_, "tmp*" )
      cleanupFilesAndDirs( tmpDirs )     
      
      delDirs = "%s*%s" % ( self.rotationBaseDirPath_, Const.rotDeleteExt )
      cleanupFilesAndDirs( delDirs )
      t0( "Mgr._cleanupRotationReqs End:" )
   
   def _rotationReqArgs( self, rotationDir ):
      t0( "Mgr._rotationReqArgs Start:", rotationDir )
      
      # The rotaionDirectory structure is <uuid>.<profileName>[.commit]
      # parse the directory name to get rotationReq arguments.
      pattern = r'^(.*?)\.(.*?)(%s)?$' % Const.rotCommitExt
      match = re.match( pattern, rotationDir )
      if not match:
         return None
      rotationId = match.group( 1 )
      profileName = match.group( 2 )
      committing = bool( match.group( 3 ) )
      ret = ( rotationId, profileName, committing )
      t0( "Mgr._rotationReqArgs End: returning:", ret )
      return ret
   
   # Also called by test code
   def loadRotationReqs( self ):
      t0( "Mgr._loadRotationReqs Start:" )
      
      for rotationDir in os.listdir( self.rotationBaseDirPath_ ):
         rrArgs = self._rotationReqArgs( rotationDir )
         if not rrArgs:
            t1( "Unable to get args for rotationDir:", rotationDir )
            continue
         
         ( rotationId, profileName, committing ) = rrArgs
         
         if rotationId in self.rotationReq_:
            t1( "Mgr._loadRotationReqs rotation ID", rotationId, "already present" )
            continue
            
         t1( "Mgr._loadRotationReqs creating req for", rotationId )
         rr = RotationReq( self, self.rotationBaseDirPath_, self.sslConfig_, 
                           self.sslServiceStatus_, rotationId, profileName,
                           committing=committing )
         
         if not rr.isDeleted():
            t1( "Mgr._loadRotationReqs adding", rotationId )
            self.rotationReq_[ rotationId ] = rr
         else:
            t1( "Mgr._loadRotationReqs not adding", rotationId )
            
      t0( "Mgr._loadRotationReqs End:" )            
   
   def generateCsr( self, profileName, importTimeout, digest, 
                    signReqParams, newKeyBits=None, key=None,
                    handleExpiry=True ):
      '''
      Creates a rotationReq with unique uuid and calls its generateCsr
      Returns RotaionGenerateSignRequestStatus 
      '''
      with self.lock_:
         t0( "Mgr.generateCsr Start:",
             "profileName:", profileName, 
             "importTimeout:", importTimeout, 
             "digest:", digest, 
             "signReqParams:", signReqParams,
             "newKeyBits:", newKeyBits,
             "use key:", bool( key ) )
             
         rotationId = uuid.uuid1().hex
         t1( "Mgr.generateCsr: Rotation Id:", rotationId )
         
         rr = RotationReq( self, self.rotationBaseDirPath_, self.sslConfig_, 
                           self.sslServiceStatus_, rotationId, profileName )
         status = rr.generateCsr( importTimeout, digest, signReqParams, 
                                  newKeyBits=newKeyBits, key=key )
         t1( "Mgr.generateCsr: statusCode:", status.statusCode, 
             "errorMessage:", status.errorMessage,
             "CSR:", status.csr )
         
         if status.statusCode == StatusCode.success:
            t1( "Mgr.generateCsr: generateCsr success for", rotationId )
            self.rotationReq_[ rotationId ] = rr
            if handleExpiry:
               rr.handleExpiry()
         else:
            t1( "Mgr.generateCsr: Generate CSR failed for", rotationId )
         
         t0( "Mgr.generateCsr End: returning:", status )
         return status
   
   def importCertificate( self, rotationId, certPem, validateExpiry=True, 
                          handleExpiry=True ):
      with self.lock_:
         t0( "Mgr.importCertificate Start: rotationId:", rotationId,
             "certPem:", certPem )
   
         rr = self.rotationReq_.get( rotationId )
         if rr:
            status = rr.importCertificate( certPem, validateExpiry=validateExpiry )
            if status.statusCode == StatusCode.success:
               if handleExpiry:
                  rr.handleExpiry()
         else:
            status = getImportCertStatus( StatusCode.invalidRotationId, rotationId )
   
         t0( "Mgr.importCertificate End: returning", status )
         return status
   
   def commit( self, rotationId, validateDates=True, dontWait=False ):
      with self.lock_:
         t0( "Mgr.commit Start: rotationId:", rotationId )
         rr = self.rotationReq_.get( rotationId )
         if rr:
            status = rr.commit( validateDates=validateDates, dontWait=dontWait )
         else:
            status = getCommitStatus( StatusCode.invalidRotationId, rotationId )
   
         t0( "Mgr.commit End: returning", status )
         return status

   def commitSslProfile( self, profileName, key, cert, validateDates, dontWait ):
      def validateKey():
         with self.lock_:
            statusCode = StatusCode.success
            errorMessage = None
            try:
               validateRsaPrivateKey( key, isFile=False )
            except SslCertKeyError as e:
               t1( "Mgr.commitSslProfile invalid key:", str( e ) )
               statusCode = StatusCode.invalidKey
               errorMessage = getErrorMessage( statusCode, e )
            except Exception as e: # pylint: disable-msg=W0703
               t1( "Mgr.commitSslProfile validateRsaPrivateKey failed:", str( e ) )
               statusCode = StatusCode.requestFailed
               errorMessage = getErrorMessage( statusCode, e )
            return ( statusCode, errorMessage ) 

      def generateCsr():
         status = self.generateCsr( profileName, 
                                    Const.defImportTimeoutMins, 
                                    Const.defaultDigest, 
                                    { "commonName": "common.name" }, 
                                    key=key, 
                                    handleExpiry=False )
         return ( status.statusCode, status.errorMessage, status.rotationId )

      def importCert( rotationId ):
         status = self.importCertificate( rotationId, 
                                          cert, 
                                          validateExpiry=False, 
                                          handleExpiry=False )
         assert status.statusCode not in ( StatusCode.invalidRotationId,
                                           StatusCode.certExpired )
         return ( status.statusCode, status.errorMessage )      

      def commit():
         status = self.commit( rotationId, validateDates=validateDates,
                               dontWait=dontWait )
         unexpectedStatus = ( StatusCode.invalidRotationId,
                              StatusCode.certDoesNotExist )
         if not validateDates:
            unexpectedStatus += ( StatusCode.certNotYetValid,
                                  StatusCode.certExpired )
         assert status.statusCode not in unexpectedStatus
         return ( status.statusCode, status.errorMessage )      

      t0( "Mgr.commitSslProfile Start: profileName:", profileName, 
          "dontWait:", dontWait )
      ( statusCode, errorMessage ) = validateKey()
      if statusCode != StatusCode.success:
         status = getCommitSslProfileStatus( statusCode, errorMessage=errorMessage )
         t0( "Mgr.commitSslProfile returning:", status )
         return status

      ( statusCode, errorMessage, rotationId ) = generateCsr()
      if statusCode != StatusCode.success:
         status = getCommitSslProfileStatus( statusCode, errorMessage=errorMessage )
         t0( "Mgr.commitSslProfile returning:", status )
         return status
   
      ( statusCode, errorMessage ) = importCert( rotationId )
      if statusCode != StatusCode.success:
         self.clear( rotationId )
         status = getCommitSslProfileStatus( statusCode, errorMessage=errorMessage )
         t0( "Mgr.commitSslProfile returning:", status )
         return status

      ( statusCode, errorMessage ) = commit()
      if statusCode != StatusCode.success:
         self.clear( rotationId )

      status = getCommitSslProfileStatus( statusCode, errorMessage=errorMessage )
      t0( "Mgr.commitSslProfile returning:", status )
      return status

   def clear( self, rotationId ):
      with self.lock_:
         t0( "Mgr.clear Start: rotationId:", rotationId )
         rr = self.rotationReq_.get( rotationId )
         if rr:
            status = rr.clear()
         else:
            status = getClearStatus( StatusCode.invalidRotationId, rotationId )
   
         t0( "Mgr.clear End: returning", status )
         return status
   
   def handleDeletedReq( self, rotationId ):
      t0( "Mgr.handleDeletedReq Start: rotatioId:", rotationId )
      try:
         del self.rotationReq_[ rotationId ]
      except KeyError:
         pass
      t0( "Mgr.handleDeletedReq End: rotatioId:", rotationId )
      
   def isReady( self ):
      return self.initDone_ and isSslDirsCreated()
      
   # Below methods called by test code only
   def getRotationReq( self, rotationId ):
      return self.rotationReq_.get( rotationId )
   
   def getAllRotationReqs( self ):
      return self.rotationReq_
   
   def forceCleanup( self, cleanupDirs=True ):
      t0( "Mgr.forceCleanup" )
      if cleanupDirs:
         for rr in self.rotationReq_.values():
            rr.forceCleanup()
      self.rotationReq_.clear()
      self.initDone_ = False
