#!/usr/bin/env python
# Copyright (c) 2015 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Arnet
import collections
import errno
from functools import wraps
import random
import sys
import Tac
from TypeFuture import TacLazyType
from PseudowireLib import (
      pwPingCcEnumToVal,
      PwPingCcType,
)
from ClientCommonLib import ( 
   getDsMappingInfo, 
   getIntfPrimaryIpAddr,
   isIpv6Addr,
   IPV4, IPV6,
   LspPingTypeBgpLu,
   LspPingTypeLdp,
   LspPingTypeMldp,
   LspPingTypeSr,
   LspPingDDMap,
)
from MplsTracerouteClientLib import (
   tracerouteDownstreamInfoRender,
   tracerouteLabelStackRender,
   tracerouteReplyHdrRender,
)
from CliPlugin.MplsUtilModel import DownstreamInfo, LabelStackInfo, HopPktInfo
from ClientState import (
   sessionIdIncr,
   cleanupGlobalState
)

import Toggles.MplsUtilsToggleLib
import Toggles.PseudowireToggleLib

# RFC4379
# An MPLS echo request is a UDP packet.  The IP header is set as follows: the source 
# IP address is a routable address of the sender; the destination IP address is a 
# (randomly chosen) IPv4 address from the range 127/8 or IPv6 address from the range 
# 0:0:0:0:0:FFFF:127/104.
# 
# XXX Packets destining to 0:0:0:0:0:FFFF:127.0.0.1 are not trapped on Arad today.
# So use V6 loopback address here. (BUG123983)
DefaultLspPingReqDst = { IPV4 : '127.0.0.1',
                         IPV6 : '::ffff:127.0.0.1' }

AddressFamily = Tac.Type( 'Arnet::AddressFamily' )
IpGenAddr = Tac.Type( 'Arnet::IpGenAddr' )
IpGenPrefix = Tac.Type( 'Arnet::IpGenPrefix' )
Ipv4Unnumbered = Tac.enumValue( 'LspPing::LspPingAddrType', 'ipv4Unnumbered' )
Ipv6Unnumbered = Tac.enumValue( 'LspPing::LspPingAddrType', 'ipv6Unnumbered' )
Ipv4Numbered = Tac.enumValue( 'LspPing::LspPingAddrType', 'ipv4Numbered' )
Ipv6Numbered = Tac.enumValue( 'LspPing::LspPingAddrType', 'ipv6Numbered' )
LspPingDownstreamMappingType = TacLazyType(
      'LspPing::LspPingDownstreamMappingType' )
LspPingTunnelType = Tac.Type( 'LspPing::LspPingTunnelType' )
LspPingPwInfo = Tac.Type( 'LspPing::LspPingPwInfo' )
LspPingP2mpFecStackSubTlvType = TacLazyType(
      'LspPing::LspPingP2mpFecStackSubTlvType' )
LspPingTxCountConst = Tac.Type( 'LspPing::LspPingTxCountConstants' )
LspPingMultipathBitset = Tac.Type( 'LspPing::LspPingMultipathBitset' )
LspPingReturnCode = Tac.Type( 'LspPing::LspPingReturnCode' )
MplsLabelStack = Tac.Type( 'Arnet::MplsLabelStack' )
MplsLabel = Tac.Type( 'Arnet::MplsLabel' )
TracerouteRetArgs = collections.namedtuple( 'TracerouteRetArgs',
                                            [ 'retVal', 'txPkts', 'replyHostRtts' ] )
TunnelTableIdentifier = TacLazyType( "Tunnel::TunnelTable::TunnelTableIdentifier" )

timerWheel = Tac.newInstance( "Ark::TimerWheel", Tac.activityManager.clock, 
                              100, 10 * 60 * 5, True, 1024 )
IpAddrAllRoutersMulticast = "224.0.0.2"
Ip6AddrAllRoutersMulticast = "ff02::2"

class LspPingResult( object ):
   def __init__( self, time, requests, replyInfo, kbInt ):
      self.duration = time
      self.requests = requests
      self.replyInfo = replyInfo
      self.kbInt = kbInt

# ---------------------------------------------------------
#               Common utility functions
# ---------------------------------------------------------

def createLspPingClient( mount, state, interface, src, dst, tracerouteInfo=None,
                         verbose=False ):
   '''
      Creates, assigns, and returns a global instance of the LspPingClientRoot SM 
   '''
   if not state.clientRoot:
      # Get the udpPam
      lspPingClientRootUdpPam = Tac.newInstance( 'Arnet::UdpPam', 'udppam' )
      vrfIpIntfStatus = vrfIp6IntfStatus = None
      if isIpv6Addr( src ) or isIpv6Addr( dst ):   
         lspPingClientRootUdpPam.ipv6En = True
         lspPingClientRootUdpPam.txDstIp6Addr = Arnet.Ip6Addr( dst )
         vrfIp6IntfStatus = mount.vrfIp6IntfStatus
      else:
         lspPingClientRootUdpPam.txDstIpAddr = dst
         vrfIpIntfStatus = mount.vrfIpIntfStatus
      if state.clientRootUdpPamSrcPort is not None:
         lspPingClientRootUdpPam.rxPortLocal = state.clientRootUdpPamSrcPort

      lspPingClientRootUdpPam.mode = 'server'
      if verbose:
         print "Listening on UDP port %d" % lspPingClientRootUdpPam.rxPort

      # Bump up the sock buffer size to 1MB
      if lspPingClientRootUdpPam.incSockBufSize( 1048576 ) == -1:
         return None

      state.lspPingClientRootUdpPam = lspPingClientRootUdpPam
      local = mount.entityManager.root().parent
      if tracerouteInfo:
         # set pam rxPort to be used in multipath code
         tracerouteInfo.srcUdpPort = lspPingClientRootUdpPam.rxPort
      lspPingUtilRoot = local.newEntity( 'LspPing::LspPingUtilRoot', 'LspPingUtil' )
      lspPingUtilRoot.lspPingClientRoot = ( mount.allIntfStatusDir, 
                                            mount.allIntfStatusLocalDir, 
                                            mount.kniStatus,
                                            vrfIpIntfStatus, 
                                            vrfIp6IntfStatus,
                                            lspPingClientRootUdpPam,
                                            timerWheel, tracerouteInfo )
      lspPingClientRoot = lspPingUtilRoot.lspPingClientRoot
      # pam may not be ready initially during lspPingClientRoot initialization,
      # causing it to not being added to EthDevPamCollection immediately.
      pamColl = lspPingClientRoot.lspPingEthDevPamManager.ethDevPamCollection
      Tac.waitFor( lambda: interface in pamColl )
      state.clientRoot = lspPingClientRoot

   return state.clientRoot

