#!/usr/bin/env python
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import BasicCli
import CliCommand
import CliMatcher
import CliToken.Clear
import CliPlugin.AclCli as AclCli
import CliPlugin.MldCliLib as MldCliLib
import CliPlugin.MrouteCli as MrouteCli
import CliPlugin.Ip6AddrMatcher as Ip6AddrMatcher
import CliPlugin.IntfCli as IntfCli
import CliPlugin.IraIp6IntfCli as IraIp6IntfCli
from IpLibConsts import DEFAULT_VRF
import LazyMount
import McastCommonCliLib
import Tac
import Tracing

traceIntf = Tracing.trace4

mldConfig = None
mldAddrs = None
aclConfigCli = None
mldStatus = None
mldClearConfigColl = None
mldClearStatusColl = None

def _mldVrfDefinitionHook( vrfName ):
   if vrfName not in mldClearConfigColl.clearConfig:
      c = mldClearConfigColl.newClearConfig( vrfName )
      assert c is not None

def _mldVrfDeletionHook( vrfName ):
   if vrfName in mldClearConfigColl.clearConfig and vrfName != DEFAULT_VRF:
      del mldClearConfigColl.clearConfig[ vrfName ]

def vrfExists( vrfName ):
   if vrfName == DEFAULT_VRF or vrfName in mldClearConfigColl.clearConfig:
      return True
   return False

AclAction = Tac.Type( 'Acl::Action' )

ipIntfMode = IraIp6IntfCli.RoutingProtocolIntfConfigModelet
nodeMld = CliCommand.guardedKeyword( 'mld',
      helpdesc='Multicast Listener Discovery commands',
      guard=McastCommonCliLib.mcast6RoutingSupportedIntfGuard)

class MldIntf( IntfCli.IntfDependentBase ):
   def setDefault( self ):
      intfId = self.intf_.name
      traceIntf( "Resetting interface %s to default" % intfId )
      mldConfig.intfDel( intfId )

#--------------------------------------------------------------------------------
# [ no | default ] mld
#--------------------------------------------------------------------------------
class MldCmd( CliCommand.CliCommandClass ):
   syntax = 'mld'
   noOrDefaultSyntax = syntax
   data = {
      'mld' : nodeMld,
   }

   @staticmethod
   def handler( mode, args ):
      intf = mldConfig.intfIs( mode.intf.name )
      intf.enabled = True

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      intfId = mode.intf.name
      intf = mldConfig.intf( intfId )
      if intf is None:
         return
      intf.enabled = False
      mldConfig.intfUpdate( intfId )

ipIntfMode.addCommandClass( MldCmd )

#--------------------------------------------------------------------------------
# [ no | default ] mld static-group access-list ACL_NAME
#--------------------------------------------------------------------------------
def allowedNumberOfStaticGroupsPerRule():
   return 16384

# examines the fitness of an acl IpRuleConfig as a SourceGroup range.
def checkSourceGroupRule( rule ):
   if rule.action == AclAction.permit:
      # Ip addr/prefix len
      destination = rule.filter.destinationFullMask
      source = rule.filter.sourceFullMask
      sourceAddr = source.address
      sourceMask = source.mask
      destAddr = destination.address
      destMask = destination.mask

      wildCardBits = bin( ~destMask.word0 & 0xFFFFFFFF ).count( "1" )
      wildCardBits += bin( ~destMask.word1 & 0xFFFFFFFF ).count( "1" )
      wildCardBits += bin( ~destMask.word2 & 0xFFFFFFFF ).count( "1" )
      wildCardBits += bin( ~destMask.word3 & 0xFFFFFFFF ).count( "1" )

      if wildCardBits > 14:
         return ( False, "total groups in rule exceeds maximum groups %d"
                  % allowedNumberOfStaticGroupsPerRule() )

      if not mldAddrs.validMulticastAddr( destAddr ):
         return ( False, "invalid group address - must be a multicast address" )

      if not mldAddrs.validMulticastAddr( destMask ):
         return ( False, "invalid group address range - contains unicast addresses" )

      if mldAddrs.isWildcardAddr( sourceAddr ):
         if not mldAddrs.isWildcardAddr ( sourceMask ):
            return ( False,
                     "invalid source address/mask - (*,G) requires a zero mask" )

      else:
         if ( sourceMask.word0 != 0xFFFFFFFF or
              sourceMask.word1 != 0xFFFFFFFF or
              sourceMask.word2 != 0xFFFFFFFF or
              sourceMask.word3 != 0xFFFFFFFF ) :
            return ( False,
               "invalid source address - must be a single address, not a range" )

         if not mldAddrs.validUnicastAddr( sourceAddr ):
            return ( False,
                     "invalid source address - must be a unicast address" )
   else:
      # currently deny is not supported - maybe supported in future.
      return ( False,
               "action must be 'permit'" )

   return ( True, None )

