# Copyright (c) 2016 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
'''
MplsEtbaLib.py

This module provides methods that may be useful for all MPLS-related ETBA plugins.
Previously, there was some code duplication between the EvpnMpls and MplsEtba ETBA
plugins, e.g. constructMplsHeader.
'''

import hashlib
import Tac
from TypeFuture import TacLazyType
from Arnet.IpTestLib import MplsHdrSize

EthAddr = TacLazyType( 'Arnet::EthAddr' )
EthIntfId = TacLazyType( 'Arnet::EthIntfId' )
FecId = TacLazyType( 'Smash::Fib::FecId' )
PortChannelIntfId = TacLazyType( 'Arnet::PortChannelIntfId' )
TunnelType = TacLazyType( 'Tunnel::TunnelTable::TunnelType' )
VlanIntfId = TacLazyType( 'Arnet::VlanIntfId' )

ARP_SMASH_DEFAULT_VRF_ID = Tac.Type( 'Vrf::VrfIdMap::VrfId' ).defaultVrf
IMP_NULL = Tac.Type( "Arnet::MplsLabel" ).implicitNull

def ecmpHash( coll, *values ):
   '''
   Based on the MD4 hash of all 'values', retrieve an entry
   from 'coll' to simulate ECMP hashing.
   '''
   num = len( coll )
   if num == 1:
      return coll[ 0 ]
   m = hashlib.new( 'md4' )
   for v in values:
      m.update( str( v ) )
   d = m.digest()
   idx = ord( d[ 0 ] ) % num
   return coll[ idx ]

def tcToMplsCosStatic( trafficClass ):
   # BUG137916 - Use cos-to-tc map to find the TC first, instead of 1:1 mapping
   tcToMplsCos = {
      0: 1,
      1: 0,
   }
   # Unspecified values in tcToMplsCos are identity
   return tcToMplsCos.get( trafficClass, trafficClass )

def removeImpNullFromLabelStack( labelStack ):
   '''
   Remove all implicit-NULL labels from the provided labelStack.
   '''
   idx = 0
   while idx < len( labelStack ):
      if labelStack[ idx ] == IMP_NULL:
         del labelStack[ idx ]
      else:
         idx += 1

def constructMplsHeader( labelStack, mplsTtl=255, mplsCos=None, tc=None,
                         setBos=True, flowLabelPresent=False ):
   """Construct the MPLS header for the given label stack
   labelStack: frame order list of labels, i.e. [ Top, ..., ..., Bottom ]
   mplsTtl: should be copied from the IP frame, or IP TTL - 1.  Default 255.
   mplsCos / tc: At most one of mplsCos / tc may be specified.  If tc is specified,
                 is it converted to an MPLS COS value using a static map.
   """
   assert len( labelStack ) > 0
   assert mplsCos is None or tc is None
   for label in labelStack:
      assert label != IMP_NULL, \
            'labelStack should not contain IMP_NULL {}'.format( labelStack )

   if tc is not None:
      cosValue = tcToMplsCosStatic( tc )
   elif mplsCos is not None:
      cosValue = mplsCos
   else:
      cosValue = 0

   mplsHeader = Tac.newInstance( 'Arnet::Pkt' )
   mplsHeader.newSharedHeadData = MplsHdrSize * len( labelStack )
   bosIndex = len( labelStack ) - 1
   for ( index, label ) in enumerate( labelStack ):
      offset = MplsHdrSize * index
      mplsHdrWrapper = Tac.newInstance( 'Arnet::MplsHdrWrapper', mplsHeader, offset )
      mplsHdrWrapper.label = label
      mplsHdrWrapper.ttl = mplsTtl
      mplsHdrWrapper.cos = cosValue
      if setBos:
         mplsHdrWrapper.bos = ( index == bosIndex )
         if mplsHdrWrapper.bos and flowLabelPresent:
            mplsHdrWrapper.ttl = 0
      else:
         mplsHdrWrapper.bos = False

   return mplsHeader.stringValue

def getTunnelEntry( tunnelTables, policyForwardingStatus, ethHdr, tunnelId ):
   tunnelType = tunnelId.tunnelType()
   if tunnelType == TunnelType.srTePolicyTunnel:
      # We need to retrieve the actual segment-list tunnel from the SR-TE policy FEC.
      srTePolicyFecId = FecId.tunnelIdToFecId( tunnelId )
      srTePolicyFec = policyForwardingStatus.fec.get( srTePolicyFecId )
      if not srTePolicyFec:
         return None
      srTePolicyVia = ecmpHash( srTePolicyFec.via, ethHdr.src, ethHdr.dst )
      # Overwrite the tunnelId and tunnelType to use the segment-list.
      tunnelId = Tac.Value( 'Tunnel::TunnelTable::TunnelId', srTePolicyVia.tunnelId )
      if not tunnelId.isValid():
         return None
      tunnelType = tunnelId.tunnelType()
   for tunnelTable in tunnelTables[ tunnelType ]:
      entry = tunnelTable.entry.get( tunnelId )
      if entry:
         return entry
   return None

def getIntfVlan( bridgingConfig, intfId ):
   if VlanIntfId.isVlanIntfId( intfId ):
      return VlanIntfId.vlanId( intfId )
   else:
      switchIntfConf = bridgingConfig.switchIntfConfig.get( intfId )
      if switchIntfConf is None or switchIntfConf.nativeVlan == 0:
         return None
      return switchIntfConf.nativeVlan

def isPortChannel( intfName ):
   # Adapted from Ale/AleHelper.cpp
   return PortChannelIntfId.isPortChannelIntfId( intfName )

def isRoutedPort( intfName ):
   # Adapted from Ale/AleHelper.cpp
   return ( isPortChannel( intfName ) or EthIntfId.isEthIntfId( intfName ) )

def getL2Intf( bridgingStatus, vlan, macAddr ):
   intf = Tac.ValueConst( 'Arnet::IntfId' )
   fdbStatus = bridgingStatus.fdbStatus.get( vlan )
   if not fdbStatus:
      return intf
   macEntry = fdbStatus.learnedHost.get( macAddr )
   if not macEntry:
      return intf
   return macEntry.intf
