#!/usr/bin/env python2.7
#
# Copyright (c) 2015-2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import time
import threading
import traceback
from copy import deepcopy
from MssPolicyMonitor import Lib
from MssPolicyMonitor.Error import FirewallError, FirewallAPIError
from MssPolicyMonitor.Lib import t0, t1, t2, t3, t4, BYPASS_MODE
from MssPolicyMonitor.PluginLib import ServiceDevice, IPolicyPlugin, IHAStatePlugin


ENABLE_PROFILING = False
threadCount = 0


def getThreadId():
   global threadCount
   threadCount += 1
   return threadCount


def genMonitoringThreadName( deviceSetName, deviceId, aggMgrId='' ):
   name = 't#%s:%s:%s' % ( getThreadId(), deviceSetName, deviceId )
   return '%s:via:%s' % ( name, aggMgrId ) if aggMgrId else name


def genDeviceMonitor( serviceDeviceType, deviceConfig, sysdbPolicyMgr ):
   ''' Returns a reference to a DeviceMonitor object to be used for
       starting and stopping the device polling thread.
       Called by MssPolicyMonitorAgent.
   '''
   mpmPlugin = Lib.getPlugin( serviceDeviceType )
   t4('startMonitoringPolicies using plugin:', mpmPlugin )
   if not mpmPlugin:
      return None

   threadName = genMonitoringThreadName( deviceConfig[ 'deviceSet' ],
                                         deviceConfig[ 'ipAddress' ] )
   return DeviceMonitor( mpmPlugin, deviceConfig, sysdbPolicyMgr, threadName )


def startMonitoringPolicies( devMonitor ):
   ''' Starts the policy monitoring thread on the passed DeviceMonitor.
       Called by MssPolicyMonitorAgent.
   '''
   if devMonitor.deviceConfig[ 'isAggregationMgr' ]:
      mpmThread = threading.Thread( name=devMonitor.instanceThreadName,
                                    target=devMonitor.accessDevicesViaAggMgr )
   else:
      mpmThread = threading.Thread( name=devMonitor.instanceThreadName,
                                    target=devMonitor.monitorDevicePolicies,
                                    args=(devMonitor.deviceConfig[ 'ipAddress' ],) )
   mpmThread.setDaemon( True )
   t0('++ START monitoring for', devMonitor.deviceType, 'thread:',
      devMonitor.instanceThreadName )
   mpmThread.start()


def stopMonitoringPolicies( monitorInstance ):
   ''' Stops policy monitoring for the passed MPM instance.
       Called by MssPolicyMonitorAgent. 
   '''
   if monitorInstance:
      monitorInstance.stopRunning()

