#!/usr/bin/env python
# Copyright (c) 2016 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
#

import Tac
import Toggles.MssToggleLib
import CliSave

from CliSavePlugin.Controllerdb import ( CvxConfigMode,
                                         controllerConfigPath,
                                         getClusterName )
from CliSavePlugin.MssCliSave import MssConfigSaveMode
from CliMode.MssPolicyMonitor import ( MssPolicyMonitorDynamic,
                                       MssPolicyMonitorDevice,
                                       MssPolicyMonitorVInst,
                                       MssPolicyMonitorVrf,
                                       MssPolicyMonitorRoute )
from MssPolicyMonitor.Lib import ( PANORAMA_PLUGIN,
                                   PAN_FW_PLUGIN,
                                   FORTIMGR_PLUGIN,
                                   CHKP_MS_PLUGIN,
                                   DEFAULT_POLICY_TAG,
                                   t0 )

mssTypeMap = { CHKP_MS_PLUGIN: 'check-point management-server',
               PANORAMA_PLUGIN: 'palo-alto panorama',
               FORTIMGR_PLUGIN: 'fortinet fortimanager',
               PAN_FW_PLUGIN: 'palo-alto firewall' }
defaults = Tac.Value( 'MssPolicyMonitor::CliDefaults' )
mssDefaults = Tac.Value( 'Mss::CliDefaults' )

class MssPolicyMonitorDynamicConfigMode ( MssPolicyMonitorDynamic, CliSave.Mode ):
   def __init__( self, param ):
      t0( 'instantiating dynamic in saveplugin' )
      MssPolicyMonitorDynamic.__init__( self, param )
      CliSave.Mode.__init__( self, param )

MssConfigSaveMode.addChildMode(
   MssPolicyMonitorDynamicConfigMode,
   after=[ 'mss.config' ] )
MssPolicyMonitorDynamicConfigMode.addCommandSequence(
   'mss.dynamic.config' )

class MssPolicyMonitorDeviceConfigMode ( MssPolicyMonitorDevice, CliSave.Mode ):
   def __init__( self, param ):
      t0( 'instantiating device in saveplugin with param')
      t0( param )
      MssPolicyMonitorDevice.__init__( self, param )
      CliSave.Mode.__init__( self, param )

MssPolicyMonitorDynamicConfigMode.addChildMode(
   MssPolicyMonitorDeviceConfigMode )
MssPolicyMonitorDeviceConfigMode.addCommandSequence(
   'mss.dynamic.device.config' )

class MssPolicyMonitorVInstConfigMode( MssPolicyMonitorVInst, CliSave.Mode ):
   def __init__( self, param ):
      t0( 'instantiating virtual instance in saveplugin with param' )
      t0( param )
      MssPolicyMonitorVInst.__init__( self, param )
      CliSave.Mode.__init__( self, param )

MssPolicyMonitorDeviceConfigMode.addChildMode(
   MssPolicyMonitorVInstConfigMode )
MssPolicyMonitorVInstConfigMode.addCommandSequence(
   'mss.dynamic.device.vinst.config' )

class MssPolicyMonitorVrfConfigMode ( MssPolicyMonitorVrf, CliSave.Mode ):
   def __init__( self, param ):
      t0( 'instantiating vrf in saveplugin with param' )
      t0( param )
      MssPolicyMonitorVrf.__init__( self, param )
      CliSave.Mode.__init__( self, param )

MssPolicyMonitorDeviceConfigMode.addChildMode(
   MssPolicyMonitorVrfConfigMode )
MssPolicyMonitorVrfConfigMode.addCommandSequence(
   'mss.dynamic.device.vrf.config' )

class MssPolicyMonitorRouteConfigMode ( MssPolicyMonitorRoute, CliSave.Mode ):
   def __init__( self, param ):
      t0( 'instantiating route in saveplugin with param' )
      t0( param )
      MssPolicyMonitorRoute.__init__( self, param )
      CliSave.Mode.__init__( self, param )