def createLspPingClientConfig( mount, state, clientId, tunnelType, interface, 
                               labelStack, src, dst, dmac, smac=None, interval=1,
                               mplsTtl=None, count=None, verbose=False, tc=None, 
                               dstPrefix=None, dsMappingInfo=None, rsvpSpId=None,
                               pwInfo=None, rsvpSenderAddr=None, mldpInfo=None,
                               standard=None, size=None, padReply=False, tos=None,
                               dstype=None, setFecValidateFlag=True,
                               egressValidateAddress=None,
                               lookupLabelCount=1 ):
   '''
      Adds an entry to the global LspPingClientConfig collection, 
      which is used to craft and send packets.
   '''
   lspPingClientRoot = state.clientRoot
   lspPingClientConfigColl = lspPingClientRoot.lspPingClientConfigColl
   lspPingClientStatusColl = lspPingClientRoot.lspPingClientStatusColl
   lspPingClientRootUdpPam = state.lspPingClientRootUdpPam

   lspPingClientConfig = lspPingClientConfigColl.newPingClientConfig( clientId )
   lspPingClientConfig.oamStandard = standard or mount.config.oamStandard

   if egressValidateAddress is not None:
      lspPingClientConfig.egressValidate = True
      egressAddress = Tac.newInstance( 'Arnet::IpGenAddr',
                                       egressValidateAddress )
      lspPingClientConfig.egressValidateAddress = egressAddress

   lspPingClientConfig.tunnelType = tunnelType
   lspPingClientConfig.txIntfId = interface

   # By default mappingTlv is downstreamMap
   if dstype == LspPingDDMap:
      lspPingClientConfig.mappingTlv = \
         LspPingDownstreamMappingType.downstreamDetailedMap
   if dsMappingInfo and dsMappingInfo.valid:
      lspPingClientConfig.dsMappingInfo = dsMappingInfo
   if tos:
      lspPingClientConfig.tos = tos
   lspPingClientConfig.mplsTc = tc if tc else 0
   lspPingClientConfig.srcIpAddr = IpGenAddr( src )
   lspPingClientConfig.dstIpAddr = IpGenAddr( dst )
   if dstPrefix:
      lspPingClientConfig.dstPrefix = IpGenPrefix( dstPrefix )

   lspPingClientConfig.srcUdpPort = lspPingClientRootUdpPam.rxPort
   lspPingClientConfig.srcEthAddr = ( smac or mount.bridgingConfig.bridgeMacAddr )
   lspPingClientConfig.tunnelNexthopEthAddr = dmac
   lspPingClientConfig.txRate = interval
   lspPingClientConfig.pktDelay = random.randint( 50, 100 ) / 1000.0 # ms
   lspPingClientConfig.txCount = LspPingTxCountConst.infinite if count is None \
                                 else count

   if size:
      lspPingClientConfig.pktSize = size
      lspPingClientConfig.padReply = padReply

   lspPingClientConfig.lookupLabelCount = lookupLabelCount
   lspPingClientConfig.setFecValidateFlag = setFecValidateFlag

   # mldp config
   if mldpInfo:
      if mldpInfo.jitter:
         lspPingClientConfig.echoJitter = mldpInfo.jitter
      if mldpInfo.genOpqVal:
         lspPingClientConfig.opqType = LspPingP2mpFecStackSubTlvType.GenericLsp
         lspPingClientConfig.genericOpqVal = mldpInfo.genOpqVal
      if mldpInfo.sourceAddrOpqVal and mldpInfo.groupAddrOpqVal:
         lspPingClientConfig.opqType = LspPingP2mpFecStackSubTlvType.SGIpv4
         lspPingClientConfig.sourceAddrOpqVal = \
               IpGenAddr( mldpInfo.sourceAddrOpqVal )
         lspPingClientConfig.groupAddrOpqVal = IpGenAddr( mldpInfo.groupAddrOpqVal )
      if mldpInfo.responderAddr:
         lspPingClientConfig.nodeResponderAddr = IpGenAddr( mldpInfo.responderAddr )
   if tunnelType == LspPingTunnelType.tunnelMldp:
      # For mldp ,Downstream Detailed Mapping Tlv is added for dsMappingInfo
      lspPingClientConfig.mappingTlv = \
         LspPingDownstreamMappingType.downstreamDetailedMap

   # Add the MPLS labels, possibly including the entropy label
   stackSize = len( labelStack )
   isEntropyLabelNext = False
   for i in range( stackSize ):
      label = int( labelStack[ i ] )
      lspPingClientConfig.mplsLabel[ i ] = MplsLabel( label )
      # If we see the ELI, we need to set it and the following entropy label's
      # TTL to 0 as per RFC6970
      if label == MplsLabel.entropyLabelIndicator:
         assert mount.mplsTunnelConfig.entropyLabelPush
         isEntropyLabelNext = True
         mplsTtlVal = 0
      elif isEntropyLabelNext:
         isEntropyLabelNext = False
         mplsTtlVal = 0
      else:
         mplsTtlVal = mplsTtl[ i ] if mplsTtl else 255
      lspPingClientConfig.mplsTtl[ i ] = mplsTtlVal

   if rsvpSpId:
      lspPingClientConfig.rsvpSpId = rsvpSpId
   if rsvpSenderAddr:
      lspPingClientConfig.rsvpSenderAddr = rsvpSenderAddr
   if pwInfo:
      lspPingClientConfig.pwInfo = pwInfo
   if not lspPingClientConfig.enabled:
      print lspPingClientConfig.reason
      return None

   lspPingClientStatus = lspPingClientStatusColl.pingClientStatus.get( clientId )  
   if not lspPingClientStatus:
      return None
   return ( lspPingClientConfig, lspPingClientStatus )