def validateAcl( mode, aclName ):
   allValid = True
   acl = aclConfigCli.config[ 'ipv6' ].acl[ aclName ].currCfg

   for seq, uid in acl.ruleBySequence.iteritems():
      maybeRet = checkSourceGroupRule( acl.ip6RuleById[ uid ] )
      if not maybeRet[ 0 ]:
         mode.addWarning( "Seq no %d: %s" % ( seq, maybeRet[ 1 ] ) )
         allValid = False

   return allValid

class MldStaticGroupAclCmd( CliCommand.CliCommandClass ):
   syntax = 'mld static-group access-list ACL_NAME'
   noOrDefaultSyntax = 'mld static-group access-list ...'
   data = {
      'mld' : nodeMld,
      'static-group' : 'MLD static multicast group',
      'access-list' : 'IPv6 access list for use as static group list',
      'ACL_NAME' : AclCli.userIp6AclNameMatcher,
   }

   @staticmethod
   def handler( mode, args ):
      name = args[ 'ACL_NAME' ]
      intfId = mode.intf.name

      if not name in aclConfigCli.config[ 'ipv6' ].acl:
         mode.addWarning( "Access list %s not configured. Assigning anyway."
                          % name )
      else:
         if not validateAcl( mode, name ):
            mode.addWarning(
               "Access list %s contains rules invalid as static group "
               "specifications.\nAssigning anyway." % name )

      mldConfig.aclIs( intfId, name )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      intfId = mode.intf.name
      mldConfig.aclDel( intfId )
      mldConfig.intfUpdate( intfId )

ipIntfMode.addCommandClass( MldStaticGroupAclCmd )

#--------------------------------------------------------------------------------
# [ no | default ] mld static-group GROUP [ SOURCE ]
#--------------------------------------------------------------------------------
def validateStaticGroup( mode, groupAddr, sourceAddr ):
   if not mldAddrs.validMulticastAddr( groupAddr ):
      mode.addError( "Invalid multicast group %s, must be in the range %s"
                     % ( groupAddr, mldAddrs.multicastRange ) )
      return False

   if mldAddrs.isWildcardAddr( sourceAddr ):
      return True

   if not mldAddrs.validUnicastAddr( sourceAddr ):
      mode.addError( "Source address is not a valid unicast host address" )
      return False

   return True

