# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import CliExtensions

# Hook for platform - specific capacity measurements
mirroringCapacityHook = CliExtensions.CliHook()

def getDirForSession( sess, srcIntfToExclude=None ):
   rx = False
   tx = False
   for srcIntf in sess.srcIntf.itervalues():
      if srcIntf == srcIntfToExclude:
         continue
      if srcIntf.direction == 'directionTx':
         tx = True
      elif srcIntf.direction == 'directionRx':
         rx = True
      elif srcIntf.direction == 'directionBoth':
         tx = True
         rx = True
      if tx and rx:
         break
   if rx and tx:
      return 'directionBoth'
   return 'directionTx' if tx else 'directionRx'

def sessionsForDir( direction, mirroringHwCapability ):
   if direction == 'directionBoth':
      return mirroringHwCapability.sessionsUsedForBothDir
   else:
      return 1

def numOfSessionUsed( sess, mirroringHwCapability ):
   if not sess.greTunnelKey.dstIpGenAddr.isAddrZero and \
      not sess.dropGreTunnelKey.dstIpGenAddr.isAddrZero:
      return 2
   direction = getDirForSession( sess )
   return sessionsForDir( direction, mirroringHwCapability )

def numValidSessions( mirroringHwCapability, mirroringConfig ):
   numSessions = 0
   for sessionCfg in mirroringConfig.session.itervalues():
      if sessionCfg.srcIntf or sessionCfg.targetIntf:
         numSessions += numOfSessionUsed( sessionCfg, mirroringHwCapability )
   return numSessions

def isMaxSessionsReached( mode, sessionToAdd, mirroringHwCapability, mirroringConfig,
                          src=None, forwardingDrop=False ):
   if len( mirroringCapacityHook.extensions() ):  # platform specific
      for hook in mirroringCapacityHook.extensions():
         # True: max reached, False: max not reached, None: inconclusive
         res = hook( mode, sessionToAdd, src=src,
                     forwardingDrop=forwardingDrop )
         if res is not None:
            return res
   maxSessions = mirroringHwCapability.maxSupportedSessions
   numSessions = numValidSessions( mirroringHwCapability, mirroringConfig )
   if sessionToAdd in mirroringConfig.session:
      sess = mirroringConfig.session[ sessionToAdd ]
      if src:
         intf, direction = src
         if intf in sess.srcIntf and \
            len( sess.srcIntf ) == 1:
            srcIntfDirection = sess.srcIntf[ intf ].direction
            numSessions += ( sessionsForDir( direction, mirroringHwCapability ) -
                             sessionsForDir( srcIntfDirection,
                                             mirroringHwCapability ) )
         else:
            sessDir = getDirForSession( sess, intf )
            if sessDir != direction:
               directionBoth = 'directionBoth'
               numSessions += ( sessionsForDir( directionBoth,
                                                mirroringHwCapability ) -
                                sessionsForDir( sessDir, mirroringHwCapability ) )
      elif ( not forwardingDrop and
             not sess.dropGreTunnelKey.dstIpGenAddr.isAddrZero ) or \
             ( forwardingDrop and
               not sess.greTunnelKey.dstIpGenAddr.isAddrZero ):
         numSessions += 1
   else:
      if src:
         _, direction = src
         numSessions += sessionsForDir( direction, mirroringHwCapability )
      else:
         numSessions += 1
   return maxSessions and numSessions > maxSessions