def isValidDestAddr( dst, ipv ):
   '''
   Validate the destination address is in correct range for IPv4/IPv6.
   Returns True if the provided destination address is valid for the ip version,
   False otherwise.
   '''
   if ipv is IPV4:
      if not Arnet.IpAddr( dst ).isLoopback:
         return False
   elif ipv is IPV6:
      ip6DestRange = Arnet.Ip6AddrWithMask( '::ffff:127.0.0.0/104' )
      if not ip6DestRange.contains( Arnet.Ip6Addr( dst ) ):
         return False
   else:
      return False
   return True

def verifyConsistentAfType( addr1Str, addr2Str ):
   addr1 = Arnet.IpGenAddr( addr1Str )
   addr2 = Arnet.IpGenAddr( addr2Str )
   return addr1.af == addr2.af

# ---------------------------------------------------------
#               Core ping helpers
# ---------------------------------------------------------

def echoRequestsSent( lspPingClientStatusColl, clientIds, seq ):
   for clientId in clientIds:
      clientStatus = lspPingClientStatusColl.pingClientStatus.get( clientId )
      if not clientStatus or clientStatus.sequenceNum < seq:
         return False
   return True

def echoReplyReceived( clientStatusColl, clientId, seq ):
   clientStatus = clientStatusColl.pingClientStatus.get( clientId )
   if not clientStatus or clientStatus.sequenceNum < seq \
   or not clientStatus.replyPktInfo.get( seq ):
      return False
   return True

def echoRepliesPending( clientStatusColl, clientIds, seq ):
   clientIds[:] = [ clientId for clientId in clientIds \
                    if not echoReplyReceived( clientStatusColl, clientId, seq ) ]
   return clientIds

def checkLspPingClientStatus( clientStatusColl, clientIdBase, clientIds, replyRender,
                              statisticsRender, renderArg, pktTimeout, interval=None,
                              count=None, pollRate=0.1 ):
   txPkts = {}
   seq = 1 # seq # starting from 1
   kbInt, replyInfo = False, {}
   interval = interval or 1
   startTime = Tac.now()
   try:
      while True:
         Tac.waitFor( lambda : echoRequestsSent( clientStatusColl, clientIds, seq ), 
                      maxDelay=0.2 )
         timeout = min( startTime + interval * seq + pktTimeout,
                        Tac.now() + pktTimeout )
         pendingClients = list( clientIds )
         while Tac.now() < timeout:
            pendingClients = echoRepliesPending( clientStatusColl, 
                                                 pendingClients, seq )
            if pendingClients:
               Tac.runActivities( pollRate )
               continue
            break

         for clientId in clientIds:
            clientStatus = clientStatusColl.pingClientStatus.get( clientId )
            replyPktInfo = clientStatus.replyPktInfo.get( seq )
            if replyPktInfo is not None and replyPktInfo.seqNum == seq:
               replyHostRtt = replyInfo.setdefault( clientId - clientIdBase, {} )
               replyHostRtt.setdefault( replyPktInfo.replyHost, [] ).append( 
                                                 replyPktInfo.roundTrip / 1000.0 )
            else:
               replyPktInfo = None

            if replyRender:
               replyRender( clientId - clientIdBase, replyPktInfo, renderArg )
            # update txNum ONLY after we have process the reply
            txPkts[ clientId - clientIdBase ] = seq 
         if count and seq == count:
            break
         seq += 1
   
   except KeyboardInterrupt:
      kbInt = True
   time = int( ( Tac.now() - startTime ) * 1000 ) # microsecond
   if statisticsRender:
      relativeClientIds = [ clientId - clientIdBase for clientId in clientIds ]
      statisticsRender( relativeClientIds, time, txPkts, replyInfo, renderArg )
   return LspPingResult( time, txPkts, replyInfo, kbInt )

# ---------------------------------------------------------
#               Core ping implementation
# ---------------------------------------------------------