class MldStaticGroupCmd( CliCommand.CliCommandClass ):
   syntax = 'mld static-group GROUP [ SOURCE ]'
   noOrDefaultSyntax = syntax
   data = {
      'mld' : nodeMld,
      'static-group' : 'MLD static multicast group',
      'GROUP' : Ip6AddrMatcher.Ip6AddrMatcher( helpdesc='IPv6 group address' ),
      'SOURCE' : Ip6AddrMatcher.Ip6AddrMatcher( helpdesc='IPv6 source address' ),
   }

   @staticmethod
   def handler( mode, args ):
      groupAddr = args[ 'GROUP' ]
      sourceAddr = args.get( 'SOURCE', mldAddrs.wcAddr )

      if not validateStaticGroup( mode, groupAddr, sourceAddr ):
         return

      intfId = mode.intf.name
      mldConfig.sourceIs( intfId, groupAddr, sourceAddr )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      groupAddr = args[ 'GROUP' ]
      sourceAddr = args.get( 'SOURCE', mldAddrs.wcAddr )

      if not validateStaticGroup( mode, groupAddr, sourceAddr ):
         return

      intfId = mode.intf.name
      mldConfig.sourceDel( intfId, groupAddr, sourceAddr )
      mldConfig.intfUpdate( intfId )

ipIntfMode.addCommandClass( MldStaticGroupCmd )

#--------------------------------------------------------------------------------
# [ no | default ] mld query-interval INTERVAL
#--------------------------------------------------------------------------------
class MldQueryIntervalIntervalCmd( CliCommand.CliCommandClass ):
   syntax = 'mld query-interval INTERVAL'
   noOrDefaultSyntax = 'mld query-interval ...'
   data = {
      'mld' : nodeMld,
      'query-interval' : 'MLD query interval',
      'INTERVAL' : CliMatcher.IntegerMatcher( 1, 3175,
         helpdesc='Time between queries in units of seconds' ),
   }

   @staticmethod
   def handler( mode, args ):
      intfId = mode.intf.name
      mldConfig.queryIntervalIs( intfId, args[ 'INTERVAL' ] )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      intfId = mode.intf.name
      mldConfig.queryIntervalIs( intfId )
      mldConfig.intfUpdate( intfId )

ipIntfMode.addCommandClass( MldQueryIntervalIntervalCmd )

#--------------------------------------------------------------------------------
# [ no | default ] mld query-response-interval INTERVAL
#--------------------------------------------------------------------------------
class MldQueryResponseIntervalIntervalCmd( CliCommand.CliCommandClass ):
   syntax = 'mld query-response-interval INTERVAL'
   noOrDefaultSyntax = 'mld query-response-interval ...'
   data = {
      'mld' : nodeMld,
      'query-response-interval' : 'MLD query response interval',
      'INTERVAL' : CliMatcher.IntegerMatcher( 1, 3175,
         helpdesc='Query response interval in seconds' ),
   }

   @staticmethod
   def handler( mode, args ):
      intfId = mode.intf.name
      mldConfig.queryResponseIntervalIs( intfId, args[ 'INTERVAL' ] )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      intfId = mode.intf.name
      mldConfig.queryResponseIntervalIs( intfId )
      mldConfig.intfUpdate( intfId )

ipIntfMode.addCommandClass( MldQueryResponseIntervalIntervalCmd )

#--------------------------------------------------------------------------------
# [ no | default ] mld startup-query-interval INTERVAL
#--------------------------------------------------------------------------------
class MldStartupQueryIntervalIntervalCmd( CliCommand.CliCommandClass ):
   syntax = 'mld startup-query-interval INTERVAL'
   noOrDefaultSyntax = 'mld startup-query-interval ...'
   data = {
      'mld' : nodeMld,
      'startup-query-interval' : 'MLD startup query interval',
      'INTERVAL' : CliMatcher.IntegerMatcher( 1, 3175,
         helpdesc='Startup query interval in seconds' ),
   }

   @staticmethod
   def handler( mode, args ):
      intfId = mode.intf.name
      mldConfig.startupQueryIntervalIs( intfId, args[ 'INTERVAL' ] )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      intfId = mode.intf.name
      mldConfig.startupQueryIntervalIs( intfId )
      mldConfig.intfUpdate( intfId )

ipIntfMode.addCommandClass( MldStartupQueryIntervalIntervalCmd )

