# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Cell
import ConfigMount
import LazyMount
import Toggles.TrafficPolicyToggleLib

import CliSession
import CliCommand
import CliMatcher
import CliParser
import CliPlugin.VrfCli as VrfCli
from CliMode.TrafficPolicy import TrafficPoliciesConfigMode
from PolicyMapCliLib import PolicyOpChkr
from TrafficPolicyCli import trafficPolicyNameMatcher, TrafficPoliciesConfigCmd

allVrfConfig = None
policiesVrfConfig = None
policiesStatusRequestDir = None
policiesStatus = None
entityManager = None
l3IntfStatus = None
trafficPolicyHwStatus = None

def getCpuTrafficPolicyVrfs( policyName ):
   vrfs = []
   for vrf in policiesVrfConfig.trafficPolicies:
      if policiesVrfConfig.trafficPolicies[ vrf ] == policyName:
         vrfs.append( vrf )
   return vrfs

def getVrfNames( mode ):
   return sorted( allVrfConfig.vrf.members() )

def cpuTrafficPolicySupportedGuard( mode, token ):
   if trafficPolicyHwStatus.cpuTrafficPolicySupported:
      return None
   return CliParser.guardNotThisPlatform

cpuNode = CliCommand.Node(
      matcher=CliMatcher.KeywordMatcher( 'cpu',
            helpdesc='Configure CPU traffic policy' ),
      guard=cpuTrafficPolicySupportedGuard )

# The "[no|default] cpu traffic-policy <name>
# [ vrf (<vrf> [ <vrf2> ...])|all]" command
#---------------------------------------------------------------
class CpuTrafficPolicyConfigCmd( CliCommand.CliCommandClass ):
   """
   This command replaces the list of VRFs to which we assign traffic-policy-name
   """
   if Toggles.TrafficPolicyToggleLib.toggleTrafficPolicyVrfListToggleEnabled():
      _vrfSyntax = '{ VRF }'
      _vrfData = VrfCli.VrfNameExprFactory( inclDefaultVrf=True )
   else:
      _vrfSyntax = 'VRF'
      _vrfData = CliMatcher.KeywordMatcher(
         'all',
         helpdesc='Configure policy for all VRFs',
         value=lambda mode, match: [ 'all' ] )

   syntax = 'cpu traffic-policy POLICY vrf %s' % _vrfSyntax
   noOrDefaultSyntax = 'cpu traffic-policy POLICY ' \
                       '[ vrf %s ]' % _vrfSyntax
   data = {
      'cpu': cpuNode,
      'traffic-policy': 'Configure traffic policy',
      'POLICY': trafficPolicyNameMatcher,
      'vrf': 'Configure VRF for CPU traffic policy',
      'VRF': _vrfData
   }

   @staticmethod
   def _checkStatus( mode, prevPoliciesVrfConfig ):
      def getVrfStatuses():
         results = []
         for vrf in policiesVrfConfig.trafficPolicies:
            for intfId, status in l3IntfStatus.intfStatus.iteritems():
               intfVrf = status.vrf
               if intfVrf == vrf or vrf == 'all':
                  chkr = PolicyOpChkr( policiesStatusRequestDir, policiesStatus )
                  results.append( chkr.verify( 'intf', str( intfId ) ) )
         return results

      def rollback( result ):
         global policiesVrfConfig
         reason = result.error if result else 'unknown'
         mode.addError( 'Failed to commit traffic-policy : %s' % reason )

         if mode.session_.inConfigSession():
            return reason

         mode.addError( 'Rolling back to previous configuration' )

         # Rollback the new VRFs
         for vrf in policiesVrfConfig.trafficPolicies:
            if vrf not in prevPoliciesVrfConfig:
               del policiesVrfConfig.trafficPolicies[ vrf ]

         for vrf, policy in prevPoliciesVrfConfig.iteritems():
            policiesVrfConfig.trafficPolicies[ vrf ] = policy

         # Rollback of multiple VRFs (read many interfaces) can itself take time, so
         # block the user until this is complete.
         for commitStatus, _ in getVrfStatuses():
            if not commitStatus:
               mode.addError( 'Failed to roll back to previous configuration' )
         return reason

      for commitStatus, result in getVrfStatuses():
         if not commitStatus:
            return rollback( result )

   @staticmethod
   def handler( mode, args ):
      # Save original configuration for rollback
      prevPoliciesVrfConfig = { vrf: policy for vrf, policy in
                                policiesVrfConfig.trafficPolicies.iteritems() }

      name = args[ 'POLICY' ]
      vrfList = args.get( 'VRF' )
      oldVrfList = getCpuTrafficPolicyVrfs( name )

      # Delete any VRFs from the old list that aren't in the new one.
      for oldVrfName in oldVrfList:
         if oldVrfName not in vrfList:
            del policiesVrfConfig.trafficPolicies[ oldVrfName ]

      # Add any vrfs from the new list that aren't in the old one
      for newVrfName in vrfList:
         if newVrfName not in oldVrfList:
            policiesVrfConfig.trafficPolicies[ newVrfName ] = name

      if mode.session_.inConfigSession():
         handler = lambda mode, onSessionCommit=True: \
               CpuTrafficPolicyConfigCmd._checkStatus( mode, prevPoliciesVrfConfig )
         CliSession.registerSessionOnCommitHandler(
               mode.session_.entityManager, "cpu-traffic-policy", handler )
         return

      if not ( mode.session_.startupConfig() or
               mode.session_.isStandalone() ):
         CpuTrafficPolicyConfigCmd._checkStatus( mode, prevPoliciesVrfConfig )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      name = args[ 'POLICY' ]
      vrfList = args.get( 'VRF' )
      oldVrfList = getCpuTrafficPolicyVrfs( name )
      if vrfList is None:
         vrfList = oldVrfList # Clear all VRFs
      for delVrfName in vrfList:
         if delVrfName in oldVrfList:
            del policiesVrfConfig.trafficPolicies[ delVrfName ]