def ping( mount, state, src, dst, smac, interval, count, viaInfo, replyRender,
          statisticsRender, renderArg, protocol=None, ipv=IPV4, pwLdpInfo=None,
          rsvpSpIds=None, rsvpSenderAddr=None, tc=None, timeout=5, mldpInfo=None,
          standard=None, size=None, padReply=False, setFecValidateFlag=True,
          egressValidateAddress=None, lookupLabelCount=1, tos=None ):
   count = count or None # ping until interruptted
   # While dst option is not supported via the CLI at this time, some of the tests
   # use it. Also, we may in future start allowing destination address to be
   # specified via the CLI. This check makes sure that if the destination address is
   # provided, it must be in 127.0.0.0/8 range for IPv4 and  ::fff:127.0.0.0/104
   # range for IPv6
   if dst:
      if not isValidDestAddr( dst, ipv ):
         msg = 'Destination address must be in ' + ( '127.0.0.0/8' if ipv is IPV4
               else '::ffff:127.0.0.0/104' ) + ' range.'
         print msg
         return -1
   else:
      dst = DefaultLspPingReqDst[ ipv ]
   pwInfo, rsvpSpId = None, None
   # create ping clients
   clientIds = []
   for idx, ( l3Intf, label, nhMac ) in enumerate( viaInfo ):
      if isinstance( label, int ):
         labelStack = [ label ]
      elif isinstance( label, tuple ):
         labelStack = list( label )
      else:
         labelStack = label
      clientId = state.clientIdBase + idx
      reqSrcIp = src or getIntfPrimaryIpAddr( mount, l3Intf, ipv )
      if not reqSrcIp:
         print 'No usable source ip address for ping'
         return -1

      # ensure both src and dst ip are of the consistent ipv type
      if not verifyConsistentAfType( reqSrcIp, dst ):
         msg = "Inconsistent source (" + str( reqSrcIp ) + ") and destination (" + \
               str( dst ) + ") IP address families.\n"
         print msg
         return -1

      rsvpSpId = rsvpSpIds[ idx ] if rsvpSpIds else None
      tunnelType, tunnelPrefix, pwInfo = \
            getTunnelTypeAndPrefix( protocol, renderArg[ 1 ], # renderArg[ 1 ]=prefix
                                    rsvpSpId, ipv, pwLdpInfo )

      if ( ( tunnelType == LspPingTunnelType.tunnelModeRaw ) and
           ( not egressValidateAddress ) and
           ( MplsLabel.explicitNullIpv4 not in labelStack ) and
           ( MplsLabel.explicitNullIpv6 not in labelStack ) ):
         nullLabel = ( MplsLabel.explicitNullIpv4 if ipv == IPV4 else
                       MplsLabel.explicitNullIpv6 )
         labelStack.extend( [ nullLabel ] )
         
      if not createLspPingClient( mount, state, l3Intf, reqSrcIp, dst ):
         print 'failed to create lsp ping client'
         return -1

      if not createLspPingClientConfig( mount, state, clientId, tunnelType, l3Intf, 
                                        labelStack, reqSrcIp, dst, nhMac, smac,
                                        interval, None, count,
                                        dstPrefix=tunnelPrefix, tc=tc,
                                        rsvpSpId=rsvpSpId, pwInfo=pwInfo,
                                        rsvpSenderAddr=rsvpSenderAddr,
                                        mldpInfo=mldpInfo, standard=standard,
                                        size=size, padReply=padReply, tos=tos,
                                        setFecValidateFlag=setFecValidateFlag,
                                        egressValidateAddress=(
                                           egressValidateAddress ),
                                        lookupLabelCount=lookupLabelCount ):
         print 'failed to create ping client config'
         return -1

      clientIds.append( clientId )
   # check ping client status, XXX timeout hard coded to 3 for now
   clientStatusColl = state.clientRoot.lspPingClientStatusColl 
   clientIdBase = state.clientIdBase
   pingResult = checkLspPingClientStatus( clientStatusColl, clientIdBase, clientIds,
                                          replyRender, statisticsRender, renderArg,
                                          timeout, interval, count )
   return errno.EINTR if pingResult.kbInt else 0

# ---------------------------------------------------------
#                 traceroute helpers 
# ---------------------------------------------------------

# Calculate the bitset for the specified multipath type
def getMultipathBitset( multipathType, numMultipathBits ):
   multipathBitset = LspPingMultipathBitset()

   if multipathType != 0:
      # Do a lookup for the specified baseip

      # The size of the multipathBitset should be:
      #  "mask of length 2^(32-prefix length) bits"
      # as per RFC4379, and a minimum of 4 bytes.
      multipathBitset.size = ( 2 ** ( numMultipathBits - 1 ).bit_length() ) / 8

      if multipathBitset.size < 4:
         multipathBitset.size = 4

      for byteIndex in range( 0, multipathBitset.size ):
         if numMultipathBits > 7:
            multipathBitset.bitset[ byteIndex ] = 0xFF
            numMultipathBits = numMultipathBits - 8
         else:
            tempByte = 0
            for _ in range( 0, numMultipathBits ):
               tempByte = tempByte << 1
               tempByte = tempByte + 1

            for _ in range( numMultipathBits, 8 ):
               tempByte = tempByte << 1

            multipathBitset.bitset[ byteIndex ] = tempByte
            numMultipathBits = 0

   return multipathBitset

def getMultipathType( multipath, numMultipathBits ):
   if multipath:
      if numMultipathBits == 0:
         return 0
      else:
         return 8

   return 0

# Note that this may return a different value for numMultipathBits than what
# was passed in
def getMultipathInfo( multipath, numMultipathBits ):
   multipathType = getMultipathType( multipath, numMultipathBits )

   if multipathType == 0:
      numMultipathBits = 0

   multipathBitset = getMultipathBitset( multipathType, numMultipathBits )

   return ( multipathType, numMultipathBits, multipathBitset )

# We use this when we do not know what DsMapping info to send in the echo
# request for the next hop. We send the DsMapping Info with the downstream
# address set to 224.0.0.2 in the case of Ipv4 and FF02::2 in the case of
# Ipv6. On doing this, we indicate that the LSR which recieves this echo request
# must ignore DsMapping validation but send a DsMapping TLV in its echo reply.
def getGenericDsMappingInfo( addrFamily=AddressFamily.ipv4 ):
   address = '224.0.0.2' if addrFamily == AddressFamily.ipv4 else 'FF02::2'
   # MTU should not have any significance here as DsMap validation is not performed
   # at the router which receives an echo request with this TLV.
   intfMtu = 1500
   multiPathBaseIp = '0.0.0.0'
   multipath = False
   numMultiPathBits = 0
   labelStack = []
   dsMapInfo = getDsMappingInfo( address, labelStack, intfMtu, multipath,
                                 multiPathBaseIp, numMultiPathBits )
   return dsMapInfo

def tracerouteReplyRender( ttl, replyPktInfo ):
   if replyPktInfo:
      hopPktInfo = HopPktInfo(
         replyHost=str( replyPktInfo.replyHost ),
         hopMtu=( 0 if not replyPktInfo.downstreamMappingInfo
                    else replyPktInfo.downstreamMappingInfo[ 0 ].mtu ),
         roundTrip=int( replyPktInfo.roundTrip / 1000.0 ),
         retCode=replyPktInfo.retCode,
         ttl=ttl )
   else:
      hopPktInfo = HopPktInfo( ttl=ttl )
   tracerouteReplyHdrRender( hopPktInfo )

def downstreamInfoRender( downstreamInfos ):
   if not downstreamInfos:
      return

   tracerouteDownstreamInfoRender( downstreamInfos )

