# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

from Arnet import IpGenAddr
from CliPlugin import IntfCli, IpAddrMatcher, IpGenAddrMatcher, VlanCli, VxlanCli
from CliPlugin.EthIntfCli import EthPhyAutoIntfType
from CliPlugin.LagIntfCli import LagAutoIntfType
from Intf.IntfRange import IntfRangeMatcher
import CliCommand
import ConfigMount
import Tac
import Toggles.McastVpnLibToggleLib
import Tracing

t0 = Tracing.trace0

membershipJoinStatusIpv4Config = None
membershipJoinStatusIpv6Config = None

AF = Tac.Type( "Arnet::AddressFamily" )
IrVxlanIpv4TunnelId = Tac.Type( "Multicast::Tunnel::IrVxlanIpv4TunnelId" )
McastTunnelIntfId = Tac.Type( "Multicast::Tunnel::McastTunnelIntfId" )
PimVxlanTunnel = Tac.Type( "Multicast::Tunnel::PimVxlanTunnel" )

def convertVtepIpToIntfId( ipAddr ):
   mcastTunnelId = IrVxlanIpv4TunnelId.makeTunnel( ipAddr )
   intfId = McastTunnelIntfId.makeIntfId( mcastTunnelId )
   return intfId

def pimTunnelTypeToIntfId( pimTunnelType ):
   mcastTunnelId = PimVxlanTunnel.makeTunnel( pimTunnelType )
   intfId = McastTunnelIntfId.makeIntfId( mcastTunnelId )
   return intfId

def getPimTunnelIntfs( args ):
   intfs = []
   if 'ipv4-pim-sm' in args:
      intfs.append( pimTunnelTypeToIntfId( "pimsmIpv4Tunnel" ) )
   if 'ipv4-pim-ssm' in args:
      intfs.append( pimTunnelTypeToIntfId( "pimssmIpv4Tunnel" ) )
   if 'ipv6-pim-sm' in args:
      intfs.append( pimTunnelTypeToIntfId( "pimsmIpv6Tunnel" ) )
   if 'ipv6-pim-ssm' in args:
      intfs.append( pimTunnelTypeToIntfId( "pimssmIpv6Tunnel" ) )
   return intfs