MssPolicyMonitorVrfConfigMode.addChildMode(
   MssPolicyMonitorRouteConfigMode )
MssPolicyMonitorRouteConfigMode.addCommandSequence(
   'mss.dynamic.device.vrf.route.config' )

@CliSave.saver( 'MssPolicyMonitor::Config', 'msspolicymonitor/config',
                requireMounts=( controllerConfigPath, ) )
def saveMssPolicyMonitor( entity, root, SysdbRoot, options, requireMounts ):
   t0( options )
   for devset in entity.deviceSet.keys( ):
      cmds = []
      addState( entity, options, cmds, devset )
      addType( entity, options, cmds, devset )
      addPolicyRedirectTag( entity, options, cmds, devset )
      addPolicyOffloadTag( entity, options, cmds, devset )
      addPolicyModifierTag( entity, options, cmds, devset )
      addQueryInterval( entity, options, cmds, devset )
      addRetries( entity, options, cmds, devset )
      addTimeout( entity, options, cmds, devset )
      if Toggles.MssToggleLib.toggleMssL3V2Enabled():
         addTrafficInspection( entity, options, cmds, devset )
      addExceptionHandling( entity, options, cmds, devset )
      addAdminDomain( entity, options, cmds, devset )
      if not Toggles.MssToggleLib.toggleMssL3V2Enabled():
         addVirtualDomain( entity, options, cmds, devset )
      serviceDevice = entity.deviceSet[ devset ].serviceDevice

      t0( 'adding cluster' )
      clusterName = getClusterName( requireMounts[ controllerConfigPath ] )
      cvx = root[ CvxConfigMode ].getOrCreateModeInstance(
         CvxConfigMode.modeName( clusterName ) )
      cvxMss = cvx[ MssConfigSaveMode ].getOrCreateModeInstance( 'cvx-mss' )

      t0( 'defining cmdSeq for dynamic' )
      dynamic = cvxMss[ MssPolicyMonitorDynamicConfigMode ].\
                getOrCreateModeInstance( devset )
      cmdSeqDyn = dynamic[ 'mss.dynamic.config' ]

      deviceType = entity.deviceSet[ devset ].policySourceType

      # import pdb; pdb.set_trace()
      for cmd in cmds:
         t0( 'adding cmdSeq for dynamic' )
         cmdSeqDyn.addCommand( cmd )
      for servDev in serviceDevice:
         t0( 'service device exists!' )
         cmds = []
         addDevices( cmds, servDev, serviceDevice, options )

         t0( 'creating command seq for device' )
         dynamicConfigMode = dynamic[ MssPolicyMonitorDeviceConfigMode ]
         isAccessedViaAggrMgr = serviceDevice[ servDev ].isAccessedViaAggrMgr
         getOrCreateTuple = ( servDev, devset, isAccessedViaAggrMgr )
         devConfigMode = dynamicConfigMode.getOrCreateModeInstance(
            getOrCreateTuple )
         cmdSeqDev = devConfigMode[ 'mss.dynamic.device.config' ]

         for cmd in cmds:
            t0( 'add command on cmdSeqDev' )
            cmdSeqDev.addCommand( cmd )

         if Toggles.MssToggleLib.toggleMssL3V2Enabled():
            addVInstCommands( devConfigMode, cmdSeqDev, servDev, devset,
                              serviceDevice[ servDev ], deviceType, options )

         addVrfCommands( devConfigMode, cmdSeqDev, servDev, devset,
                         serviceDevice[ servDev ], options )