def getLabelStack( labelStack ):
   labels = [ 'implicit-null'
              if labelStack.labelEntry[ entry ].label == MplsLabel.implicitNull
              else str( labelStack.labelEntry[ entry ].label )
              for entry in range( labelStack.size ) ]
   return labels

def labelStackRender( labelStack ):
   if not labelStack:
      return

   labels = getLabelStack( labelStack )
   labelStackInfo = LabelStackInfo( labelStack=labels )
   tracerouteLabelStackRender( labelStackInfo )

def getTunnelTypeAndPrefix( protocol, prefix,
                            rsvpSpId, ipv, pwLdpInfo=None ):
   tunnelType, tunnelPrefix, pwInfo = None, None, None
   if protocol == LspPingTypeLdp:
      tunnelType = LspPingTunnelType.tunnelLdpIpv4
      tunnelPrefix = prefix
   elif protocol == LspPingTypeMldp:
      tunnelType = LspPingTunnelType.tunnelMldp
      tunnelPrefix = prefix
   elif protocol == LspPingTypeSr:
      assert ipv in [ IPV4, IPV6 ]
      tunnelType = ( LspPingTunnelType.tunnelSrIpv4 if ipv == IPV4
                     else LspPingTunnelType.tunnelSrIpv6 )
      tunnelPrefix = prefix
   elif protocol == LspPingTypeBgpLu:
      assert ipv in [ IPV4, IPV6 ]
      tunnelType = ( LspPingTunnelType.tunnelBgpLuIpv4 if ipv == IPV4
                     else LspPingTunnelType.tunnelBgpLuIpv4 )
      tunnelPrefix = prefix
   elif rsvpSpId:
      tunnelType = LspPingTunnelType.tunnelRsvpIpv4
      tunnelPrefix = rsvpSpId.sessionId.dstIp.stringValue
   elif pwLdpInfo:
      tunnelType = LspPingTunnelType.tunnelPwLdp
      tunnelPrefix = None
      pwInfo = LspPingPwInfo( pwLdpInfo.localRouterId,
                              pwLdpInfo.remoteRouterId,
                              pwLdpInfo.pwId,
                              pwLdpInfo.pwType,
                              pwLdpInfo.controlWord )
      # activeCcType is control word (CW) if CW is being used and is supported by
      # peer. Router alert (RA) is used otherwise. This logic may change when we
      # implement TTL expiry as a ccType.
      if ( pwLdpInfo.controlWord and
           pwLdpInfo.ccTypes & pwPingCcEnumToVal[ PwPingCcType.cw ] and
           Toggles.PseudowireToggleLib.togglePseudowirePingCwEnabled() ):
         pwInfo.activeCcType = pwPingCcEnumToVal[ PwPingCcType.cw ]
      else:
         pwInfo.activeCcType = pwPingCcEnumToVal[ PwPingCcType.ra ]
   else:
      tunnelType = LspPingTunnelType.tunnelModeRaw
      tunnelPrefix = None

   return ( tunnelType, tunnelPrefix, pwInfo )

def getAndRenderDownstreamInfoOrIls( replyPktInfo, tunnelType, dsInfoExpected=True ):
   mplsLabelStack, downstreamInfos = [], []
   if replyPktInfo.downstreamMappingInfo:
      if tunnelType != LspPingTunnelType.tunnelMldp:
         dsInfos = [ replyPktInfo.downstreamMappingInfo[ 0 ] ]
      else:
         dsInfos = replyPktInfo.downstreamMappingInfo.values()
      for dsInfo in dsInfos:
         if dsInfo.valid and dsInfo.labelStack.size:
            labelStack = dsInfo.labelStack
            labels = getLabelStack( labelStack )
            addrType = Tac.enumName( 'LspPing::LspPingAddrType', dsInfo.addrType )
            downstreamInfos.append(
               DownstreamInfo( addrType=addrType,
                               dsIntfAddr=str( dsInfo.dsIntfAddr ),
                               dsIntfIndex=dsInfo.intfIndex(),
                               dsIpAddr=str( dsInfo.dsIpAddr ),
                               labelStack=labels,
                               retCode=dsInfo.retCode,
                               dsType=replyPktInfo.dsType ) )
            mplsLabelStack.append( dsInfo.labelStack )
      downstreamInfoRender( downstreamInfos )
   elif ( replyPktInfo.intfAndLabelStackInfo.valid and
          replyPktInfo.intfAndLabelStackInfo.labelStack.size ):
      mplsLabelStack.append( replyPktInfo.intfAndLabelStackInfo.labelStack )
      labelStackRender( replyPktInfo.intfAndLabelStackInfo.labelStack )
   elif dsInfoExpected:
      # We expected either dsInfo or ILS except for few scenarios where we don't
      # send DSMAP in the first place e.g. NHG traceroute. However we should only do
      # this check for success/labels switched scenario
      if ( replyPktInfo.retCode == LspPingReturnCode.seeDdmTlv or
           replyPktInfo.retCode == LspPingReturnCode.labelSwitchedAtStackDep ):
         errOutput = '     error: downstream information missing in echo response'
         print errOutput
   return mplsLabelStack

def labelStackIsExplicitNull( ipv, mplsLabelStack ):
   for labelStack in mplsLabelStack:
      topLabel = labelStack.labelEntry[ 0 ].label if labelStack.size \
                 else None
      topLabelIsExplicitNull = (
         ( ipv == IPV4 and topLabel == MplsLabel.explicitNullIpv4 ) or
         ( ipv == IPV6 and topLabel == MplsLabel.explicitNullIpv6 )
      )
      isLabelStackSizeOne = ( labelStack.size == 1 )
   # isLabelStackSizeOne and topLabelIsExplicitNull checks for all the
   # downstreaminfo labelstack objects and returns true if it holds true
   # all of them.
   return isLabelStackSizeOne and topLabelIsExplicitNull

def isLabelSwitchedOrSeeDdmapTlv( prefix, retCode ):
   return ( not prefix or
            retCode == 'labelSwitchedAtStackDep' or retCode == 'seeDdmTlv' )