#--------------------------------------------------------------------------------
# [ no | default ] mld startup-query-count COUNT
#--------------------------------------------------------------------------------
class MldStartupQueryCountCountCmd( CliCommand.CliCommandClass ):
   syntax = 'mld startup-query-count COUNT'
   noOrDefaultSyntax = 'mld startup-query-count ...'
   data = {
      'mld' : nodeMld,
      'startup-query-count' : 'MLD startup query count',
      'COUNT' : CliMatcher.IntegerMatcher( 1, 100, helpdesc='Startup query count' ),
   }

   @staticmethod
   def handler( mode, args ):
      intfId = mode.intf.name
      mldConfig.startupQueryCountIs( intfId, args[ 'COUNT' ] )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      intfId = mode.intf.name
      mldConfig.startupQueryCountIs( intfId )
      mldConfig.intfUpdate( intfId )

ipIntfMode.addCommandClass( MldStartupQueryCountCountCmd )

#--------------------------------------------------------------------------------
# [ no | default ] mld robustness COUNT
#--------------------------------------------------------------------------------
class MldRobustnessCountCmd( CliCommand.CliCommandClass ):
   syntax = 'mld robustness COUNT'
   noOrDefaultSyntax = 'mld robustness ...'
   data = {
      'mld' : nodeMld,
      'robustness' : 'MLD querier robustness',
      'COUNT' : CliMatcher.IntegerMatcher( 1, 100, helpdesc='Robustness count' ),
   }

   @staticmethod
   def handler( mode, args ):
      intfId = mode.intf.name
      mldConfig.robustnessIs( intfId, args[ 'COUNT' ] )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      intfId = mode.intf.name
      mldConfig.robustnessIs( intfId )
      mldConfig.intfUpdate( intfId )

ipIntfMode.addCommandClass( MldRobustnessCountCmd )

#--------------------------------------------------------------------------------
# [ no | default ] mld last-listener-query-interval INTERVAL
#--------------------------------------------------------------------------------
class MldLastListenerQueryIntervalIntervalCmd( CliCommand.CliCommandClass ):
   syntax = 'mld last-listener-query-interval INTERVAL'
   noOrDefaultSyntax = 'mld last-listener-query-interval ...'
   data = {
      'mld' : nodeMld,
      'last-listener-query-interval' : 'MLD last listener query interval',
      'INTERVAL' : CliMatcher.IntegerMatcher( 1, 3175,
         helpdesc='Last listener query interval in seconds' ),
   }

   @staticmethod
   def handler( mode, args ):
      intfId = mode.intf.name
      mldConfig.lastListenerQueryIntervalIs( intfId, args[ 'INTERVAL' ] )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      intfId = mode.intf.name
      mldConfig.lastListenerQueryIntervalIs( intfId )
      mldConfig.intfUpdate( intfId )

ipIntfMode.addCommandClass( MldLastListenerQueryIntervalIntervalCmd )

#--------------------------------------------------------------------------------
# [ no | default ] mld last-listener-query-count COUNT
#--------------------------------------------------------------------------------
class MldLastListenerQueryCountCountCmd( CliCommand.CliCommandClass ):
   syntax = 'mld last-listener-query-count COUNT'
   noOrDefaultSyntax = 'mld last-listener-query-count ...'

   data = {
      'mld' : nodeMld,
      'last-listener-query-count' : 'MLD last listener query count',
      'COUNT' : CliMatcher.IntegerMatcher( 0, 100,
         helpdesc='Last listener query count' ),
   }

   @staticmethod
   def handler( mode, args ):
      intfId = mode.intf.name
      mldConfig.lastListenerQueryCountIs( intfId, args[ 'COUNT' ] )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      intfId = mode.intf.name
      mldConfig.lastListenerQueryCountIs( intfId )
      mldConfig.intfUpdate( intfId )

ipIntfMode.addCommandClass( MldLastListenerQueryCountCountCmd )

