# Copyright (c) 2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac

class MessageCounter:
   def __init__( self, success = 0, invalid = 0, filtered = 0 ):
      self.success = success
      self.invalid = invalid
      self.filtered = filtered

class PimCountersWrapper:
   # Wrapper class for reading and writing pim counters in smash status
   def __init__( self, smashStatus ):
      self.smashStatus = smashStatus

   def getCounter( self, direction, mtype, stype=None, intf=None ):
      if not self.smashStatus:
         return MessageCounter( 0, 0, 0 )

      if stype == None:
         stype = Tac.Value( "Routing::Pim::Packet::MessageSubtype" )
      if intf == None:
         intf = Tac.Value( "Arnet::IntfId" )
      counterKey = Tac.Value( "Routing::Pim::Smash::CounterKey",
                              mtype, stype, intf )
      if counterKey in self.smashStatus.msgCounter:
         data = self.smashStatus.msgCounter[ counterKey ]
         if direction == 'rx':
            return MessageCounter( data.rxMsgCounter.success,
                                   data.rxMsgCounter.invalid,
                                   data.rxMsgCounter.filtered )
         elif direction == 'tx':
            return MessageCounter( data.txMsgCounter.success,
                                   data.txMsgCounter.invalid,
                                   data.txMsgCounter.filtered )
         else:
            assert False, "Invalid direction: %s" % direction
      else:
         return MessageCounter( 0, 0, 0 )

   def setCounter( self, direction, mtype, stype = None, intf = None,
                   counter = MessageCounter() ):
      if stype == None:
         stype = Tac.Value( "Routing::Pim::Packet::MessageSubtype" )
      if intf == None:
         intf = Tac.Value( "Arnet::IntfId" )
      counterKey = Tac.Value( "Routing::Pim::Smash::CounterKey", mtype, stype, intf )
      if counterKey in self.smashStatus.msgCounter:
         data = Tac.nonConst( self.smashStatus.msgCounter[ counterKey ] )
      else:
         data = Tac.Value( "Routing::Pim::Smash::MessageCounterData", counterKey )
      tacCounter = Tac.Value( "Routing::Pim::IO::MessageCounter",
                              success = counter.success,
                              invalid = counter.invalid,
                              filtered = counter.filtered )
      if direction == 'rx':
         data.rxMsgCounter = tacCounter
      elif direction == 'tx':
         data.txMsgCounter = tacCounter
      else:
         assert False, "Invalid direction: %s" % direction
      self.smashStatus.msgCounter.addMember( data )

   def clearCounters( self ):
      if self.smashStatus:
         self.smashStatus.msgCounter.clear()

# Pim message types
messageTypeDict = {
      0 : 'Hello',
      1 : 'Register',
      2 : 'RegisterStop',
      3 : 'JoinPrune',
      4 : 'BootstrapRouter',
      5 : 'Assert',
      6 : 'Graft',
      7 : 'GraftAck',
      8 : 'CrpAdvertisement',
      10 : 'DFElection'
}

# Pim message subtypes
def messageSubtypeDict( msgTypeNum ):
   if messageTypeDict[ msgTypeNum ] == 'Register':
      return {
            0 : 'Null Register',
            1 : 'Data Register'
      }
   elif messageTypeDict[ msgTypeNum ] == 'DFElection':
      return {
            1 : 'Offer',
            2 : 'Winner',
            3 : 'Backoff',
            4 : 'Pass'
      }
   elif messageTypeDict[ msgTypeNum ] == 'JoinPrune':
      return {
            0 : 'Datagram',
            1 : 'SCTP',
      }
   else:
      return {}
