# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from CliModel import Model, Dict, Bool, Int, Enum, Float
from ArnetModel import Ip6Address
from CliPlugin import IntfCli
from IntfModels import Interface

# pylint: disable-msg=unsubscriptable-object

statusEnumToStr = {
      "enabled": "Enabled",
      "disabled": "Disabled",
      "default": "Default",
}

stateEnumToStr = {
      "enabled": "Enabled",
      "disabled": "Disabled",
}

class MldSnoopingVlanInfo( Model ):
   mldSnoopingState = Enum( help="State of MLD Snooping",
         values=stateEnumToStr.keys() )
   maxGroups = Int( default=65534,
         help="Maximum number of multicast groups that can join the VLAN" )
   groupsOverrun = Bool( default=False,
         help="There has been an attempt to create more than "
         "the maximum number of groups" )
   pruningActive = Bool( help="MLD snooping pruning is active", default=False )
   floodingTraffic = Bool( help="Flooding traffic to VLAN", default=True )

class MldSnoopingInfo( Model ):
   mldSnoopingState = Enum( help="Global state of MLD Snooping",
         values=stateEnumToStr.keys(), optional=True )
   robustness = Int(
         help="Number of queries sent to age out a port's membership in group",
         optional=True )
   vlans = Dict( keyType=int, valueType=MldSnoopingVlanInfo,
         help="A mapping of VLAN's ID to its information", optional=True )

   def render( self ):
      if self.mldSnoopingState is None:
         return
      print "   Global MLD Snooping configuration:"
      print "-------------------------------------------"
      print "%-30s : %s" % ( 'MLD snooping',
            stateEnumToStr[ self.mldSnoopingState ] )
      print "%-30s : %s" % ( 'Robustness variable', self.robustness )
      print ""
      for vlan, info in sorted( self.vlans.iteritems() ):
         print "VLAN", "%s" % vlan, ":"
         print "----------"
         print "%-30s : %s" % ( 'MLD snooping',
               stateEnumToStr[ info.mldSnoopingState ] )
         mldMaxGroupsLimitStr = 'MLD max group limit'
         if info.maxGroups == 65534:
            print "%-30s : %s" % ( mldMaxGroupsLimitStr, 'No limit set' )
         else:
            print "%-30s : %u" % ( mldMaxGroupsLimitStr,
                  info.maxGroups )
         print "%-30s : %s" % ( 'Recent attempt to exceed limit',
               'Yes' if info.groupsOverrun else 'No' )
         print "%-30s : %s" % ( 'MLD snooping pruning active',
               info.pruningActive )
         print "%-30s : %s" % ( 'Flooding traffic to VLAN',
               info.floodingTraffic )

class MldSnoopingCountersInterface( Model ):
   pimPacketsReceived = Int( help="Number of PIM packets received",
         optional=True )
   shortPacketsReceived = Int( help="Number of packets received"
         " with not enough IP payload" )
   nonIpPacketsReceived = Int( help="Number of non IP packets received" )
   badChecksumIpPacketsReceived = Int( help="Number of packets received"
         " for which IP checksum check failed" )
   unknownIpPacketsReceived = Int( help="Number of packets received"
         " with unknown IP Protocol" )
   badChecksumPimPacketsReceived = Int( help="Number of packets received"
         " for which PIM checksum check failed" )
   otherPacketsSent = Int( help="Number of other packets sent", optional=True )

   badChecksumIcmpV6PacketsReceived = Int( help="Number of packages received"
         " for which ICMP v6 checksum check failed" )
   badMldQueryReceived = Int( help="Number of invalid MLD querys received" )
   mldV1QueryReceived = Int( help="Number of MLD v1 queries received",
         optional=True )
   mldV2QueryReceived = Int( help="Number of MLD v2 queries received",
         optional=True )
   badMldV2ReportReceived = Int( help="Number of invalid MLD v2 reports received" )
   mldV2ReportReceived = Int( help="Number of MLD v2 reports received",
         optional=True )
   otherIcmpPacketsReceived = Int( help="Number of other ICMP v6 packets received",
         optional=True )
   mldQuerySend = Int( help="Number of MLD querys sent",
         optional=True )
   mldReportSend = Int( help="Number of MLD reports sent",
         optional=True )

