#!/usr/bin/env python
#Copyright (c) 2010, 2011, 2013 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import ConfigMount
import LazyMount
import PimCliLib
import Tac
from AclCli import getAclConfig
from IpLibConsts import DEFAULT_VRF
from RouterMulticastCliLib import (
      AddressFamily,
      configGetters,
      doConfigMounts,
      getAddressFamilyFromMode,
)

staticRpConfigColl = None
pimsmConfigColl = None
pimBidirStaticRpConfigColl = None
pimConfigRoot = None
_allVrfConfig = None

# Af Independent Config Types
PimsmConfigColl = "Routing::Pim::SparseMode::ConfigColl"

( pimsmConfigRoot,
  pimsmConfigRootFromMode,
  pimsmConfig,
  pimsmConfigFromMode ) = configGetters( PimsmConfigColl,
                                         collectionName='vrfConfig' )

SsmConfigColl = "Routing::Gmp::GmpSsmConfigColl"
( ssmConfigRoot,
  ssmConfigRootFromMode,
  ssmConfig,
  ssmConfigFromMode ) = configGetters( SsmConfigColl,
                                         collectionName='config' )

def getPimsmVrfConfig( vrfName ):
   pimsmVrfConfig = None
   pimsmVrfConfig = pimsmConfigColl.vrfConfig.get( vrfName ) 
   return pimsmVrfConfig

def _pimsmConfigCreation( vrfName ):
   for af in [ AddressFamily.ipv4, AddressFamily.ipv6 ]:
      pimsmConfig( af, vrfName )
      ssmConfig( af, vrfName )

def _pimsmConfigDeletion( vrfName ): 
   for af in [ AddressFamily.ipv4, AddressFamily.ipv6 ]:
      configColl = pimsmConfigRoot( af )
      if vrfName in configColl.vrfConfig:
         config = configColl.vrfConfig[ vrfName ]
         if vrfName != DEFAULT_VRF and config.isDefault():
            # only delete if there is no non-default config
            # and the VRF is not defined
            del configColl.vrfConfig[ vrfName ]

      configCollSsm = ssmConfigRoot( af ) 
      if vrfName in configCollSsm.config:
         config = configCollSsm.config[ vrfName ]
         if vrfName != DEFAULT_VRF and config.isDefault():
            del configCollSsm.config[ vrfName ]
      
def _canDeletePimsmVrfConfig( vrfName ):
   for af in [ AddressFamily.ipv4, AddressFamily.ipv6 ]:
      configColl = pimsmConfigRoot( af )
      if vrfName in configColl.vrfConfig:
         config = configColl.vrfConfig[ vrfName ]
         if not config.isDefault():
            return False
   return True

def _cleanupPimsmConfig( vrfName ):
   for af in [ AddressFamily.ipv4, AddressFamily.ipv6 ]:
      configColl = pimsmConfigRoot( af )
      if vrfName in configColl.vrfConfig:
         configColl.vrfConfig[ vrfName ].reset()

      configCollSsm = ssmConfigRoot( af )
      if vrfName in configCollSsm.config:
         configCollSsm.config[ vrfName ].reset()

# Check if acl rule has a multicast source address
def sourceIsMulticast( rule ):
   source = rule.filter.source

   # Need getRawAttribute, otherwise ipv4 returns a str instead of IpAddr
   address = source.getRawAttribute( "address" )
   return address.isMulticast

def sourceIsAny( rule, af ):
   source = rule.filter.source
   if af == 'ipv4':
      return source == Tac.Value( "Arnet::IpAddrWithFullMask" )
   else:
      return source == Tac.Value( "Arnet::Ip6AddrWithMask" )

def allSourcesAreMulticast( acl, af ):
   if af == 'ipv4':
      rules = acl.currCfg.ipRuleById
   else:
      rules = acl.currCfg.ip6RuleById

   return all( [ sourceIsMulticast( rule ) or sourceIsAny( rule, af )
                 for rule in rules.itervalues() ] )

#-----------------------------------------------------------------------------
# legacy: switch(config)# [no] ip pim ssm default source < groupAddr > [ sourceAddr ]
# (config-af)# [ no ] ssm default source < groupAddr > [ sourceAddr ]

