# Copyright (c) 2013 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

#-------------------------------------------------------------------------------
# This module contains utilities for working with packets in the
# format returned by recvmsg
#-------------------------------------------------------------------------------

from eunuchs.if_packet_h import SOL_PACKET, PACKET_AUXDATA, TP_STATUS_VLAN_VALID, \
                                TP_STATUS_VLAN_TPID_VALID
from eunuchs.if_ether_h import ETH_ALEN, ETH_P_8021Q, ETH_P_8021AD
from eunuchs.if_vlan_h import VLAN_PRIO_MASK, VLAN_PRIO_SHIFT, VLAN_CFI_MASK, \
                              VLAN_VID_MASK
from struct import pack, unpack, calcsize

#-------------------------------------------------------------------------------
# Function to get vlan tag from msg_control.
# Returns vlan_tci in native endian-ness or None.
#-------------------------------------------------------------------------------

minAuxData = 'IIIHHH'
minAuxDataSize = calcsize( minAuxData )
auxDataMaybeWithTpid = 'IIIHHHH'
auxDataMaybeWithTpidSize = calcsize( auxDataMaybeWithTpid )

def getVlanTagFromCmsg( pktCtrl ):
   for cmsgLevel, cmsgType, cmsgData in pktCtrl:
      if cmsgLevel == SOL_PACKET and cmsgType == PACKET_AUXDATA:
         if len( cmsgData ) < minAuxDataSize:
            # auxdata is not valid
            raise ValueError( "Invalid auxiliary data" )
         if len( cmsgData ) >= auxDataMaybeWithTpidSize:
            tp_status, tp_len, tp_snaplen, tp_mac, tp_net, tp_vlan_tci, \
                  tp_vlan_tpid = unpack( auxDataMaybeWithTpid, \
                                    cmsgData[ 0 : auxDataMaybeWithTpidSize ] )
         else:
            tp_status, tp_len, tp_snaplen, tp_mac, tp_net, tp_vlan_tci = \
                     unpack( minAuxData, cmsgData[ 0 : minAuxDataSize ] )
            tp_vlan_tpid = None

         if ( not ( tp_status & TP_STATUS_VLAN_VALID ) ):
            continue
         if ( tp_status & TP_STATUS_VLAN_TPID_VALID ):
            return ( tp_vlan_tci, tp_vlan_tpid )
         else:
            return ( tp_vlan_tci, ETH_P_8021Q )
         break
   return ( None, None )

#-------------------------------------------------------------------------------
# Function to insert a vlan tag into a packet.
# If maxsize > 0 then that is the maximum length of the returned packet,
# otherwise it returns the whole packet after vlan tag insertion.
# Argument vlan_tci is in native endian-ness.
#-------------------------------------------------------------------------------

vlanTagOffset = ETH_ALEN * 2

def insertVlanTag( pkt, maxsize=0, vlan_tci=None, vlan_tpid=ETH_P_8021Q ):
   if vlan_tci is None:
      return pkt
   if vlan_tpid != ETH_P_8021Q and vlan_tpid != ETH_P_8021AD:
      return pkt
   pktLen = len( pkt )
   if pktLen <= vlanTagOffset:
      # No room for vlan tag
      if maxsize > 0:
         return pkt[ 0 : maxsize ]
      return pkt
   pktData = pkt[ 0 : vlanTagOffset ] + \
             pack( '>HH', vlan_tpid, vlan_tci ) + \
             pkt[ vlanTagOffset : ]
   if maxsize > 0:
      return pktData[ 0 : maxsize ]
   return pktData

#-------------------------------------------------------------------------------
# Function to delete a separate vlan tag from a packet.
#-------------------------------------------------------------------------------

def delSeparateVlanTag( pkt ):
   ''' Delete separate vlan tag from packet
   '''
   pkt.tciPresent = False
   pkt.tciField = 0
   pkt.vlanTpid = 0

#-------------------------------------------------------------------------------
# Function to set a separate vlan tag in a packet.
#-------------------------------------------------------------------------------

def newSeparateVlanTag( pkt, tciField, vlanTpid ):
   ''' Set separate vlan tag in packet with complete tciField
   '''
   pkt.tciField = tciField
   pkt.vlanTpid = vlanTpid
   pkt.tciPresent = True

#-------------------------------------------------------------------------------
# Function to create a tciField from its fields
#-------------------------------------------------------------------------------

def makeTciField( tciPriority, tciCfi, tciVlanId ):
   tciField = tciVlanId & VLAN_VID_MASK
   if tciCfi != 0:
      tciField |= VLAN_CFI_MASK
   tciField |= ( tciPriority << VLAN_PRIO_SHIFT ) & VLAN_PRIO_MASK
   return tciField

#-------------------------------------------------------------------------------
# Function to set a separate vlan tag by field in a packet.
#-------------------------------------------------------------------------------

def newSeparateVlanTagByField( pkt, tciPriority, tciCfi, tciVlanId, vlanTpid ):
   ''' Set separate vlan tag in packet specifying individual fields
   '''
   tciField = makeTciField( tciPriority, tciCfi, tciVlanId )
   newSeparateVlanTag( pkt, tciField, vlanTpid )