class MldSnoopingCounters( Model ):
   interfaces = Dict( keyType=Interface, valueType=MldSnoopingCountersInterface,
         help="Map Interface name with its counter details" )
   _errorSpecific = Bool( help="Display only error counters" )

   def render( self ):
      if self._errorSpecific:
         errorCounterFormat = '%-10s %-10s %-9s %-9s %-12s %-9s %-9s %-9s %-9s'
         print errorCounterFormat % ( '', 'Packet',
               '  Packet', 'Bad IP', 'Unknown', 'Bad PIM', 'Bad ICMP',
               'Bad MLD', 'Bad MLD' )
         print errorCounterFormat % ( 'Port', 'Too Short',
               '  Not IP', 'Checksum', 'IP Protocol', 'Checksum', 'Checksum',
               'Query', 'Report' )
         print "-" * 93
         for intf, counters in sorted( self.interfaces.iteritems() ):
            print '%-10s %9s %9s %9s %12s %9s %9s %9s %9s' % (
                  IntfCli.Intf.getShortname( intf ),
                  counters.shortPacketsReceived,
                  counters.nonIpPacketsReceived,
                  counters.badChecksumIpPacketsReceived,
                  counters.unknownIpPacketsReceived,
                  counters.badChecksumPimPacketsReceived,
                  counters.badChecksumIcmpV6PacketsReceived,
                  counters.badMldQueryReceived,
                  counters.badMldV2ReportReceived )
         return

      print "{:^38}|{:>10}".format( 'Input', 'Output' )
      print "Port   Queries Reports  Others  Errors|" \
            "Queries Reports  Others"
      print "-" * 62
      for intf, counters in sorted( self.interfaces.iteritems() ):
         print '%-6s %7d %7d %7d %7d %7d %7d %7d' % (
            IntfCli.Intf.getShortname( intf ),
            ( counters.mldV1QueryReceived + counters.mldV2QueryReceived ),
            counters.mldV2ReportReceived,
            ( counters.pimPacketsReceived + counters.otherIcmpPacketsReceived ),
            ( counters.shortPacketsReceived + counters.nonIpPacketsReceived +
               counters.badChecksumIpPacketsReceived +
               counters.unknownIpPacketsReceived +
               counters.badChecksumPimPacketsReceived +
               counters.badChecksumIcmpV6PacketsReceived +
               counters.badMldQueryReceived +
               counters.badMldV2ReportReceived ),
            counters.mldQuerySend,
            counters.mldReportSend,
            counters.otherPacketsSent )

class MldSnoopingVlanQuerier( Model ):
   querierAddress = Ip6Address(
         help="Address of MLD querier in the VLAN" )
   mldVersion = Enum( help="Version of MLD snooping querier in the vlan",
         values=( 'v1', 'v2', 'unknown' ) )
   querierInterface = Interface( help="Interface where the querier is located" )
   queryResponseInterval = Float( help="Effective maximum period a recipient"
   " can wait before responding with a membership report in seconds" )

class MldSnoopingQuerier( Model ):
   vlans = Dict( keyType=int, valueType=MldSnoopingVlanQuerier,
         help="Mapping vlan Id to the vlan querier info" )
   _vlanSpecific = Bool( default=False,
         help="Display information in vlan specific format if True" )

   def render( self ):
      if not self.vlans:
         return
      elif self._vlanSpecific:
         vlanId = self.vlans.keys()[ 0 ]
         mldVersion = self.vlans[ vlanId ].mldVersion
         print '%-20s : %s\n%-20s : %s\n%-20s : %s\n%-20s : %s' % (
               'IP Address', self.vlans[ vlanId ].querierAddress,
               'MLD Version', '-' if mldVersion == 'unknown' else mldVersion,
               'Port', IntfCli.Intf.getShortname(
                  self.vlans[ vlanId ].querierInterface ),
               'Max response time', self.vlans[ vlanId ].queryResponseInterval )
         return
      print "%-5s %-24s %-8s %s" % (
            'Vlan', 'IP Address', 'Version', 'Port' )
      print "-" * 44
      for vlanId, info in sorted( self.vlans.iteritems() ):
         print '%-5s %-24s %-8s %s' % ( vlanId, info.querierAddress,
               '-' if info.mldVersion == 'unknown' else info.mldVersion,
               IntfCli.Intf.getShortname( info.querierInterface ) )
