# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import Logging

CVX_MSS_FIREWALL_ACCESS_ERROR = None
Logging.logD( id='CVX_MSS_FIREWALL_ACCESS_ERROR', severity=Logging.logError,
              format='Unable to access firewall %s: %s',
              explanation='Unable to access firewall',
              recommendedAction='Check firewall credentials and ' +
                                'network reachability, ' +
                                'contact Arista support as needed' )

CVX_MSS_FIREWALL_ACCESS_RECOVERED = None
Logging.logD( id='CVX_MSS_FIREWALL_ACCESS_RECOVERED', severity=Logging.logError,
              format='Access to firewall %s is restored',
              explanation='Access to firewall is restored',
              recommendedAction=Logging.NO_ACTION_REQUIRED )

CVX_MSS_FIREWALL_API_ERROR = None
Logging.logD( id='CVX_MSS_FIREWALL_API_ERROR', severity=Logging.logError,
              format='Firewall %s API request failed: ' +
                     'HTTP status %d, error code %s ',
              explanation='Firewall API request failed',
              recommendedAction=Logging.CALL_SUPPORT_IF_PERSISTS )

CVX_MSS_FIREWALL_API_RECOVERED = None
Logging.logD( id='CVX_MSS_FIREWALL_API_RECOVERED', severity=Logging.logError,
              format='Firewall %s API request succeeded',
              explanation='Firewall API request succeeded',
              recommendedAction=Logging.NO_ACTION_REQUIRED )

CVX_MSS_SKIPPING_POLICY = None
Logging.logD( id='CVX_MSS_SKIPPING_POLICY', severity=Logging.logWarning,
              format='Skipping policy %s from firewall %s, virtual instance %s: %s',
              explanation='A tagged policy is ignored by MSS',
              recommendedAction='Check policy config, ' +
                                'contact Arista support as needed' )

class MssSkippingPolicyLogger( object ):
   def __init__( self ):
      self.logSet = {}

   def log( self, policyName, deviceId, vinst, msg ):
      '''
      Log policy error if not previously logged
      '''
      loggedPolicySet = self.logSet.setdefault( ( deviceId, vinst ), {} )
      if loggedPolicySet.get( policyName ) != msg:
         loggedPolicySet[ policyName ] = msg
         Logging.log( CVX_MSS_SKIPPING_POLICY, policyName, deviceId, vinst, msg )

   def clearLog( self, policyName, deviceId, vinst ):
      '''
      Valid policy is now removed from the log set
      '''
      loggedPolicySet = self.logSet.get( ( deviceId, vinst ) )
      if loggedPolicySet:
         loggedPolicySet.pop( policyName, None )

   def cleanupLoggedPolicies( self, deviceId, vinst, validPolicyNameSet ):
      '''
      Cleanup policies that have been removed from the config
      '''
      loggedPolicySet = self.logSet.get( ( deviceId, vinst ), {} )
      for loggedPolicy in list( loggedPolicySet ):
         if loggedPolicy not in validPolicyNameSet:
            del loggedPolicySet[ loggedPolicy ]

   def cleanupLoggedVInst( self, deviceId, validVInst ):
      '''
      Cleanup old virtual instances
      '''
      for devId, vinst in list( self.logSet ):
         if devId == deviceId and vinst not in validVInst:
            del self.logSet[ ( devId, vinst ) ]

   def cleanupLoggedDevice( self, validDeviceIds ):
      '''
      Cleanup old devices
      '''
      for devId, vinst in list( self.logSet ):
         if devId not in validDeviceIds:
            del self.logSet[ ( devId, vinst ) ]

class MssFirewallLogger( object ):
   def __init__( self ):
      self.deviceLog = {}

   def log( self, deviceId, errorObj ):
      if errorObj != self.deviceLog.get( deviceId ):
         errorObj.log( deviceId )
         self.deviceLog[ deviceId ] = errorObj

   def recover( self, deviceId ):
      errorObj = self.deviceLog.pop( deviceId, None )
      if errorObj:
         errorObj.recover( deviceId )

   def cleanupLoggedDevice( self, validDeviceIds ):
      '''
      Cleanup old devices
      '''
      for devId in list( self.deviceLog ):
         if devId not in validDeviceIds:
            del self.deviceLog[ devId ]
