#!/usr/bin/env python
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pkgdeps: rpmwith %{_libdir}/SysdbMountProfiles/ConfigAgent-Mld

import Tac
import Tracing, QuickTrace
import ConfigMount, SharedMem, SmashLazyMount
import Arnet
import CliPlugin.VrfCli as VrfCli

traceIntf = Tracing.trace4
traceGroup = Tracing.trace5
traceSource = Tracing.trace6

qv = QuickTrace.Var
qtraceIntf = QuickTrace.trace4
qtraceGroup = QuickTrace.trace5
qtraceSource = QuickTrace.trace6

vrfExprFactory = VrfCli.VrfExprFactory(
      helpdesc='VRF name' )

class MldConfig( object ):

   def __init__( self, entityManager ):

      self.em = entityManager
      self.intfConfig = ConfigMount.mount( self.em, "routing6/mld/config",
                                           "Routing::Mld::IntfConfig", "w" )

   def intf( self, intfId ):
      return self.intfConfig.configIntf.get( intfId )

   def intfIs( self, intfId ):

      intf = self.intf( intfId )
      if intf is not None:
         return intf

      traceIntf( "Adding interface %s to config" % intfId )
      qtraceIntf( "Adding interface ", qv( intfId ), " to config" )

      intf = self.intfConfig.configIntf.newMember( intfId )
      intf.staticConfig = ()
      intf.staticConfig.staticGroupConfig = ( "direct", )
      intf.querierConfig = ( intfId, )

      return intf

   def intfDel( self, intfId ):
      if intfId not in self.intfConfig.configIntf:
         return

      traceIntf( "Deleting interface %s from config" % intfId )
      qtraceIntf( "Deleting interface ", qv( intfId ), " from config" )

      del self.intfConfig.configIntf[ intfId ]

   def intfUpdate( self, intfId ):
      '''Deletes the inteface from config if all the attributes are back to the
      default value'''

      intf = self.intf( intfId )
      if intf is None:
         return

      if ( not intf.enabled and
           intf.querierConfig.queryInterval == \
                 intf.querierConfig.queryIntervalDefault and
           intf.querierConfig.queryResponseInterval == \
                 intf.querierConfig.queryResponseIntervalDefault and
           intf.querierConfig.lastListenerQueryInterval == \
                 intf.querierConfig.lastListenerQueryIntervalDefault and
           intf.querierConfig.startupQueryInterval == \
                 intf.querierConfig.startupQueryIntervalDefault and
           intf.querierConfig.startupQueryCount == \
                 intf.querierConfig.startupQueryCountDefault and
           intf.querierConfig.robustness == \
                 intf.querierConfig.robustnessDefault and
           intf.querierConfig.lastListenerQueryCount == \
                 intf.querierConfig.lastListenerQueryCountDefault and
           len( intf.staticConfig.staticGroupConfig.sourceByGroup ) == 0 and
           intf.staticAccessList == intf.staticAccessListDefault ) :
         self.intfDel( intfId )

   def group( self, intfId, groupAddr ):
      intf = self.intf( intfId )
      if intf is None:
         return None

      if isinstance( groupAddr, str ):
         groupAddr = Arnet.Ip6Addr( groupAddr )

      return intf.staticConfig.staticGroupConfig.sourceByGroup.get( groupAddr )

   def groupIs( self, intfId, groupAddr ):
      group = self.group( intfId, groupAddr )
      if group is not None:
         return group

      intf = self.intfIs( intfId )

      if isinstance( groupAddr, str ):
         groupAddr = Arnet.Ip6Addr( groupAddr )

      traceGroup( "Adding group %s to intf %s" % ( groupAddr, intfId ) )
      qtraceGroup( "Adding group ", qv( groupAddr ), " to interface ",
                   qv( intfId ) )

      return intf.staticConfig.staticGroupConfig.sourceByGroup.newMember( groupAddr )

   def groupDel( self, intfId, groupAddr ):
      intf = self.intf( intfId )
      if intf is None:
         return

      if isinstance( groupAddr, str ):
         groupAddr = Arnet.Ip6Addr( groupAddr )

      if groupAddr in intf.staticConfig.staticGroupConfig.sourceByGroup:
         traceGroup( "Deleting group %s from intf %s" % ( groupAddr, intfId ) )
         qtraceGroup( "Deleting group ", qv( groupAddr ), " from interface ",
                      qv( intfId ) )
         del intf.staticConfig.staticGroupConfig.sourceByGroup[ groupAddr ]

   def source( self, intfId, groupAddr, sourceAddr ):
      group = self.group( intfId, groupAddr )
      if group is None:
         return None

      if isinstance( sourceAddr, str ):
         sourceAddr = Arnet.Ip6Addr( sourceAddr )

      return group.sourceAddr.get( sourceAddr )

   def sourceIs( self, intfId, groupAddr, sourceAddr ):
      group = self.groupIs( intfId, groupAddr )

      if isinstance( sourceAddr, str ):
         sourceAddr = Arnet.Ip6Addr( sourceAddr )

      traceSource( "Adding source %s to group %s on intf %s" %
                   ( sourceAddr, groupAddr, intfId ) )
      qtraceSource( "Adding source ", qv( sourceAddr ), " to group ",
                    qv( groupAddr ), " on interface ", qv( intfId ) )

      group.sourceAddr[ sourceAddr ] = True

   def sourceDel( self, intfId, groupAddr, sourceAddr ):
      group = self.group( intfId, groupAddr )
      if group is None:
         return

      if isinstance( sourceAddr, str ):
         sourceAddr = Arnet.Ip6Addr( sourceAddr )

      if sourceAddr in group.sourceAddr:
         traceSource( "Deleting source %s from group %s on intf %s" %
                      ( sourceAddr, groupAddr, intfId ) )
         qtraceSource( "Deleting source ", qv( sourceAddr ), " from group ",
                       qv( groupAddr ), " on interface ", qv( intfId ) )
         del group.sourceAddr[ sourceAddr ]
         if len( group.sourceAddr ) == 0:
            self.groupDel( intfId, groupAddr )

   def aclIs( self, intfId, aclName ):
      intf = self.intfIs( intfId )

      traceIntf( "Adding acl %s to intf %s" % ( aclName, intfId ) )
      qtraceIntf( "Adding acl ", qv( aclName ), " to interface ",
                   qv( intfId ) )

      intf.staticAccessList = aclName

   def aclDel( self, intfId ):
      intf = self.intf( intfId )
      if intf is None:
         return

      aclName = intf.staticAccessList

      traceIntf( "Deleting acl %s from intf %s" % ( aclName, intfId ) )
      qtraceIntf( "Deleting acl ", qv( aclName ), " from interface ",
                     qv( intfId ) )

      intf.staticAccessList = intf.staticAccessListDefault
      

   def queryIntervalIs( self, intfId, queryInterval=None ):
      intf = self.intfIs( intfId )
      if queryInterval is None:
         intf.querierConfig.queryInterval = \
             intf.querierConfig.queryIntervalDefault
      else:
         intf.querierConfig.queryInterval = queryInterval

   def queryResponseIntervalIs( self, intfId, queryResponseInterval=None ):
      intf = self.intfIs( intfId )
      if queryResponseInterval is None:
         intf.querierConfig.queryResponseInterval = \
             intf.querierConfig.queryResponseIntervalDefault
      else:
         intf.querierConfig.queryResponseInterval = queryResponseInterval

   def startupQueryIntervalIs( self, intfId, queryInterval=None ):
      intf = self.intfIs( intfId )
      if queryInterval is None:
         intf.querierConfig.startupQueryInterval = \
             intf.querierConfig.startupQueryIntervalDefault
      else:
         intf.querierConfig.startupQueryInterval = queryInterval

   def startupQueryCountIs( self, intfId, queryCount=None ):
      intf = self.intfIs( intfId )
      if queryCount is None:
         intf.querierConfig.startupQueryCount = \
             intf.querierConfig.startupQueryCountDefault
      else:
         intf.querierConfig.startupQueryCount = queryCount

   def robustnessIs( self, intfId, count=None ):
      intf = self.intfIs( intfId )
      if count is None:
         intf.querierConfig.robustness = \
             intf.querierConfig.robustnessDefault
      else:
         intf.querierConfig.robustness = count

   def lastListenerQueryIntervalIs( self, intfId, queryInterval=None ):
      intf = self.intfIs( intfId )
      if queryInterval is None:
         intf.querierConfig.lastListenerQueryInterval = \
             intf.querierConfig.lastListenerQueryIntervalDefault
      else:
         intf.querierConfig.lastListenerQueryInterval = queryInterval

   def lastListenerQueryCountIs( self, intfId, queryCount=None ):
      intf = self.intfIs( intfId )
      if queryCount is None:
         intf.querierConfig.lastListenerQueryCount = \
             intf.querierConfig.lastListenerQueryCountDefault
      else:
         intf.querierConfig.lastListenerQueryCount = queryCount