def isNoRetCodeOrLabelSwitchedInSingleDdmapTlv( replyPktInfo ):
   '''
   Checks for the retcode returned in the only DDMAP tlv received in echo reply.
   Returns true if the reply header retcode is labelSwitched or if reply header
   retcode is seeDdmTlv along with DDMAP returncode as labelSwitchedAtStackDep.
   ( As per rfc 8029, we should also confirm that if reply header retcode is 
     labelSwitched then DDMAP retcode must be noRetcode but here we are not 
     checking for noRetcode in DDMAPTlv as cisco replies with both reply header
     and DDMAP retcode set to labelSwitched )
   '''
   ret = True
   if ( replyPktInfo.dsType == LspPingDownstreamMappingType.downstreamDetailedMap and
        len( replyPktInfo.downstreamMappingInfo ) == 1 ):
      ret = ( replyPktInfo.retCode == LspPingReturnCode.labelSwitchedAtStackDep or
              ( replyPktInfo.retCode == LspPingReturnCode.seeDdmTlv and
                replyPktInfo.downstreamMappingInfo[ 0 ].retCode ==
                LspPingReturnCode.labelSwitchedAtStackDep ) )
   return ret

def isLabelSwitchedInAllDdmapTlvs( replyPktInfo ):
   '''
   Checks for the retcode returned in all DDMAP tlvs received in the echo reply.
   Returns true only if reply header retcode is SeeDdmTlv and all ddmap tlvs have 
   labelSwitchedAtStackDep in them. 
   '''
   ret = True
   if ( replyPktInfo.dsType == LspPingDownstreamMappingType.downstreamDetailedMap and
        len( replyPktInfo.downstreamMappingInfo ) > 1 ):
      ret = ( not( replyPktInfo.retCode != LspPingReturnCode.seeDdmTlv ) or
              not any( dsInfo.retCode != LspPingReturnCode.labelSwitchedAtStackDep
                 for dsInfo in replyPktInfo.downstreamMappingInfo.itervalues() ) )
   return ret

def createFirstMultipathSearchNode( packetTracerSm, dsMappingInfo, dstIpStr,
                                    labelStack, l3Intf ):
   # populate search stack with first node to probe
   nexthop = Tac.Value( "LspPing::LspPingGraphNode" )
   nexthop.hopNum = 1
   nexthop.clientId = packetTracerSm.generateClientId()
   nexthop.dsMappingInfo = dsMappingInfo
   for idx, val in enumerate( labelStack ):
      nexthop.mplsLabel[ idx ] = val
   nexthop.intf = l3Intf
   nexthop.dstIpAddr = IpGenAddr( dstIpStr )
   nexthop.isRootNode = True
   return nexthop

def getAndClearPacketTracerSm( lspPingClientRoot ):
   packetTracerSm = lspPingClientRoot.lspPingPacketTracerSm
   # force packetTracerSm state to be empty
   packetTracerSm.searchStack.clear()
   packetTracerSm.tracePath.clear()
   packetTracerSm.searchCompleted = False
   packetTracerSm.clientConfigColl.pingClientConfig.clear()
   return packetTracerSm

def createMultipathTracerouteInfo( tunnelType, srcIpStr, dstIpStr, smac,
                                   mount, dmac, dsMapInfo, tc, count,
                                   interval, topLabel, nextHopIpStr,
                                   dstPrefix=None, dstype=None, tos=None,
                                   standard=None, size=None, padReply=False ):
   # create LspPingMultipathTracerouteInfo object
   tracerouteInfo = Tac.newInstance( "LspPing::LspPingMultipathTracerouteInfo" )
   tracerouteInfo.oamStandard = standard or mount.config.oamStandard
   tracerouteInfo.tunnelType = tunnelType
   tracerouteInfo.nextHopIp = IpGenAddr( nextHopIpStr )
   tracerouteInfo.srcIpAddr = IpGenAddr( srcIpStr )
   tracerouteInfo.dstIpAddr = IpGenAddr( dstIpStr )
   if tos:
      tracerouteInfo.tos = tos
   if dstPrefix:
      tracerouteInfo.dstPrefix = IpGenPrefix( dstPrefix )
   tracerouteInfo.srcEthAddr = ( smac or mount.bridgingConfig.bridgeMacAddr )
   tracerouteInfo.tunnelNexthopEthAddr = dmac
   tracerouteInfo.dsMapInfo = dsMapInfo
   # By default tracerouteInfo.dstype is LspPingDownstreamMappingType.downstreamMap
   if dstype == LspPingDDMap:
      tracerouteInfo.dstype = LspPingDownstreamMappingType.downstreamDetailedMap
   if size:
      tracerouteInfo.pktSize = size
      tracerouteInfo.padReply = padReply
   tracerouteInfo.mplsTc = tc
   tracerouteInfo.txCount = count
   tracerouteInfo.txRate = interval
   tracerouteInfo.fd = sys.stdout.fileno()
   # 1 indicates text output for cliPrinter
   tracerouteInfo.format = 1
   tracerouteInfo.topLabel = Tac.newInstance( "Arnet::MplsLabel", topLabel )
   tracerouteInfo.valid = True
   return tracerouteInfo

def doMultipathTraceroute( lspPingClientRoot, dsMappingInfo, dstIpStr,
                           labelStack, l3Intf ):
   # Unlike default traceroute, all logic for multipath traceroute runs in C++
   # and resides within LspPingPacketTracerSm. This utility library waits till
   # packetTracerSm is done, and exits.
   packetTracerSm = getAndClearPacketTracerSm( lspPingClientRoot )
   # set up first node to probe
   nexthop = createFirstMultipathSearchNode( packetTracerSm, dsMappingInfo, dstIpStr,
                                             labelStack, l3Intf )
   packetTracerSm.searchStack.push( nexthop )
   # initiate search
   packetTracerSm.probeNextHop()
   # wait till packetTracerSm declares search to be completed
   while not packetTracerSm.searchCompleted:
      Tac.runActivities( 0.050 )

   return 0

