#!/usr/bin/env python
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac, Tracing
import CliSave
import IntfCliSave
import Arnet
from RoutingIntfUtils import allRoutingProtocolIntfNames

traceIntf = Tracing.trace4

IntfCliSave.IntfConfigMode.addCommandSequence( 'Mld.intf', after=[ 'Ira.ipIntf' ] )

def isWildcardAddr( addr ):
   if isinstance( addr, str ):
      addr = Arnet.Ip6Addr( addr )

   return addr == Arnet.Ip6Addr( "::" )

@CliSave.saver( 'Routing::Mld::IntfConfig', 'routing6/mld/config',
                requireMounts=( 'interface/config/all',
                                'interface/status/all' ) )
def saveMldConfig( mldConfig, root, sysdbRoot, options, requireMounts ):
   saveAll = options.saveAll
   saveAllDetail = options.saveAllDetail

   intfNames = None

   if saveAllDetail:
      intfNames = allRoutingProtocolIntfNames( sysdbRoot, includeEligible=True,
                                               requireMounts=requireMounts )
   elif saveAll:
      # Routing configuration is allowed on switchports as well.
      # Save configuration on all routing protocol interfaces and switchports
      # with non-default config.
      intfNames = set(
            allRoutingProtocolIntfNames( sysdbRoot, requireMounts=requireMounts ) +
            mldConfig.configIntf.keys() )
   else:
      intfNames = mldConfig.configIntf

   for intfName in intfNames:
      intf = mldConfig.configIntf.get( intfName )

      if intf is None:
         if saveAll:
            intf = Tac.newInstance( 'Routing::Mld::ConfigIntf', intfName )
            intf.staticConfig = ()
            intf.staticConfig.staticGroupConfig = ( "direct", )
            intf.querierConfig = ( intfName, )
         else:
            continue

      saveMldIntf( intf, root, sysdbRoot, saveAll, saveAllDetail )

def saveMldIntf( intf, root, sysdbRoot, saveAll, saveAllDetail ):
   traceIntf( "Saving interface %s, save all %s, detail %s" %
              ( intf.intfId, saveAll, saveAllDetail ) )

   mode = root[ IntfCliSave.IntfConfigMode ].getOrCreateModeInstance( intf.intfId )
   cmds = mode[ 'Mld.intf' ]

   if ( intf.enabled ):
      cmds.addCommand( 'mld' )
   elif saveAll:
      cmds.addCommand( 'no mld' )

   if intf.staticAccessList != intf.staticAccessListDefault:
      cmds.addCommand( 'mld static-group access-list %s' % intf.staticAccessList )
   elif saveAll:
      cmds.addCommand( 'no mld static-group access-list' )

   if len( intf.staticConfig.staticGroupConfig.sourceByGroup ) > 0:
      for group in intf.staticConfig.staticGroupConfig.sourceByGroup.values():
         for source in group.sourceAddr.keys():
            if not isWildcardAddr( source ):
               cmds.addCommand( 'mld static-group %s %s' % ( group.groupAddr,
                                                             source ) )
            else:
               cmds.addCommand( 'mld static-group %s' % ( group.groupAddr ) )
   if ( intf.querierConfig.queryInterval != \
        intf.querierConfig.queryIntervalDefault or \
        saveAll ):
      cmds.addCommand( 'mld query-interval %d' % \
                       intf.querierConfig.queryInterval )

   if ( intf.querierConfig.queryResponseInterval != \
        intf.querierConfig.queryResponseIntervalDefault or \
        saveAll ):
      cmds.addCommand( 'mld query-response-interval %d' % \
                       intf.querierConfig.queryResponseInterval )

   if ( intf.querierConfig.startupQueryInterval != \
         intf.querierConfig.startupQueryIntervalDefault ):
      cmds.addCommand( 'mld startup-query-interval %d' % \
                       intf.querierConfig.startupQueryInterval )
   elif saveAll:
      # startup query interval is special case because by default
      # it is set to 0 and this means it is not configured and
      # we should use 1/4 of the query interval.
      # mld startup-query-interval 0 is rejected as out of range.
      # mld startup-query-interval 31.25 is rejected too because.
      # it needs to be a whole number. (1/4 * 125 = 31.25)
      cmds.addCommand( 'no mld startup-query-interval' )

   if ( intf.querierConfig.startupQueryCount != \
        intf.querierConfig.startupQueryCountDefault or \
        saveAll ):
      cmds.addCommand( 'mld startup-query-count %d' % \
                       intf.querierConfig.startupQueryCount )

   if ( intf.querierConfig.robustness != \
        intf.querierConfig.robustnessDefault or \
        saveAll ):
      cmds.addCommand( 'mld robustness %d' % \
                       intf.querierConfig.robustness )

   if ( intf.querierConfig.lastListenerQueryInterval != \
        intf.querierConfig.lastListenerQueryIntervalDefault or \
        saveAll ):
      cmds.addCommand( 'mld last-listener-query-interval %d' % \
                       intf.querierConfig.lastListenerQueryInterval )

   if ( intf.querierConfig.lastListenerQueryCount != \
        intf.querierConfig.lastListenerQueryCountDefault or \
        saveAll ):
      cmds.addCommand( 'mld last-listener-query-count %d' % \
                       intf.querierConfig.lastListenerQueryCount )