def addVInstCommands( deviceConfigMode, cmdSeqDev, sdName, dsName,
                      deviceConfig, deviceType, options ):
   vinstConfig = deviceConfig.virtualInstance
   if vinstConfig:
      for vinstName, vinst in vinstConfig.iteritems():
         vinstMode = deviceConfigMode[ MssPolicyMonitorVInstConfigMode ].\
                    getOrCreateModeInstance( ( dsName, sdName, vinstName ) )
         cmdSeqVInst = vinstMode[ 'mss.dynamic.device.vinst.config' ]

         # traffic inspection
         # By default, traffic inspection isn't set
         if vinst.trafficInspection.local:
            cmd = 'traffic inspection local'
            if vinst.trafficInspection.outbound:
               cmd += ' outbound'
            cmdSeqVInst.addCommand( cmd )

         # If no network vrf is configured, the default network vrf will be
         # associated with the default virtual router
         if vinst.firewallVrf:
            for fwVrf, networkVrf in vinst.firewallVrf.iteritems():
               if deviceType == PAN_FW_PLUGIN or deviceType == PANORAMA_PLUGIN:
                  keyword = 'virtual-router'
               else:
                  keyword = 'vrf id'

               cmdSeqVInst.addCommand( 'network vrf %s %s %s' % (
                                       networkVrf, keyword, fwVrf ) )
         elif options.saveAll:
            addVInstDefaultCommands( vinstMode, deviceType )
   elif options.saveAll:
      # If no virtual instance is configured, the default virtual instance
      # and default router are still present
      if deviceType == PAN_FW_PLUGIN or deviceType == PANORAMA_PLUGIN:
         vinstMode = deviceConfigMode[ MssPolicyMonitorVInstConfigMode ].\
                        getOrCreateModeInstance( ( dsName, sdName, 'vsys1' ) )
         addVInstDefaultCommands( vinstMode, deviceType )
      elif deviceType == FORTIMGR_PLUGIN:
         vinstMode = deviceConfigMode[ MssPolicyMonitorVInstConfigMode ].\
                        getOrCreateModeInstance( ( dsName, sdName, 'root' ) )
         addVInstDefaultCommands( vinstMode, deviceType )

def addVInstDefaultCommands( vinstMode, deviceType ):
   cmdSeqVInst = vinstMode[ 'mss.dynamic.device.vinst.config' ]
   if deviceType == PAN_FW_PLUGIN or deviceType == PANORAMA_PLUGIN:
      cmdSeqVInst.addCommand( 'network vrf default virtual-router default' )   
   elif deviceType == FORTIMGR_PLUGIN:
      cmdSeqVInst.addCommand( 'network vrf default vrf id  0' )

def addVrfCommands( deviceConfigMode, cmdSeqDev, sdName, dsName,
                    deviceConfig, options ):
   vrfConfig = deviceConfig.vrfConfig
   for vrfName in vrfConfig:
      vrf = vrfConfig[ vrfName ]
      vrfMode = deviceConfigMode[ MssPolicyMonitorVrfConfigMode ].\
                getOrCreateModeInstance( ( dsName, sdName, vrfName ) )
      cmdSeqVrf = vrfMode[ 'mss.dynamic.device.vrf.config' ]
      excludeList = ''
      for subnet in vrf.hostExcludeList:
         excludeList += str( subnet ) + ' '
      if excludeList:
         cmdSeqVrf.addCommand( 'security policy ipv4 exclude-list ' + excludeList )
      elif options.saveAll:
         cmdSeqVrf.addCommand( 'no security policy ipv4 exclude-list' )
      for ipv4 in vrf.routingTable:
         ipv4Mode = vrfMode[ MssPolicyMonitorRouteConfigMode ].\
                    getOrCreateModeInstance( ( dsName, sdName,
                                               vrfName, str( ipv4 ) ) )
         cmdSeqIpv4 = ipv4Mode[ 'mss.dynamic.device.vrf.route.config' ]
         for subnet in vrf.routingTable[ ipv4 ].route:
            cmdSeqIpv4.addCommand( 'route ' + str( subnet ) )