# ---------------------------------------------------------
#          Core traceroute implementation
# ---------------------------------------------------------
def bumpSessionId( fn ):
   @wraps( fn )
   def wrapper( *args, **kwargs ):
      ret = fn( *args, **kwargs )
      # cleanup the stuff in existing global state before we actually bump the
      # sessionId.
      cleanupGlobalState()
      # bump session id per fn invocation to force fresh global state
      sessionIdIncr()
      return ret

   return wrapper

# traceroute function returns a tuple of ( success, txPkts and replyHostRtts )
# Returning variable number of args is feasible only in python 3 where
# *retArgs=traceroute(...) can be done to get variable args which will help in
# returning the tuple when required only.
@bumpSessionId
def traceroute( mount, state, l3Intf, labelStack, src, dst, srcMac, dstMac, count,
                interval, protocol=None, prefix=None, verbose=False, hops=None,
                ipv=IPV4, rsvpSpId=None, rsvpSenderAddr=None, dsMappingInfo=None, 
                tc=None, multipath=False, nextHopIp=None,
                mldpInfo=None, standard=None, size=None, padReply=False,
                dstype=None, setFecValidateFlag=True, egressValidateAddress=None,
                lookupLabelCount=1, tos=None ):

   txPkts, replyHostRtts = 0, {}
   retArgs = TracerouteRetArgs( retVal=-1, txPkts=txPkts,
                                replyHostRtts=replyHostRtts )
   reqSrcIp = src or getIntfPrimaryIpAddr( mount, l3Intf, ipv )
   if not reqSrcIp:
      print 'No usable source ip address for traceroute'
      return retArgs

   baseErrorMsg = 'Error: Could not proceeed with traceroute\n'
   # While dst option is not supported via the CLI at this time, some of the tests
   # use it. Also, we may in future start allowing destination address to be
   # specified via the CLI. This check makes sure that if the dst is provided, it's
   # valid.
   if dst:
      # ensure both src and dst ip are of the consistent ipv type
      if not verifyConsistentAfType( reqSrcIp, dst ):
         msg = "Inconsistent source (" + str( reqSrcIp ) + ") and destination (" + \
               str( dst ) + ") IP address families.\n"
         print baseErrorMsg + msg
         return retArgs

      # Provided destination address must be in 127.0.0.0/8 range for IPv4 and
      # ::fff:127.0.0.0/104 range for IPv6
      if not isValidDestAddr( dst, ipv ):
         msg = 'Destination address must be in ' + ( '127.0.0.0/8' if ipv is IPV4
               else '::ffff:127.0.0.0/104' ) + ' range.'
         print baseErrorMsg + msg
         return retArgs

   # In case of multipath, our destination should be the base address. User specified
   # destination address should match the provided base IP address. If it's not
   # multipath, use the default IP destination address if one is not provided.
   if multipath and dsMappingInfo:
      if dst:
         msg = ( 'Destination address is not supported for multipath traceroute. ' +
                 'Please use multipath base address.' )
         print baseErrorMsg + msg
         return retArgs
      dst = dsMappingInfo.dsMultipathBaseAddr.stringValue
   else:
      dst = dst or DefaultLspPingReqDst[ ipv ]

   # pseudowire does not have traceroute implementation yet so
   # pwInfo returned is irrelevant.
   tunnelType, tunnelPrefix, _ = \
               getTunnelTypeAndPrefix( protocol, prefix, rsvpSpId, ipv )

   tracerouteInfo = Tac.newInstance( "LspPing::LspPingMultipathTracerouteInfo" ) \
                    if not multipath else \
                    createMultipathTracerouteInfo( tunnelType, reqSrcIp, dst, srcMac,
                                                   mount, dstMac, dsMappingInfo, tc,
                                                   count, interval, labelStack[ 0 ],
                                                   nextHopIp, tunnelPrefix, dstype,
                                                   tos, standard=standard, size=size,
                                                   padReply=padReply )
   # create LspPingClientRootSm
   lspPingClientRoot = createLspPingClient( mount, state, l3Intf, reqSrcIp, dst,
                                            tracerouteInfo )
   if not lspPingClientRoot:
      print baseErrorMsg
      return retArgs

   if multipath:
      ret = doMultipathTraceroute( lspPingClientRoot, dsMappingInfo, dst,
                                   labelStack, l3Intf )
      retArgs = retArgs._replace( retVal=ret )
      return retArgs

   sendDsMappingTlv, nextDsMappingInfo = False, None
   # Send the DsMap TLV in the subsequent echo requests.
   if dsMappingInfo:
      sendDsMappingTlv, nextDsMappingInfo = True, dsMappingInfo

   # For some usecases we already push an Exp-NULL label at the BOS. In that case,
   # we do not need to push an extra Exp-NULL label again.
   if ( ( tunnelType == LspPingTunnelType.tunnelModeRaw ) and
        ( not egressValidateAddress ) and
        ( MplsLabel.explicitNullIpv4 not in labelStack ) and
        ( MplsLabel.explicitNullIpv6 not in labelStack ) ):
      nullLabel = ( MplsLabel.explicitNullIpv4 if ipv == IPV4 else
                    MplsLabel.explicitNullIpv6 )
      labelStack = list( labelStack )
      labelStack.extend( [ nullLabel ] )

   labelTtl, numHops = [ 1 ] * len( labelStack ), 0
   maxHops, maxRequests = 64, 1
   clientStatusColl = state.clientRoot.lspPingClientStatusColl
   # pylint: disable-msg=cell-var-from-loop
   # pylint: disable=too-many-nested-blocks
   for hop in range( 1, ( hops or maxHops ) + 1 ):
      numHops += 1
      clientId = state.clientIdBase + hop
      ret = createLspPingClientConfig( mount, state, clientId, tunnelType, l3Intf,
                                       labelStack, reqSrcIp, dst, dstMac, srcMac,
                                       mplsTtl=labelTtl, count=maxRequests,
                                       dstPrefix=tunnelPrefix, rsvpSpId=rsvpSpId,
                                       rsvpSenderAddr=rsvpSenderAddr,
                                       dsMappingInfo=nextDsMappingInfo, tc=tc,
                                       mldpInfo=mldpInfo, standard=standard,
                                       size=size, padReply=padReply, tos=tos,
                                       dstype=dstype,
                                       setFecValidateFlag=setFecValidateFlag,
                                       egressValidateAddress=egressValidateAddress,
                                       lookupLabelCount=lookupLabelCount )
      if not ret:
         print baseErrorMsg
         state.clientIdBase += numHops
         retArgs = retArgs._replace( retVal=-1 )
         return retArgs

      clientConfig, clientStatus = ret
      unreachable = True
      for seqnum in range( 1, maxRequests + 1 ):
         Tac.waitFor(
            lambda: echoRequestsSent( clientStatusColl, [ clientId ], seqnum ), 
            maxDelay=0.1 )
         txPkts += 1
         # wait to receive an echo response
         timeout = Tac.now() + interval
         while Tac.now() < timeout:
            if echoReplyReceived( clientStatusColl, clientId, seqnum ):
               break
            Tac.runActivities( 0.050 )

         # we got an echo response back
         replyPktInfo = clientStatus.replyPktInfo.get( seqnum )
         if replyPktInfo and replyPktInfo.seqNum == seqnum:
            replyHostRtts.setdefault( replyPktInfo.replyHost, [] ).append(
                  replyPktInfo.roundTrip / 1000.0 )
            tracerouteReplyRender( hop, replyPktInfo )
            dsInfoExpected = dsMappingInfo is not None
            mplsLabelStack = getAndRenderDownstreamInfoOrIls( replyPktInfo,
                                                              tunnelType,
                                                              dsInfoExpected )
            # Don't check return code if protocol is not ldp or segment-routing
            # Check for DownstreamInfo and return codes in echo reply header
            # & DDMAP tlvs.
            # Stop probing and return if no downstreamInfo is there or an
            # undesired retcode is returned.
            if ( not mplsLabelStack or
                    labelStackIsExplicitNull( ipv, mplsLabelStack ) or
                    not isLabelSwitchedOrSeeDdmapTlv( prefix,
                        replyPktInfo.retCode ) or
                    not isLabelSwitchedInAllDdmapTlvs( replyPktInfo ) or
                    not isNoRetCodeOrLabelSwitchedInSingleDdmapTlv(
                        replyPktInfo ) ):
               # Skip explicit NULL in echo reply to stop probing further hop(s)
               # This is necessary when traceroute to single-hop LSP. Otherwise,
               # will traceroute to more than one hop.
               state.clientIdBase += numHops
               retArgs = retArgs._replace( retVal=0, txPkts=txPkts,
                                           replyHostRtts=replyHostRtts )
               return retArgs

            unreachable = False
            # increment top label ttl for next ping
            labelTtl[ 0 ] += 1
            # set next dsMapInfo to be used
            if sendDsMappingTlv:
               # Currently, we only populate dsMappingInfo for LDP, RSVP, and SR
               # mpls traceroutes. As this helper will not not be used for multipath
               # traceroutes, we always expect only one dsMapInfo object in the
               # received echo response. If for some reason no dsMapInfo is present
               # in the echo response, don't include the dsMapInfo object in any
               # subsequent echo requests.
               nextDsMappingInfo = (
                                 replyPktInfo.downstreamMappingInfo.values()[ 0 ] if
                                 replyPktInfo.downstreamMappingInfo else None )
               # As per rfc 6425, section 4.3.4, If the echo request is destined
               # for more than one node, then the Downstream IP Address field of
               # the Downstream Detailed Mapping TLV MUST be set to the
               # ALLROUTERS multicast address and the Address Type field
               # MUST be set to either IPv4 Unnumbered or IPv6 Unnumbered
               # depending on the Target FEC Stack TLV.
               # As per rfc 4379, section 4.6, the interface index MUST be set to 0.
               # Arnet::IntfId does not support taking value 0, so 0 value is set in
               # LspPingClientTxSm depending on dsMappingInfo.dsIpAddr being
               # AllRoutersMulticast address.
               if ( protocol == 'mldp'
                    and ( len( replyPktInfo.downstreamMappingInfo.values() ) > 1 ) ):
                  if ipv == IPV4:
                     dsIpAddr = IpGenAddr( IpAddrAllRoutersMulticast )
                     addrType = Ipv4Unnumbered

                  else:
                     dsIpAddr = IpGenAddr( Ip6AddrAllRoutersMulticast )
                     addrType = Ipv6Unnumbered

                  nextDsMappingInfo = Tac.newInstance(
                        'LspPing::LspPingDownstreamMappingInfo',
                        nextDsMappingInfo.mtu, addrType, nextDsMappingInfo.dsFlags,
                        dsIpAddr, IpGenAddr(), 0, 0, IpGenAddr(),
                        LspPingMultipathBitset( 0 ), nextDsMappingInfo.labelStack,
                        Arnet.IntfId( '' ), 0, 0, True )
         else:
            if seqnum < maxRequests:
               clientConfig.txCount = 1

      if unreachable:
         tracerouteReplyRender( hop, None )
         labelTtl[ 0 ] += 1
         if sendDsMappingTlv:
            # Traceroute to this particular hop timed out. Hence we do not have
            # the DsMapping Info TLV to send in the next echo request to the next hop
            # Set the nextDsMappingInfo to None here. The subsequent lines of code
            # would handle this, and sets an appropriate DsMapping TLV to be sent in
            # the echo request to subsequent hops.
            nextDsMappingInfo = None

      if sendDsMappingTlv and ( not nextDsMappingInfo ):
         addrFamily = AddressFamily.ipv4 if ipv == IPV4 else AddressFamily.ipv6
         nextDsMappingInfo = getGenericDsMappingInfo( addrFamily=addrFamily )

   #pylint: enable-msg=cell-var-from-loop
   state.clientIdBase += numHops

   retArgs = retArgs._replace( retVal=-1 )
   return retArgs