def noTrafficPolicies( mode, args ):
   policiesVrfConfig.trafficPolicies.clear()

# pylint: disable=protected-access
TrafficPoliciesConfigCmd._registerNoHandler( noTrafficPolicies )
# pylint: enable=protected-access

TrafficPoliciesConfigMode.addCommandClass( CpuTrafficPolicyConfigCmd )

def Plugin( em ):
   global policiesStatusRequestDir, policiesStatus
   global policiesVrfConfig
   global entityManager
   global allVrfConfig
   global l3IntfStatus
   global trafficPolicyHwStatus

   policiesRootNode = 'trafficPolicies'
   policiesCellRootNode = 'cell/%d/trafficPolicies' % Cell.cellId()
   statusNode = 'status'
   policiesVrfConfigNode = 'cpu/vrf'
   policiesVrfConfigPath = policiesRootNode + '/' + policiesVrfConfigNode
   policiesVrfConfigType = 'PolicyMap::VrfConfig'
   policiesStatusPath = policiesCellRootNode + '/' + statusNode
   policiesStatusType = 'Tac::Dir'
   statusRequestDirNode = 'statusRequest'
   policiesStatusRequestDirPath = policiesRootNode + '/' + statusRequestDirNode
   policiesStatusRequestDirType = 'PolicyMap::PolicyMapStatusRequestDir'
   entityManager = em

   mountGroup = entityManager.mountGroup()
   policiesStatus = mountGroup.mount( policiesStatusPath, policiesStatusType, 'ri' )
   mountGroup.close( callback=None, blocking=False )

   policiesVrfConfig = ConfigMount.mount( entityManager, policiesVrfConfigPath,
                                          policiesVrfConfigType, 'wi' )
   policiesStatusRequestDir = LazyMount.mount( entityManager,
                                               policiesStatusRequestDirPath,
                                               policiesStatusRequestDirType,
                                               'w' )
   allVrfConfig = LazyMount.mount( entityManager, 'ip/vrf/config',
                                   'Ip::AllVrfConfig', 'r' )
   l3IntfStatus = LazyMount.mount( entityManager, "l3/intf/status",
                                   "L3::Intf::StatusDir", "r" )
   trafficPolicyHwStatus = LazyMount.mount( entityManager,
                                            'trafficPolicies/hardware/status/global',
                                            'TrafficPolicy::HwStatus', 'r' )