class MldStatus( object ):

   def __init__( self, entityManager ):

      self.em = entityManager
      self.smashMount = SharedMem.entityManager( sysdbEm=self.em )
      self.readerInfo = SmashLazyMount.mountInfo( 'reader' )

      self.gmpStatusByVrf = {}

      self.mldQuerierStatusByVrf = {}

      self.mldStatusByVrf = {}
      self.mldStatusByVrf[ "staticGroup" ] = {}
      self.mldStatusByVrf[ "dynamicGroup" ] = {}

      self.mldStatus = Tac.newInstance( "Routing::Mld::Smash::Status", "Cli" )
      self.mldQuerierStatus = \
         Tac.newInstance( "Routing::Mld::Smash::QuerierStatus", "Cli" )


   def getGmpStatus( self, vrfName ):

      vrfStatus = self.gmpStatusByVrf.get( vrfName )
      if vrfStatus is not None:
         return vrfStatus

      mountPath = "routing6/gmp/status/" + vrfName
      vrfStatus = self.smashMount.doMount( mountPath,
                                           "Routing::Gmp::Smash::Status",
                                           self.readerInfo )
      self.gmpStatusByVrf[ vrfName ] = vrfStatus
      return vrfStatus

   def getMldStatus( self, vrfName, groupType ):

      assert groupType in [ "staticGroup", "dynamicGroup" ]

      vrfStatus = self.mldStatusByVrf[ groupType ].get( vrfName )
      if vrfStatus is not None:
         return vrfStatus

      mountPath = self.mldStatus.mountPath( groupType, vrfName )
      vrfStatus = self.smashMount.doMount( mountPath,
                                           "Routing::Mld::Smash::Status",
                                           self.readerInfo )
      self.mldStatusByVrf[ groupType ][ vrfName ] = vrfStatus

      return vrfStatus

   def getMldQuerierStatus( self, vrfName ):

      vrfStatus = self.mldQuerierStatusByVrf.get( vrfName )
      if vrfStatus is not None:
         return vrfStatus

      mountPath = self.mldQuerierStatus.mountPath( vrfName )
      vrfStatus = self.smashMount.doMount( mountPath,
                                           "Routing::Mld::Smash::QuerierStatus",
                                           self.readerInfo )

      self.mldQuerierStatusByVrf[ vrfName ] = vrfStatus

      return vrfStatus


class MldAddrs( object ):

   def __init__( self ):
      self.multicastRange = Arnet.Ip6Prefix( "ff00::/8" )
      self.ssmRange = Arnet.Ip6Prefix( "ff3::/32" )
      self.wcAddr = Arnet.Ip6Addr( "::" )

   def isWildcardAddr( self, addr ):
      if isinstance( addr, str ):
         addr = Arnet.Ip6Addr( addr )

      return addr == self.wcAddr

   def validSsmGroup( self, addr ):

      if isinstance( addr, str ):
         addr = Arnet.Ip6Addr( addr )

      return self.ssmRange.contains( addr ) # pylint: disable-msg=E1103

   def validMulticastAddr( self, addr ):
      if isinstance( addr, str ):
         addr = Arnet.Ip6Addr( addr )

      return self.multicastRange.contains( addr ) # pylint: disable-msg=E1103

   def validUnicastAddr( self, addr ):
      return not self.validMulticastAddr( addr )