#-----------------------------------------------------------------------------
def setIpPimSsmConvert( mode, args ):
   groupAddr = args[ 'GROUP' ]
   sourceAddr = args[ 'SOURCE' ]
   legacy = 'pim' in args
   pimsmConfig_ = pimsmConfigFromMode( mode, legacy=legacy )
   ssmConfig_ = ssmConfigFromMode( mode, legacy=legacy )
   af = getAddressFamilyFromMode( mode, legacy=legacy )

   if af == AddressFamily.ipv6:
      return

   if pimsmConfig_:
      if not pimsmConfig_.ssmConvertConfig:
         pimsmConfig_.ssmConvertConfig = ()
      config = pimsmConfig_.ssmConvertConfig
   else:
      return

   try:
      ( source, group ) = PimCliLib.ipPimParseSg( groupAddr, sourceAddr )
   except ValueError:
      mode.addErrorAndStop( "Must enter a multicast group and an unicast source" )

   if group not in config.group:
      config.group.newMember( group )
      ssmConfig_.convertGroup.newMember( group )
   config.group[ group ].groupSource[ source ] = True
   ssmConfig_.convertGroup[ group ].groupSource[ source ] = True


def noIpPimSsmConvert( mode, args ):
   groupAddr = args.get( 'GROUP' )
   sourceAddr = args.get( 'SOURCE' )
   legacy = 'pim' in args
   pimsmConfig_ = pimsmConfigFromMode( mode, legacy=legacy )
   ssmConfig_ = ssmConfigFromMode( mode, legacy=legacy )
   af = getAddressFamilyFromMode( mode, legacy=legacy )

   if ( af == AddressFamily.ipv6 or
         not pimsmConfig_ or
         not pimsmConfig_.ssmConvertConfig ):
      return

   config = pimsmConfig_.ssmConvertConfig

   try:
      ( source, group ) = PimCliLib.ipPimParseSg( groupAddr, sourceAddr )
   except ValueError:
      mode.addErrorAndStop( "Must enter a multicast group" )

   if source and group:
      if group not in config.group:
         return 
      else:
         del config.group[ group ].groupSource[ source ]
         del ssmConfig_.convertGroup[ group ].groupSource[ source ]

   elif group:
      if group not in config.group:
         return
      else:
         del config.group[ group ]
         del ssmConfig_.convertGroup[ group ]

   else:
      config.group.clear()
      ssmConfig_.convertGroup.clear()

#------------------------------------------------------------------------------
# legacy: (config)# [ no ] ip pim ssm range <acl-name>
# legacy: (config)# [ no ] ip pim ssm range standard  
# (config-af)# [ no ] ssm range <acl-name>
# (config-af)# [ no ] ssm range standard 
#------------------------------------------------------------------------------
def setIpPimsmSsmFilter( mode, args ):
   aclNameOrStandard = args.get( 'ACL', 'standard' )
   legacy = 'pim' in args
   pimsmConfig_ = pimsmConfigFromMode( mode, legacy=legacy )
   ssmConfig_ = ssmConfigFromMode( mode, legacy=legacy )
   af = getAddressFamilyFromMode( mode, legacy=legacy )
   aclType = 'ip' if ( af == AddressFamily.ipv4 ) else 'ipv6'
   acl = getAclConfig( aclType ).get( aclNameOrStandard )

   if acl and not acl.standard:
      mode.addError( '%s is not a standard acl' % aclNameOrStandard )
      return
   if aclNameOrStandard == 'standard':
      pimsmConfig_.ssmFilter = pimsmConfig_.ssmFilterStandard
      ssmConfig_.ssmFilter = ssmConfig_.ssmFilterStandard
   else:
      if acl and not allSourcesAreMulticast( acl, af ):
         mode.addWarning( '%s contains non-multicast rule(s)'
                          % aclNameOrStandard )
      pimsmConfig_.ssmFilter = aclNameOrStandard
      ssmConfig_.ssmFilter = aclNameOrStandard

def noIpPimsmSsmFilter( mode, args ):
   legacy = 'pim' in args
   pimsmConfig_ = pimsmConfigFromMode( mode, legacy=legacy )
   ssmConfig_ = ssmConfigFromMode( mode, legacy=legacy )
   if pimsmConfig_ is None:
      return
   pimsmConfig_.ssmFilter = pimsmConfig_.ssmFilterDefault
   ssmConfig_.ssmFilter = ssmConfig_.ssmFilterDefault

#------------------------------------------------------------------------------
# legacy: (config)# [ no ] ip pim sparse-mode fast-reroute <acl-name>
# (config-af)# [ no ] fast-reroute <acl-name>
#------------------------------------------------------------------------------
def setIpPimsmFrrFilter( mode, args ):
   aclName = args[ 'ACL' ]
   legacy = 'pim' in args
   pimsmConfig_ = pimsmConfigFromMode( mode, legacy=legacy )
   af = getAddressFamilyFromMode( mode, legacy=legacy )
   aclType = 'ip' if ( af == AddressFamily.ipv4 ) else 'ipv6'
   acl = getAclConfig( aclType ).get( aclName )

   if acl and not acl.standard:
      mode.addError( '%s is not a standard acl' % aclName )
      return
   if acl and not allSourcesAreMulticast( acl, af ):
      mode.addWarning( '%s contains non-multicast rule(s)' \
            % aclName )
   pimsmConfig_.frrFilter = aclName