class MembershipJoinStatusCmd( CliCommand.CliCommandClass ):
   syntax = "vxlan vlan VLAN_ID member " \
            "[ [ exclude ] GROUP [ SOURCE ] ] " \
            "( ( vtep { VTEPIP } ) | PORT | " \
            "( tunnel " \
            "{ ipv4-pim-sm | ipv4-pim-ssm | ipv6-pim-sm | ipv6-pim-ssm } ) )"
   noOrDefaultSyntax = "vxlan vlan VLAN_ID member " \
                       "[ [ exclude ] GROUP [ SOURCE ] ] " \
                       "[ ( ( vtep { VTEPIP } ) | PORT | " \
                       "( tunnel " \
                       "{ ipv4-pim-sm | ipv4-pim-ssm | " \
                       "ipv6-pim-sm | ipv6-pim-ssm } ) ) ]"
   data = {
         'vxlan': VxlanCli.vxlanNode,
         'vlan': CliCommand.Node( matcher=VxlanCli.matcherVlan,
                                  guard=VxlanCli.isVxlan1InterfaceGuard ),
         'VLAN_ID': VlanCli.vlanIdMatcher,
         'member': "Specify multicast group interface members",
         'exclude': "Exclude from traffic delivery",
         'GROUP': IpGenAddrMatcher.IpGenAddrMatcher(
            helpdesc="IP multicast group address" ),
         'SOURCE': IpGenAddrMatcher.IpGenAddrMatcher(
            helpdesc="IP multicast source address" ),
         'vtep': "VXLAN Tunnel End Points to deliver traffic",
         'VTEPIP': IpAddrMatcher.IpAddrMatcher(
            helpdesc="IP address of the remote VTEP" ),
         'PORT': IntfRangeMatcher(
            explicitIntfTypes=( EthPhyAutoIntfType, LagAutoIntfType ),
            helpdesc="Local interface to deliver traffic" ),
         'tunnel': "Configure tunnel parameters",
         'ipv4-pim-sm': "Use a IP v4 PIM sparse-mode tree",
         'ipv4-pim-ssm': "Use a IP v4 PIM source-specifc tree",
         'ipv6-pim-sm': "Use a IP v6 PIM sparse-mode tree",
         'ipv6-pim-ssm': "Use a IP v6 PIM source-specifc tree"
   }

   @staticmethod
   def handler( mode, args ):
      t0( "MembershipJoinStatusCmd handler args", args )
      vlanId = args[ 'VLAN_ID' ].id
      if 'GROUP' in args:
         group = args[ 'GROUP' ]
         validationError = None
         if 'SOURCE' in args:
            source = args[ 'SOURCE' ]
            if not source.isUnicast:
               validationError = "Invalid source address"
         elif group.af == AF.ipv4:
            source = IpGenAddr( '0.0.0.0' )
         else:
            source = IpGenAddr( '::' )
         if not group.isMulticast:
            validationError = "Invalid multicast address"
         if group.isLinkLocalMulticast:
            validationError = "Reserved multicast address"
         if validationError:
            mode.addError( validationError )
            return
         if group.af == AF.ipv6:
            vlanMembershipJoin = membershipJoinStatusIpv6Config.newVlan( vlanId )
         else:
            vlanMembershipJoin = membershipJoinStatusIpv4Config.newVlan( vlanId )
         sourceGroupStatus = vlanMembershipJoin \
                             .newGroup( group ).newSource( source )
         membershipColl = sourceGroupStatus.excludeIntf \
                          if 'exclude' in args else sourceGroupStatus.includeIntf
         if 'VTEPIP' in args:
            for x in args[ 'VTEPIP' ]:
               membershipColl.add( convertVtepIpToIntfId( x ) )
         elif 'PORT' in args:
            for x in args[ 'PORT' ].intfNames():
               membershipColl.add( x )
         elif 'tunnel' in args:
            for x in getPimTunnelIntfs( args ):
               membershipColl.add( x )
      else:
         ipv4MembershipJoin = membershipJoinStatusIpv4Config.newVlan( vlanId )
         ipv6MembershipJoin = membershipJoinStatusIpv6Config.newVlan( vlanId )
         if 'VTEPIP' in args:
            for x in args[ 'VTEPIP' ]:
               ipv4MembershipJoin.routerIntf.add( convertVtepIpToIntfId( x ) )
               ipv6MembershipJoin.routerIntf.add( convertVtepIpToIntfId( x ) )
         elif 'PORT' in args:
            for x in args[ 'PORT' ].intfNames():
               ipv4MembershipJoin.routerIntf.add( x )
               ipv6MembershipJoin.routerIntf.add( x )
         elif 'tunnel' in args:
            for x in getPimTunnelIntfs( args ):
               ipv4MembershipJoin.routerIntf.add( x )
               ipv6MembershipJoin.routerIntf.add( x )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      t0( "MembershipJoinStatusCmd noOrDefaultHandler args", args )
      vlanId = args[ 'VLAN_ID' ].id

      ipv4MembershipJoin = membershipJoinStatusIpv4Config \
                           .vlan.get( vlanId )
      ipv6MembershipJoin = membershipJoinStatusIpv6Config \
                           .vlan.get( vlanId )

      if 'GROUP' in args:
         group = args[ 'GROUP' ]
         if group.af == AF.ipv6:
            vlanMembershipJoin = ipv6MembershipJoin
         else:
            vlanMembershipJoin = ipv4MembershipJoin
         if not vlanMembershipJoin:
            return
         if not vlanMembershipJoin.group.get( group ):
            return
         if 'SOURCE' in args:
            source = args[ 'SOURCE' ]
            # If given source D.N.E, exit
            if not vlanMembershipJoin.group[ group ].source.get( source ):
               return
         elif group.af == AF.ipv6:
            source = IpGenAddr( '::' )
         else:
            source = IpGenAddr( '0.0.0.0' )
         if 'VTEPIP' not in args and 'PORT' not in args and 'tunnel' not in args \
            and 'SOURCE' not in args:
            del vlanMembershipJoin.group[ group ]
            if not vlanMembershipJoin.routerIntf:
               if group.af == AF.ipv4:
                  del membershipJoinStatusIpv4Config.vlan[ vlanId ]
               else:
                  del membershipJoinStatusIpv6Config.vlan[ vlanId ]
               return
         # If default source 0.0.0.0/:: D.N.E, exit
         if not vlanMembershipJoin.group[ group ].source.get( source ):
            return
         sourceGroupStatus = vlanMembershipJoin.group[ group ].source[ source ]
         membershipColl = sourceGroupStatus.excludeIntf \
                          if 'exclude' in args else sourceGroupStatus.includeIntf
         if 'VTEPIP' in args:
            for x in args[ 'VTEPIP' ]:
               membershipColl.remove( convertVtepIpToIntfId( x ) )
         elif 'PORT' in args:
            for x in args[ 'PORT' ].intfNames():
               membershipColl.remove( x )
         elif 'tunnel' in args:
            for x in getPimTunnelIntfs( args ):
               membershipColl.remove( x )
         else:
            membershipColl.clear()
         if not vlanMembershipJoin.group[ group ].source[ source ].excludeIntf and \
            not vlanMembershipJoin.group[ group ].source[ source ].includeIntf:
            del vlanMembershipJoin.group[ group ].source[ source ]
         if not vlanMembershipJoin.group[ group ].source:
            del vlanMembershipJoin.group[ group ]
      else:
         if ipv4MembershipJoin:
            if 'VTEPIP' in args:
               for x in args[ 'VTEPIP' ]:
                  ipv4MembershipJoin.routerIntf.remove( convertVtepIpToIntfId( x ) )
            elif 'PORT' in args:
               for x in args[ 'PORT' ].intfNames():
                  ipv4MembershipJoin.routerIntf.remove( x )
            elif 'tunnel' in args:
               for x in getPimTunnelIntfs( args ):
                  ipv4MembershipJoin.routerIntf.remove( x )
            else:
               ipv4MembershipJoin.routerIntf.clear()
         if ipv6MembershipJoin:
            if 'VTEPIP' in args:
               for x in args[ 'VTEPIP' ]:
                  ipv6MembershipJoin.routerIntf.remove( convertVtepIpToIntfId( x ) )
            elif 'PORT' in args:
               for x in args[ 'PORT' ].intfNames():
                  ipv6MembershipJoin.routerIntf.remove( x )
            elif 'tunnel' in args:
               for x in getPimTunnelIntfs( args ):
                  ipv6MembershipJoin.routerIntf.remove( x )
            else:
               ipv6MembershipJoin.routerIntf.clear()
      if ipv4MembershipJoin:
         if not ipv4MembershipJoin.group and not ipv4MembershipJoin.routerIntf:
            del membershipJoinStatusIpv4Config.vlan[ vlanId ]
      if ipv6MembershipJoin:
         if not ipv6MembershipJoin.group and not ipv6MembershipJoin.routerIntf:
            del membershipJoinStatusIpv6Config.vlan[ vlanId ]