####################################################################################
class DeviceMonitor( object ):
   ''' Monitor service device policies and associated links to Arista switch(es)
   '''
   def __init__( self, mpmPlugin, deviceConfig, agentSysdbMgr, instanceThreadName ):
      t2('DeviceMonitor init deviceConfig:', Lib.hidePassword( deviceConfig ) )
      self.mpmAgent = agentSysdbMgr
      self.mpmPlugin = mpmPlugin
      self.deviceSetName = deviceConfig[ 'deviceSet' ]
      self.ipAddr = deviceConfig[ 'ipAddress' ]
      self.deviceType = deviceConfig[ 'serviceDeviceType' ]
      self.policyTags = deviceConfig[ 'policyTags' ]
      self.pollInterval = deviceConfig[ 'queryInterval' ]
      self.deviceConfig = deviceConfig
      self.keepRunning = True
      self.instanceThreadName = instanceThreadName
      self.serviceDevices = {}  # service devices to monitor, key = ip or serial#
      self.AggMgrPolicies = {}  # key=service device IP, value=SvcDevicePolicy list
      self.errorLogs = {}
      self.AggMgrPoliciesLock = threading.Lock()

   def __str__( self ):
      return '<%s>:%s:%s_%s' % ( self.__class__.__name__, self.deviceSetName,
                               self.ipAddr, self.instanceThreadName )

   def stopRunning( self ):
      self.keepRunning = False
      t1('keepRunning:', self.keepRunning, 'for:', self.instanceThreadName )

   def monitorDevicePolicies( self, deviceId, isAccessedViaAggrMgr=False ):
      ''' Monitor policies on a service device.
            deviceId may be an IP address, dns name, serial number etc.
      '''
      threadName = threading.current_thread().name
      t1('monitorDevicePolicies for device', deviceId, 'runs on thread', threadName )
      if deviceId not in self.serviceDevices:
         self.serviceDevices[ deviceId ] = ServiceDevice( deviceId, self.deviceType,
                                                          threadName )
      device = self.serviceDevices[ deviceId ]
      device.deviceSetName = self.deviceSetName
      interfaces = None
      routes = None
      if ENABLE_PROFILING:
         import cProfile
         profiler = cProfile.Profile()
         profiler.enable()
      while self.keepRunning:
         startTime = time.time()
         try:
            Lib.checkSslProfileStatus( self.deviceConfig, self.mpmAgent.sslStatus )
            if not device.isCurrent:  # happens when dev removed from a group
               self.removeNonCurrentDevice( device )
               break
            elif not device.initComplete():
               self.initServiceDevice( deviceId, isAccessedViaAggrMgr )
               if not self.keepRunning:  # many restarts possible on eos config load
                  break
            t0('@ polling', device.name, device.mgmtIp, 'thread', threadName )
            if device.isSingleLogicalDeviceHaModel:
               device.setDeviceInfo( device.plugin.getDeviceInfo() )
            if isinstance( device.plugin, IHAStatePlugin ):
               haState = device.plugin.getHighAvailabilityState()
               t2( device.name, haState )
               device.haPeerMgmtIp = haState.getPeerManagementIp()
               isHaPassiveOrSecondary = haState.isHaPassiveOrSecondary()
            else:
               isHaPassiveOrSecondary = False

            if isHaPassiveOrSecondary:
               t1( deviceId, 'in HA Mode but not Active/Primary device, ignoring')
               self.mpmAgent.updatePolicies( {}, [], {}, None, device,
                                             haPrimary=False )
            else:
               if not self.keepRunning:
                  break
               if isinstance( device.plugin, IPolicyPlugin ):  # e.g. PAN, FNET
                  policies = device.plugin.getPolicies( self.policyTags )
               else:  # e.g. CheckPoint
                  with self.AggMgrPoliciesLock:
                     # Wrap policies into a dummy vsys named 'root'
                     # This should be fixed once Checkpoint provides an API to
                     # manage virtual system
                     policies = { 'root' : deepcopy(
                        self.AggMgrPolicies.get( device.mgmtIp, [] ) ) }
                  populateZoneIntfStatus( policies, device.plugin )

               neighbors = device.plugin.getInterfaceNeighbors()
               interfaces = device.plugin.getInterfacesInfo()
               routes = device.plugin.getDeviceRoutingTables()
               if not self.keepRunning:  # check again, can be several seconds later
                  break

               # Error recovery
               self.mpmAgent.firewallLogger.recover( deviceId )
               
               self.mpmAgent.updatePolicies( policies, interfaces, neighbors,
                                             routes, device )
            self.mpmAgent.updateStatus( deviceId, device.threadName )
         except FirewallError as fwError:
            # Log the error if not previously logged
            self.mpmAgent.firewallLogger.log( deviceId, fwError )
            self.handleMonitoringError( device )
         except Exception as ex:  # pylint: disable=W0703
            t0( 'Error:', ex, ' device monitoring will resume on next interval' )
            traceback.print_exc()
            self.handleMonitoringError( device )

         t2( device.name, 'cycle', delta( startTime ), 'threads',
             [ t.name for t in threading.enumerate() ] )
         time.sleep( self.pollInterval )

      t1('-- monitorDevicePolicies stopping thread:', threading.current_thread() )
      self.closeDeviceApiConnection( device.plugin )
      if ENABLE_PROFILING:
         profiler.disable()
         profiler.dump_stats( 'cprofile_%s' % threadName )

   def accessDevicesViaAggMgr( self ):
      ''' Launches a thread for each service device in an aggregation manager group
      '''
      group = self.deviceConfig[ 'group' ]
      t1('accessDevicesViaAggMgr running on thread:', threading.current_thread(),
         'for device group:', group )
      aggMgrThreadName = threading.current_thread().name
      aggMgrId = self.ipAddr  # IP addr or DNS name
      aggMgrPlugin = None
      mgmtIp = ''
      while self.keepRunning:
         try:
            Lib.checkSslProfileStatus( self.deviceConfig, self.mpmAgent.sslStatus )
            if not aggMgrPlugin:
               aggMgrPlugin = self.mpmPlugin.getAggMgrPluginObj( self.deviceConfig )
            if not mgmtIp:
               devInfo = aggMgrPlugin.getDeviceInfo()
               t1('aggMgr deviceInfo:', devInfo )
               if 'ipAddr' not in devInfo:
                  # An API error must have occured for this information to be missing
                  t2( 'deviceInfo is incomplete: ipAddr is missing' )
                  raise FirewallAPIError( 200, None )
               mgmtIp = devInfo[ 'ipAddr' ]
            if isinstance( aggMgrPlugin, IPolicyPlugin ):
               with self.AggMgrPoliciesLock:
                  t2('* aggMgr', aggMgrThreadName, 'get policies and group members')
                  self.AggMgrPolicies = aggMgrPlugin.getPolicies( self.policyTags )
                  groupMembers = self.AggMgrPolicies.keys()  # devices ref'd in rules
            else:
               t3('* aggMgr', aggMgrThreadName, 'checking device group members')
               groupMembers = aggMgrPlugin.getAggMgrGroupMembers( group )

            if not self.keepRunning:  # check again here
               break
            t1( aggMgrId, 'group:', group, 'currentMembers:', groupMembers )
            previousDevices = set( self.serviceDevices.keys() )
            for memberId in groupMembers:
               if memberId in self.serviceDevices:
                  previousDevices.discard( memberId )  # remove current id
               else:
                  self.initMemberDevice( memberId, aggMgrId, aggMgrThreadName )

            t3('previous members no longer in group:', previousDevices )
            for memberId in previousDevices:  # devs that are no longer members
               try:
                  self.serviceDevices[ memberId ].isCurrent = False
                  if isinstance( aggMgrPlugin, IPolicyPlugin ):
                     with self.AggMgrPoliciesLock:
                        self.AggMgrPolicies[ memberId ] = []  # clear policies
               except KeyError:
                  pass  # ignore, may have just been removed by device thread
               self.mpmAgent.deleteGroupMember( memberId, aggMgrId )

            # Error recovery
            self.mpmAgent.firewallLogger.recover( aggMgrId )

            self.mpmAgent.updateStatus( aggMgrId, aggMgrThreadName, mgmtIp=mgmtIp )
         except FirewallError as fwError:
            # Log the error if not previously logged
            self.mpmAgent.firewallLogger.log( aggMgrId, fwError )
         except Exception as ex:  # pylint: disable=W0703
            t0( 'Error:', ex, ' device monitoring will resume on next interval' )
            traceback.print_exc()

         t3( aggMgrThreadName, 'aggMgr cycle complete')
         time.sleep( self.pollInterval )
      t1('-- accessDevicesViaAggMgr exiting thread:', threading.current_thread() )
      self.closeDeviceApiConnection( aggMgrPlugin )

   def initMemberDevice( self, memberId, aggMgrId, aggMgrThreadName ):
      memberThreadName = genMonitoringThreadName( self.deviceSetName, memberId,
                                                  aggMgrId )
      self.serviceDevices[ memberId ] = ServiceDevice( memberId, self.deviceType,
                                                       memberThreadName )
      self.mpmAgent.addGroupMember( memberId, memberThreadName,
                                    aggMgrId, aggMgrThreadName )
      t1('+ START monitor thread for group member device: ', memberId )
      devThread = threading.Thread( name=memberThreadName,
                                    target=self.monitorDevicePolicies,
                                    args=( memberId, True ) )
      devThread.setDaemon( True )
      devThread.start()

   def removeNonCurrentDevice( self, device ):
      deviceId = device.deviceId
      t1('serviceDevice', deviceId, 'not current, deleting' )
      self.mpmAgent.cleanupPoliciesForDevice( device )
      serviceDevice = self.serviceDevices[ deviceId ]
      self.closeDeviceApiConnection( serviceDevice.plugin )
      del self.serviceDevices[ deviceId ]

   def closeDeviceApiConnection( self, plugin ):
      try:
         t4('closing device API connection')
         plugin.closeApiConnection()
      except Exception as ex:  # pylint: disable=W0703
         t0('ignoring error while closing device API connection:', ex )

   def initServiceDevice( self, deviceId, isAccessedViaAggrMgr ):
      t3( '@initSvcDevice id:', deviceId, 'ip:', self.deviceConfig[ 'ipAddress' ],
          'accessViaAggMgr:', isAccessedViaAggrMgr )
      device = self.serviceDevices[ deviceId ]
      if isAccessedViaAggrMgr:
         # Copy configuration from device member and add virtual system
         # into the aggregation manager configuration which is provided to the plugin
         config = self.mpmAgent.getDeviceConfig( self.deviceSetName, deviceId )
         deviceConfig = self.deviceConfig.copy()
         if config:
            deviceConfig[ 'virtualInstance' ] = config[ 'virtualInstance' ]
            deviceConfig[ 'vrouters' ] = config[ 'vrouters' ]
         device.plugin = self.mpmPlugin.getPluginObj( deviceConfig, deviceId )
      else:
         device.plugin = self.mpmPlugin.getPluginObj( self.deviceConfig )
      devInfo = device.plugin.getDeviceInfo()
      t2( 'svcDeviceInfo:', devInfo, 'id:', deviceId, 'ip:',
          self.deviceConfig[ 'ipAddress' ] )
      device.setDeviceInfo( devInfo )

   def handleMonitoringError( self, device ):
      t1('exceptionHandlingMode:', self.deviceConfig[ 'exceptionMode' ] )
      self.mpmAgent.updateStatus( device.deviceId, device.threadName, error=True )
      if self.deviceConfig[ 'exceptionMode' ] == BYPASS_MODE:
         self.mpmAgent.cleanupPoliciesForDevice( device )


def populateZoneIntfStatus( policies, plugin ):
   ''' Populate policy zone interface status with latest service device
       interface status.
   '''
   intfsInfo = plugin.getInterfacesInfo( resolveZoneNames=False )
   for vsys, policyList in policies.iteritems():
      for policy in policyList:
         for polZoneIntfs, selector in [ ( policy.srcZoneInterfaces, 'S' ),
                                         ( policy.dstZoneInterfaces, 'D' ) ]:
            updatedZoneIntfs = []
            for polZoneIntf in polZoneIntfs:
               if polZoneIntf.name in intfsInfo:
                  intfsInfo[ vsys ][ polZoneIntf.name ].zone = polZoneIntf.zone
                  updatedZoneIntfs.append( intfsInfo[ vsys ][ polZoneIntf.name ] )
            if selector == 'S':
               policy.srcZoneInterfaces = updatedZoneIntfs
            elif selector == 'D':
               policy.dstZoneInterfaces = updatedZoneIntfs


def delta( startTime ):
   return  'time=%.2fs' % ( time.time() - startTime )