def addDevices( cmds, servDev, serviceDevice, options ):
   group = serviceDevice[ servDev ].group
   passwd = CliSave.sanitizedOutput( options,
                                     serviceDevice[ servDev ].encryptedPassword )
   protocol = serviceDevice[ servDev ].protocol
   protocolPortNum = serviceDevice[ servDev ].protocolPortNum
   sslProfileName = serviceDevice[ servDev ].sslProfileName
   username = serviceDevice[ servDev ].username
   managementVDom = serviceDevice[ servDev ].mgmtIntfVirtualDomain

   if passwd:
      t0( 'appending password' )
      cmds.append( 'username %s password 7 %s' % ( username, passwd ) )
   elif options.saveAllDetail:
      cmds.append( 'no username' )

   if options.saveAll or protocolPortNum != defaults.protocolPortNum or \
      sslProfileName != defaults.sslProfileName :
      if options.saveAllDetail and protocolPortNum == defaults.protocolPortNum and \
         sslProfileName == defaults.sslProfileName :
         cmds.append( 'no protocol' )
      else:
         if sslProfileName == defaults.sslProfileName :
            cmds.append( 'protocol %s %s' % ( protocol, str( protocolPortNum ) ) )
         else:
            cmds.append( 'protocol %s %s ssl profile %s' %
                         ( protocol, str( protocolPortNum ), sslProfileName ) )

   if group:
      t0( 'appending device group' )
      cmds.append( 'group %s' % ( group ) )
   elif options.saveAllDetail:
      cmds.append( 'no group' )

   if managementVDom:
      if options.saveAll or managementVDom != defaults.mgmtIntfVirtualDomain:
         if options.saveAllDetail and managementVDom == \
            defaults.mgmtIntfVirtualDomain:
            cmds.append( 'no management virtual domain' )
         else:
            t0( 'appending management virtual domain' )
            cmds.append( 'management virtual domain %s' % ( managementVDom ) )

   for e in sorted( serviceDevice[ servDev ].intfMap.keys() ):
      t0( 'appending intf maps' )
      intfmap = serviceDevice[ servDev ].intfMap[ e ]
      switchChassisId = intfmap.switchChassisId
      switchIntf = intfmap.switchIntf
      cmds.append( 'map device-interface %s switch %s interface %s' % (
         e, switchChassisId, switchIntf ) )

def addType( entity, options, cmds, devset ):
   if options.saveAll or options.saveAllDetail:
      if entity.deviceSet[ devset ].policySourceType != defaults.policySourceType:
         cmds.append( 'type %s'
                      % ( mssTypeMap[ entity.deviceSet[ devset ].\
                                      policySourceType ] ) )
      else:
         cmds.append( 'no type' )
   else:
      if entity.deviceSet[ devset ].policySourceType != defaults.policySourceType:
         cmds.append( 'type %s'
                      % ( mssTypeMap[ entity.deviceSet[ devset ].\
                                      policySourceType ] ) )

def addExceptionHandling( entity, options, cmds, devset ):
   if options.saveAll or \
      entity.deviceSet[ devset ].exceptionHandling != defaults.exceptionHandling:
      cmds.append( 'exception device unreachable %s' % (
         entity.deviceSet[ devset ].exceptionHandling ) )
      if options.saveAllDetail and \
         entity.deviceSet[ devset ].exceptionHandling == defaults.exceptionHandling:
         cmds.append( 'no exception device unreachable' )

def addAdminDomain( entity, options, cmds, devset ):
   if options.saveAll or \
      entity.deviceSet[ devset ].adminDomain != defaults.adminDomain:
      cmds.append( 'admin domain %s' % entity.deviceSet[ devset ].adminDomain )

      if options.saveAllDetail and \
         entity.deviceSet[ devset ].adminDomain == defaults.adminDomain:
         cmds.append( 'no admin domain' )

def addVirtualDomain( entity, options, cmds, devset ):
   if options.saveAll or \
      entity.deviceSet[ devset ].virtualDomain != defaults.virtualDomain:
      cmds.append( 'virtual domain %s' % entity.deviceSet[ devset ].virtualDomain )

      if options.saveAllDetail and \
         entity.deviceSet[ devset ].virtualDomain == defaults.virtualDomain:
         cmds.append( 'no virtual domain' )

def quotifyTag( tag ):
   if ' ' in tag:
      return '"' + tag + '" '
   else:
      return tag + ' '