#--------------------------------------------------------------------------------
# clear mld [ vrf VRF ] statistics [ interface INTF ]
#--------------------------------------------------------------------------------
def waitForStatisticsToClear( mode, vrfName, mldClearConfig ):
   try:
      Tac.waitFor(
            lambda:
               mldClearStatusColl.clearStatus[ vrfName ].lastStatsClearTime \
                     >= mldClearConfig.clearStatsRequestTime,
                     description='Counters to get cleared.',
                     sleep=True, maxDelay=0.5, timeout=5 )
   except Tac.Timeout:
      mode.addWarning(
            "Mld Counters may have not been cleared. Is Mld agent running?" )

def clearMldStatisticsByVrf( mode, vrfName, intf ):
   if mldClearConfigColl.clearConfig[ vrfName ] is not None and \
   vrfName in mldClearStatusColl.clearStatus:
      mldClearConfig = mldClearConfigColl.clearConfig[ vrfName ]
      if intf is None:
         mldClearConfig.intfId = Tac.newInstance( "Arnet::IntfId" )
      else:
         mldClearConfig.intfId = intf.name
      mldClearConfig.clearStatsRequestTime = Tac.now()
      waitForStatisticsToClear( mode, vrfName, mldClearConfig )

class ClearMldStatisticsCmd( CliCommand.CliCommandClass ):
   syntax = 'clear mld [ VRF ] statistics [ interface INTF ]'
   data = {
      'clear' : CliToken.Clear.clearKwNode,
      'mld' : CliCommand.guardedKeyword( 'mld',
         helpdesc='Multicast Listener Discovery commands',
         guard=McastCommonCliLib.mcast6RoutingSupportedGuard ),
      'VRF' : MldCliLib.vrfExprFactory,
      'statistics' : 'Clear Mld Counters',
      'interface' : 'Clear statistics on a specific interface',
      'INTF' : IntfCli.Intf.matcher,
   }

   @staticmethod
   def handler( mode, args ):
      vrfName = args.get( 'VRF', DEFAULT_VRF )
      intf = args.get( 'INTF' )
      if not vrfExists( vrfName ) and vrfName != "all":
         mode.addError( "Invalid vrf name: %s" % vrfName )
         return
      
       # pylint: disable-msg=W0212
      if isinstance( mldClearConfigColl, LazyMount._Proxy ):
         LazyMount.force( mldClearConfigColl )
      if isinstance( mldStatus, LazyMount._Proxy ):
         LazyMount.force( mldStatus )

      #clear all vrfs
      if vrfName == "all":
         for vrf in mldClearConfigColl.clearConfig:
            clearMldStatisticsByVrf( mode, vrf, intf )
         clearMldStatisticsByVrf( mode, DEFAULT_VRF, intf )
      #clear one vrf
      else:
         clearMldStatisticsByVrf( mode, vrfName, intf )

BasicCli.EnableMode.addCommandClass( ClearMldStatisticsCmd )

def Plugin( entityManager ):

   global mldConfig, mldStatus
   global mldAddrs
   global aclConfigCli
   global mldClearConfigColl, mldClearStatusColl

   mldConfig = MldCliLib.MldConfig( entityManager )
   mldStatus = MldCliLib.MldStatus( entityManager )
   mldAddrs = MldCliLib.MldAddrs()

   aclConfigCli = LazyMount.mount( entityManager, 'acl/config/cli',
                                   'Acl::Input::Config', 'r' )
   mldClearConfigColl = LazyMount.mount( entityManager, 'routing6/mld/clearConfig',
                                         'Routing::Mld::ClearConfigByVrf', 'w' )
   mldClearStatusColl = LazyMount.mount( entityManager, 'routing6/mld/clearStatus', 
                                         'Routing::Mld::ClearStatusByVrf', 'r' )

   IntfCli.Intf.registerDependentClass( MldIntf, priority=22 )

   MrouteCli.routerMcastVrfDefinitionHook.addExtension( _mldVrfDefinitionHook )
   MrouteCli.routerMcastVrfDeletionHook.addExtension( _mldVrfDeletionHook )