def noIpPimsmFrrFilter( mode, args ):
   legacy = 'pim' in args
   pimsmConfig_ = pimsmConfigFromMode( mode, legacy=legacy )
   if pimsmConfig_ is None:
      return
   pimsmConfig_.frrFilter = pimsmConfig_.frrFilterDefault

#------------------------------------------------------------------------------
# legacy: 
#   (config)# [ no ] ip pim spt-threshold <infinity|zero> group-list <acl-name>
# (config-af)#[ no ] spt threshold <infinity|zero> match list <acl=name>
#------------------------------------------------------------------------------
def noPimsmSptThresh( mode, args ):
   legacy = 'spt-threshold' in args
   pimsmConfig_ = pimsmConfigFromMode( mode, legacy=legacy )
   if pimsmConfig_ is None:
      return

   aclName = args.get( 'ACL' )
   if aclName:
      # Delete only the specific group-list entry the user specified
      if aclName in pimsmConfig_.drSwitchAcl:
         del pimsmConfig_.drSwitchAcl[ aclName ]
      else:
         mode.addError( "Unknown ACL: %s" % aclName )
   else:
      pimsmConfig_.drSwitch = 'immediate'

def setPimsmSptThresh( mode, args ):
   legacy = 'spt-threshold' in args
   pimsmConfig_ = pimsmConfigFromMode( mode, legacy=legacy )
   if pimsmConfig_ is None:
      return

   thresh = 'never' if args[ 'THRESHOLD' ] == 'infinity'else 'immediate'
   aclName = args.get( 'ACL' )
   if aclName:
      pimsmConfig_.drSwitchAcl[ aclName ] = thresh
   else:
      pimsmConfig_.drSwitch = thresh

#------------------------------------------------------------------------------
# legacy: (config)# [ no ] ip pim sparse-mode sg-expiry-timer 120-259200
# (config-af)# [ no ] sg-expiry-timer 120-259200
#------------------------------------------------------------------------------
def setPimsmSgExpiry( mode, args ):
   legacy = 'pim' in args
   pimsmConfig_ = pimsmConfigFromMode( mode, legacy=legacy )
   if pimsmConfig_ is None:
      return
   pimsmConfig_.sgExpiryTimer = args.get( 'EXPIRY',
                                          pimsmConfig_.sgExpiryTimerDefault )

#------------------------------------------------------------------------------
# [no] make-before-break
#------------------------------------------------------------------------------
def setMbb( mode, args ):
   pimsmConfig_ = pimsmConfigFromMode( mode )
   if pimsmConfig_ is None:
      return
   pimsmConfig_.disableMbb = 'disabled' in args

def setRouteSgInstallThresh( mode, args ):
   pimsmConfig_ = pimsmConfigFromMode( mode )
   if pimsmConfig_:
      pimsmConfig_.sgInstallThresh = args.get( 'CRITERIA',
                                               pimsmConfig_.sgInstallThreshDefault )

def Plugin( entityManager ):
   #Af independent Config mounts
   configTypes = [ PimsmConfigColl, SsmConfigColl, ]
   doConfigMounts( entityManager, configTypes )

   global pimsmConfigColl, _allVrfConfig
   pimsmConfigColl = ConfigMount.mount( entityManager, 
                           'routing/pim/sparsemode/config',
                           'Routing::Pim::SparseMode::ConfigColl', 'w' )

   global pimConfigRoot
   pimConfigRoot = LazyMount.mount( entityManager, 
         'routing/pim/config',
         'Routing::Pim::ConfigColl', 'r' )
   _allVrfConfig = LazyMount.mount( entityManager, 'ip/vrf/config',
                                    'Ip::AllVrfConfig', 'r' )

   PimCliLib.pimSparseModeVrfConfiguredHook.addExtension( _pimsmConfigCreation )
   PimCliLib.pimSparseModeVrfDeletedHook.addExtension( _pimsmConfigDeletion )
   PimCliLib.canDeletePimSparseModeVrfHook.addExtension( _canDeletePimsmVrfConfig )
   PimCliLib.pimSparseModeCleanupHook.addExtension( _cleanupPimsmConfig )