def addPolicyRedirectTag( entity, options, cmds, devset ):
   if options.saveAll or \
      entity.deviceSet[ devset ].policyTag.keys() != [ DEFAULT_POLICY_TAG ]:
      tagStr = ''
      for tag in entity.deviceSet[ devset ].policyTag.iterkeys():
         tagStr += quotifyTag( tag )

      cmds.append( 'policy tag redirect %s' % ( tagStr ) )

   if options.saveAllDetail:
      if entity.deviceSet[ devset ].policyTag.keys() == [ DEFAULT_POLICY_TAG ]:
         cmds.append( 'default policy tag redirect' )

def addPolicyOffloadTag( entity, options, cmds, devset ):
   if options.saveAll or entity.deviceSet[ devset ].offloadTag.keys():
      tagStr = ''
      for tag in entity.deviceSet[ devset ].offloadTag.iterkeys():
         tagStr += quotifyTag( tag )

      if tagStr:
         cmds.append( 'policy tag offload ' + tagStr )
      else:
         cmds.append( 'no policy tag offload' )

def addPolicyModifierTag( entity, options, cmds, devset ):
   deviceSet = entity.deviceSet[ devset ]
   if options.saveAll or deviceSet.modifierTag.keys():
      for tagType, modifierTag in deviceSet.modifierTag.iteritems():
         tagStr = ''
         for tag in modifierTag.tag.iterkeys():
            tagStr += quotifyTag( tag )

         if tagStr:
            cmds.append( 'policy tag modifier ' + tagType + ' ' + tagStr )
         else:
            cmds.append( 'no policy tag modifier ' + tagType )

def addQueryInterval( entity, options, cmds, devset ):
   if options.saveAll or \
      entity.deviceSet[ devset ].queryInterval != defaults.queryInterval:
      cmds.append( 'interval %s' % ( entity.deviceSet[ devset ].queryInterval ) )

      if options.saveAllDetail and \
         entity.deviceSet[ devset ].queryInterval == defaults.queryInterval:
         cmds.append( 'no interval' )

def addRetries( entity, options, cmds, devset ):
   if options.saveAll or entity.deviceSet[ devset ].retries != defaults.retries:
      cmds.append( 'retries %s' % ( entity.deviceSet[ devset ].retries ) )

      if options.saveAllDetail:
         if entity.deviceSet[ devset ].retries == defaults.retries:
            cmds.append( 'default retries' )
         if entity.deviceSet[ devset ].retries == 0:
            cmds.append( 'no retries' )

def addState( entity, options, cmds, devset ):
   if options.saveAll or entity.deviceSet[ devset ].state != defaults.state:
      if entity.deviceSet[ devset ].state == 'dryRun':
         cmds.append( 'state dry-run' )
      else:
         cmds.append( 'state %s' % ( entity.deviceSet[ devset ].state ) )

      if options.saveAllDetail and \
         entity.deviceSet[ devset ].state == defaults.state:
         cmds.append( 'no state' )

def addTimeout( entity, options, cmds, devset ):
   if options.saveAll or entity.deviceSet[ devset ].timeout != defaults.timeout:
      cmds.append( 'timeout %s' % ( entity.deviceSet[ devset ].timeout ) )

      if options.saveAllDetail and \
         entity.deviceSet[ devset ].timeout == defaults.timeout:
         cmds.append( 'no timeout' )

def addTrafficInspection( entity, options, cmds, devset ):
   if ( options.saveAll or entity.deviceSet[ devset ].trafficInspection !=
        mssDefaults.trafficInspection ):
      cmd = 'traffic inspection local'
      if entity.deviceSet[ devset ].trafficInspection.outbound:
         cmd += ' outbound'
      cmds.append( cmd )

def addVerifyCertificate( entity, options, cmds, devset ):
   if options.saveAll or \
      entity.deviceSet[ devset ].verifyCertificate != defaults.verifyCertificate:
      if entity.deviceSet[ devset ].verifyCertificate:
         cmds.append( 'verify certificate' )
      else:
         cmds.append( 'no verify certificate' )
