# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import errno

import Logging
import MgmtSecuritySslStatusSm
import SignatureVerificationMapLib as SigVerifyMapLib
from SignatureVerificationMapLib import FILE_ATTR
import SuperServer
import Tac
import Tracing

traceHandle = Tracing.Handle( 'SecuritySm' )
t0 = traceHandle.trace0
t1 = traceHandle.trace1

SslFeature = Tac.Type( "Mgmt::Security::Ssl::SslFeature" )

class SecuritySmSslReactor( MgmtSecuritySslStatusSm.SslStatusSm ):
   __supportedFeatures__ = [ SslFeature.sslFeatureTrustedCert ]

   def __init__( self, sslStatus, profileName, containingAgent ):
      t1( "init SecuritySmSslReactor for profile", profileName )
      self.containingAgent = containingAgent
      super( SecuritySmSslReactor, self ).__init__( sslStatus, profileName,
                                                    'SecuritySm' )
      self.containingAgent.handleConfigChange()

   def handleProfileState( self ):
      t1( "SSL profile", self.profileName_, "is", self.profileStatus_.state )
      self.containingAgent.handleConfigChange()

   def handleProfileDelete( self ):
      t1( "handle profile delete for profile", self.profileName_ )
      self.containingAgent.handleConfigChange()

class SecurityConfigReactor( Tac.Notifiee ):
   notifierTypeName = "Mgmt::Security::Config"

   def __init__( self, mgmtSecConfig, containingAgent ):
      Tac.Notifiee.__init__( self, mgmtSecConfig )
      self.containingAgent = containingAgent
      self.config_ = containingAgent.config_
      self.sslStatus_ = containingAgent.sslStatus
      self.activity_ = None
      self.activityInterval_ = 60
      self.sync()

   @Tac.handler( 'enforceSignature' )
   def handleEnforceSignatureVerification( self ):
      t1( "handle enforceSignature:", self.config_.enforceSignature )
      self.sync()

   @Tac.handler( 'sslProfile' )
   def handleSignatureSslProfile( self ):
      t1( "handle SSL profile change:", self.config_.sslProfile )
      self.containingAgent.initSslReactor()

   def sslProfileValid( self ):
      profileStatus = self.sslStatus_.profileStatus.get( self.config_.sslProfile )
      if profileStatus:
         return profileStatus.state == 'valid'
      return False

   def handleSigVerifyMapFile( self, imgType ):
      fileMgr = SigVerifyMapLib.FileMgr( imgType )
      enforce = self.config_.enforceSignature
      sslProfile = self.config_.sslProfile
      if ( enforce and self.sslProfileValid() ):
         t1( "Create/modify SigVerifyMapFile" )
         fileAttrs = { FILE_ATTR.ENFORCE_SIGNATURE: enforce,
                       FILE_ATTR.SSL_PROFILE: sslProfile }
         try:
            fileMgr.writeAttrs( fileAttrs )
         except ( OSError, IOError ) as e:
            t0( "Cannot update/create signature verification mapping file:",
                e.strerror )
            if e.errno == errno.ENOSPC:
               # pylint: disable-msg=no-member
               Logging.log( SuperServer.SYS_SERVICE_FILESYSTEM_FULL,
                            fileMgr.mapFilename, 'security' )
            else:
               raise
            return False
      else:
         t1( "Remove SigVerifyMapFile" )
         fileMgr.deleteFile()
      return True

   def sync( self ):
      if not self.handleSigVerifyMapFile( 'extension' ):
         # Write failed, retry in a minute
         if self.activity_:
            self.activity_.timeMin = min( self.activity_.timeMin,
                                          Tac.now() + self.activityInterval_ )
         else:
            self.activity_ = Tac.ClockNotifiee()
            self.activity_.handler = self.sync
            self.activity_.timeMin = Tac.now() + self.activityInterval_

class SecuritySm( SuperServer.SuperServerAgent ):
   def __init__( self, entityManager ):
      SuperServer.SuperServerAgent.__init__( self, entityManager )
      self.warm_ = False
      mg = entityManager.mountGroup()
      self.config_ = mg.mount( 'mgmt/security/config',
                               'Mgmt::Security::Config', 'r' )
      self.sslStatus = mg.mount( 'mgmt/security/ssl/status',
                                 'Mgmt::Security::Ssl::Status', 'r' )
      self.securityConfigReactor_ = None
      self.sslReactor = None

      def _finished():
         # run only if active
         if self.active():
            self.onSwitchover( None )
      mg.close( _finished )

   def onSwitchover( self, protocol ):
      self.securityConfigReactor_ = SecurityConfigReactor( self.config_, self )
      self.initSslReactor()
      self.warm_ = True

   def initSslReactor( self ):
      t1( "initSslReactor for", self.config_.sslProfile )
      # We should always have an SSL profile but just in case...
      if self.config_.sslProfile:
         self.sslReactor = SecuritySmSslReactor( self.sslStatus,
                                                 self.config_.sslProfile,
                                                 self )
      else:
         if self.sslReactor:
            self.sslReactor.close()
         self.sslReactor = None

   def handleConfigChange( self ):
      self.securityConfigReactor_.sync()

   def warm( self ):
      if not self.active():
         return True
      return self.warm_

def Plugin( context ):
   context.registerService( SecuritySm( context.entityManager ) )