if Toggles.McastVpnLibToggleLib.toggleMcastVpnUnderlayMulticastEnabled():
   VxlanCli.VxlanIntfModelet.addCommandClass( MembershipJoinStatusCmd )

class VxlanIntfMcastVpnIntf( IntfCli.IntfDependentBase ):
   """Cleanup config in response to VTI destruction."""
   def setDefault( self ):
      membershipJoinStatusIpv4Config.vlan.clear()
      membershipJoinStatusIpv6Config.vlan.clear()

def Plugin( entityManager ):
   global membershipJoinStatusIpv4Config, membershipJoinStatusIpv6Config
   membershipJoinStatus = 'Irb::Multicast::Gmp::MembershipJoinStatusCli'
   membershipJoinStatusIpv4Config = \
         ConfigMount.mount( entityManager,
                            Tac.Type( membershipJoinStatus )
                            .mountPath( 'ipv4', 'cli' ),
                            membershipJoinStatus, "w" )
   membershipJoinStatusIpv6Config = \
         ConfigMount.mount( entityManager,
                            Tac.Type( membershipJoinStatus )
                            .mountPath( 'ipv6', 'cli' ),
                            membershipJoinStatus, "w" )
   if Toggles.McastVpnLibToggleLib.toggleMcastVpnUnderlayMulticastEnabled():
      IntfCli.Intf.registerDependentClass( VxlanIntfMcastVpnIntf )
