# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac
import Arnet
import os
import errno, struct, socket
from socket import AF_INET6
import collections

from TypeFuture import TacLazyType
from ForwardingHelper import ( noMatchNexthopInfo )
from CliPlugin.PwaModel import connStatusMsg as PseudowireConnectorStatusMsg

IPV4, IPV6 = 'ip', 'ipv6'
# string literals used by LspTraceroute/LspPing utilities and soome client libs
LspPing = 'lsp-ping'
LspTraceroute = 'lsp-traceroute'

LspPingTypeBgpLu = 'bgpLu'
LspPingTypeRaw = 'raw'
LspPingTypeLdp = 'ldp'
LspPingTypeMldp = 'mldp'
LspPingTypeRsvp = 'rsvp'
LspPingTypeSr = 'segment-routing'
LspPingTypeStatic = 'static'
LspPingTypeNhg = 'nexthop-group'
LspPingTypeSrTe = 'SrTePolicy'
LspPingTypePwLdp = 'pwLdp'
LspPingTypeNhgTunnel = 'nexthop-group-tunnel'
LspPingTypes = [ LspPingTypeBgpLu, LspPingTypeRaw, LspPingTypeLdp, LspPingTypeRsvp,
                 LspPingTypeSr, LspPingTypeStatic, LspPingTypeNhg, LspPingTypeSrTe,
                 LspPingTypePwLdp, LspPingTypeNhgTunnel, LspPingTypeMldp ]

LspPingDSMap, LspPingDDMap = 'dsmap', 'ddmap'
LspPingDSTypes = ( LspPingDSMap, LspPingDDMap )

IpGenAddr = Tac.Type( 'Arnet::IpGenAddr' )
IpGenPrefix = Tac.Type( 'Arnet::IpGenPrefix' )
ConnectorKey = Tac.Type( 'Pseudowire::ConnectorKey' )
DynTunnelIntfId = Tac.Type( 'Arnet::DynamicTunnelIntfId' )
AddressFamily = Tac.Type( 'Arnet::AddressFamily' )
EthAddr = Tac.Type( 'Arnet::EthAddr' )
FecAdjType = TacLazyType( 'Smash::Fib::AdjType' )
FecIdConstants = Tac.Type( 'Smash::Fib::FecIdConstants' )
FecIdIntfId = TacLazyType( 'Arnet::FecIdIntfId' )
FecId = TacLazyType( 'Smash::Fib::FecId' )
IntfId = Tac.Type( 'Arnet::IntfId' )
MplsLabel = Tac.Type( 'Arnet::MplsLabel' )
NexthopGroupIntfIdType = Tac.Type( 'Arnet::NexthopGroupIntfId' )
PolicyEndpoint = TacLazyType( 'SrTePolicy::EndPoint' )
PolicyKey = TacLazyType( 'SrTePolicy::PolicyKey' )
PseudowireConnectorStatus = Tac.Type( "Pseudowire::PseudowireConnectorStatus" )
RoutingOutputType = Tac.Type( 'Routing::RoutingOutputType' )
TunnelId = Tac.Type( 'Tunnel::TunnelTable::TunnelId' )
TunnelType = TacLazyType( 'Tunnel::TunnelTable::TunnelType' )
LspPingReturnCode = Tac.Type( 'LspPing::LspPingReturnCode' )
LspPingMultipathBitset = Tac.Type( 'LspPing::LspPingMultipathBitset' )
ARP_SMASH_DEFAULT_VRF_ID = Tac.Type( 'Vrf::VrfIdMap::VrfId' ).defaultVrf

NextHopAndLabel = collections.namedtuple( 'NextHopAndLabel', [
                   'nextHopIp',
                   'label',
                   'intfId',
                   ] )

MldpInfo = collections.namedtuple( 'MldpInfo',
              'genOpqVal sourceAddrOpqVal groupAddrOpqVal jitter responderAddr' )

PwLdpInfo = collections.namedtuple( 'PwLdpInfo', [
                   'localRouterId',
                   'remoteRouterId',
                   'pwId',
                   'pwType',
                   'vcLabel',
                   'cvTypes',
                   'ccTypes',
                   'controlWord',
                   ] )

retCodeStrMap = {
   LspPingReturnCode.noRetCode : 'no return code',
   LspPingReturnCode.malFormEchoRequest : 'malformed echo req',
   LspPingReturnCode.tlvNotUnderstood : 'tlv not understood',
   LspPingReturnCode.repRouterEgress : 'egress ok',
   LspPingReturnCode.noMappingForFec : 'no map for fec',
   LspPingReturnCode.dsMappingMismatch : 'ds map mismatch',
   LspPingReturnCode.usIntfUnknown : 'us intf unknown',
   LspPingReturnCode.reservedRetCode : 'reserved ret code',
   LspPingReturnCode.labelSwitchedAtStackDep : 'label switched',
   LspPingReturnCode.labelSwitchedNoMplsFwd : 'no mpls fwd',
   LspPingReturnCode.labelMappingMismatch : 'label map mismatch',
   LspPingReturnCode.noLabelEntryAtStackDep : 'no label entry',
   LspPingReturnCode.protocolIntfMismatch : 'proto intf mismatch',
   LspPingReturnCode.premTermination : 'prem termination',
   LspPingReturnCode.seeDdmTlv : 'check downstream information'
}

bgpLuNoTunnelFoundErr = ( 'No BGP labeled unicast tunnel found in tunnel FIB'
                          ' for tunnel ID %d' )
bgpLuNoTunnelViaFoundErr = ( 'No BGP labeled unicast tunnel via found in'
                             ' tunnel ID %d' )

def lspPingRetCodeStr( retCode ):
   retCodeStr = retCodeStrMap.get( retCode )
   return 'unknown' if not retCodeStr else retCodeStr

def setProductCodeGlobals():
   '''
   Ensure that we're well behaved when invoked via the CLI as "product" code.

    - Disable tracing to stderr.
    - Use epoll.
    - Disable dropping in to pdb for exceptions.
   '''
   os.environ[ 'NOPDB' ] = '1'
   # We do not want to inherit any tracing or display anything in the CLI output
   # except if tracing is configured to write to TRACEFILE
   import Tracing
   if 'TRACEFILE' not in os.environ:
      Tracing.traceSettingIs( '' )

   Tac.activityManager.useEpoll = True

def labelStackToList( labelStack ):
   labels = []
   for idx in range( labelStack.stackSize ):
      labels.insert( 0, labelStack.labelStack( idx ) )
   return labels

def isIpv6Addr( addr ):
   if hasattr( addr, 'stringValue' ):
      addr = addr.stringValue
   else:
      assert isinstance( addr, str )
   return IpGenAddr( addr ).af == AddressFamily.ipv6

def isIpv4Addr( addr ):
   return not isIpv6Addr( addr )

def isNexthopGroupVia( via ):
   return NexthopGroupIntfIdType.isNexthopGroupIntfId( via.intfId )

def isNexthopGroupTunnelVia( via ):
   if via.intfId.startswith( 'DynamicTunnel' ):
      tunnelId = DynTunnelIntfId.tunnelId( via.intfId )
      tunnelType = TunnelId( tunnelId ).tunnelType()
      return tunnelType == TunnelType.nexthopGroupTunnel
   return False

def getNexthopGroupId( via ):
   return NexthopGroupIntfIdType.nexthopGroupId( via.intfId )

def getNhgIdToName( nhgId, mount ):
   for key in mount.smashNhgStatus.nexthopGroupEntry.keys():
      entry = mount.smashNhgStatus.nexthopGroupEntry[ key ]
      if entry.nhgId == nhgId:
         return key.nhgName()
   return None

def getNhgId( nhgName, mount ):
   key = Tac.Value( "NexthopGroup::NexthopGroupEntryKey" )
   key.nhgNameIs( nhgName )
   entry = mount.smashNhgStatus.nexthopGroupEntry.get( key )
   if not entry:
      return None
   return entry.nhgId

def getTunnelNhgName( mount, endpoint ):
   ribEntry = mount.nhgTunnelRib.entry.get( IpGenPrefix( endpoint ) )
   if not ribEntry or not ribEntry.tunnelId:
      err = 'No nexthop-group tunnel found for prefix %s' % endpoint
      return ( None, None, err )

   tunnelId = ribEntry.tunnelId[ 0 ]
   return resolveNhgTunnelFibEntry( mount, tunnelId )

def genOpaqueValForP2mpGenericOpaqueId( oid ):
   ''' Generates a opaque TLV for Generic opaque LSP identifier '''
   return struct.pack( '!BHI', 1, 4, oid )

def genOpaqueValForP2mpTransitV4SrcOpaqueId( src, grp ):
   ''' Generates a opaque TLV for transit v4 source '''
   unpackSrc = struct.unpack( '!L', socket.inet_aton( src ) )[ 0 ]
   unpackGrp = struct.unpack( '!L', socket.inet_aton( grp ) )[ 0 ]
   return struct.pack( '!BHII', 3, 8, unpackSrc, unpackGrp )

def genOpaqueValForP2mpTransitV6SrcOpaqueId( src, grp ):
   ''' Generates a opaque TLV for transit v6 source '''
   unpackSrc = struct.unpack( '!QQ', socket.inet_pton( AF_INET6, src ) )
   unpackGrp = struct.unpack( '!QQ', socket.inet_pton( AF_INET6, grp ) )
   return struct.pack( '!BHQQQQ', 4, 32,
                       unpackSrc[ 0 ], unpackSrc[ 1 ],
                       unpackGrp[ 0 ], unpackGrp[ 1 ] )

def getIntfPrimaryIpAddr( mount, intf, ipv=IPV4 ):
   if ipv == IPV4:
      if not mount.vrfIpIntfStatus.ipIntfStatus:
         return None
      intfStatus = mount.vrfIpIntfStatus.ipIntfStatus.get( intf )
      if intfStatus:
         return intfStatus.activeAddrWithMask.address
   else:
      if not mount.vrfIp6IntfStatus.ip6IntfStatus:
         return None
      intfStatus = mount.vrfIp6IntfStatus.ip6IntfStatus.get( intf )
      if intfStatus:
         for addr in intfStatus.addr:
            if not addr.address.isLinkLocal:  # need a routable addr
               return addr.address.stringValue
   return None

def generateMplsEntropyLabel( srcIp, dstIp, udpSrcPort ):
   '''
   Generates the entropy label by hashing over the srcIp, dstIp, and UDP src port
   fields in order to preserve flow.

   The entropy label will be generated between [16, 1048575] (inclusive).
   '''
   unreservedMplsLabelBase = 16
   validRange = MplsLabel.max - unreservedMplsLabelBase
   hashVal = hash( ( str( srcIp ), str( dstIp ), udpSrcPort ) )
   entropyLabel = hashVal % validRange + unreservedMplsLabelBase
   return entropyLabel

def getIpv6ArpEntry( mount, nexthop, intf ):
   'Retrieves the Ipv6 ND entry from Arp Smash table given the nexthop and interface'
   arpKey = Tac.Value( 'Arp::Table::ArpKey', ARP_SMASH_DEFAULT_VRF_ID,
                       IpGenAddr( str( nexthop ) ), intf )
   neighborEntry = mount.arpSmash.neighborEntry.get( arpKey )
   if neighborEntry is None:
      return None
   return neighborEntry.ethAddr

def getProtocolIpFec( mount, prefix, protocol, genOpqVal,
                      sourceAddrOpqVal, groupAddrOpqVal ):
   '''
   Returns a tuple of ( fecVias, errVal ) for the given protocol, where fecVias is
   a list of NextHopAndLabel objects.
   '''
   if protocol == 'ldp':
      return getLdpFec( mount, prefix )
   elif protocol == 'mldp':
      return getMldpFec( mount, prefix, genOpqVal,
                         sourceAddrOpqVal, groupAddrOpqVal )
   elif protocol == 'segment-routing':
      return getSrFec( mount, prefix )
   else:
      return ( [], errno.EINVAL )

def getNhAndLabelInfoFromTunnelId( mount, tunnelId ):
   tunnelFibEntry = mount.tunnelFib.entry.get( tunnelId )
   if not tunnelFibEntry:
      err = 'No tunnel found in tunnel FIB for tunnel ID %d' % tunnelId
      return None, err
   if not tunnelFibEntry.tunnelVia:
      err = 'No tunnel via found in entry for tunnel ID %d' % tunnelId
      return None, err

   nexthopAndLabelList = []
   for via in tunnelFibEntry.tunnelVia.values():
      nexthopAndLabel, err = getNhAndLabelFromTunnelFibVia( mount, via )
      if err:
         continue
      nexthopAndLabelList.append( nexthopAndLabel )
   if not nexthopAndLabelList:
      return None, 'No resolved vias in entry for tunnel ID %d' % tunnelId
   return nexthopAndLabelList, None

def getLdpFec( mount, prefix ):
   '''
   Gets the list of L3 resolved IP nexthops with the corresponding label stack. This
   will also resolve over an RSVP tunnel if it is LDP over RSVP.
   '''
   if isIpv6Addr( prefix ):
      print 'Ipv6 Address not supported'
      return ( None, errno.EINVAL )

   ribEntry = mount.ldpTunnelRib.entry.get( IpGenPrefix( prefix ) )
   if not ribEntry or not ribEntry.tunnelId:
      print 'No tunnel found for prefix %s' % prefix
      return ( None, errno.EINVAL )

   nextHopsAndLabels = []
   for tId in ribEntry.tunnelId.values():
      tunnelTableEntry = mount.ldpTunnelTable.entry.get( tId )
      if not tunnelTableEntry:
         continue
      for via in tunnelTableEntry.via.itervalues():
         if via.labels.stackSize != 1:
            continue
         intfId = via.intfId
         nexthop = via.nexthop.v4Addr
         ldpLabelStack = [ via.labels.labelStack( 0 ) ]

         # Check for the LDP over RSVP case
         rsvpNhAndLabelList = []
         if intfId and DynTunnelIntfId.isDynamicTunnelIntfId( intfId ):
            viaTunnelId = DynTunnelIntfId.tunnelId( intfId )
            if TunnelId( viaTunnelId ).tunnelType() == TunnelType.rsvpLerTunnel:
               rsvpNhAndLabelList, err = (
                     getNhAndLabelInfoFromTunnelId( mount, viaTunnelId ) )
               if err:
                  continue
               for nhAndLabel in rsvpNhAndLabelList:
                  # Use the RVSP resolved info with the LDP label appended
                  nextHopsAndLabels.append(
                        NextHopAndLabel( str( nhAndLabel.nextHopIp ),
                                         nhAndLabel.label + ldpLabelStack,
                                         nhAndLabel.intfId ) )
         # If just a normal LDP tunnel, add the LDP info
         if not rsvpNhAndLabelList:
            nextHopsAndLabels.append( NextHopAndLabel( nexthop, ldpLabelStack,
                                                       intfId ) )

   errVal = errno.EINVAL if not nextHopsAndLabels else 0
   if errVal:
      print 'No tunnel found or tunnel could not be resolved for tunnelId'
   return ( nextHopsAndLabels, errVal )

def getMldpFec( mount, prefix, genOpqVal, sourceAddrOpqVal, groupAddrOpqVal ):
   if isIpv6Addr( prefix ):
      print 'Ipv6 Address not supported'
      return ( None, errno.EINVAL )

   if genOpqVal:
      opqStr = genOpaqueValForP2mpGenericOpaqueId( genOpqVal )
   elif sourceAddrOpqVal and groupAddrOpqVal:
      if isIpv6Addr( sourceAddrOpqVal ) and isIpv6Addr( groupAddrOpqVal ):
         opqStr = genOpaqueValForP2mpTransitV6SrcOpaqueId( sourceAddrOpqVal,
                                                           groupAddrOpqVal )
      elif isIpv4Addr( sourceAddrOpqVal ) and isIpv4Addr( groupAddrOpqVal ):
         opqStr = genOpaqueValForP2mpTransitV4SrcOpaqueId( sourceAddrOpqVal,
                                                           groupAddrOpqVal )
      else:
         print 'Invalid Source/Group Address Opaque Value'
         return ( None, errno.EINVAL )
   else:
      assert False, 'Invalid Opaque value'

   opqIndex = 0 # invalid opaque index
   opqValColl = mount.mldpOpaqueValueTable.opaqueValToIndexColl
   if opqValColl and opqValColl.opaqueVal.get( opqStr ):
      opqIndex = opqValColl.opaqueVal.get( opqStr ).oValIndex

   nextHopsAndLabels = []
   for route in mount.mldpLfib.lfibRoute.itervalues():
      # p2mpfec is formed from rootIp and opqIndex
      if ( route.fec.mldpRootIp.stringValue == prefix and
           route.fec.mldpOpaqueId == opqIndex ):
         vsk = route.viaSetKey
         vs = mount.mldpLfib.viaSet.get( vsk )
         for vk in vs.viaKey.values():
            via = mount.mldpLfib.mplsVia.get( vk )
            if via:
               nexthopAddr = via.nextHop.v4Addr
               label = via.labelStack.topLabel()
               nextHopAndLabel = NextHopAndLabel( nexthopAddr, label, None )
               nextHopsAndLabels.append( nextHopAndLabel )
   errVal = errno.EINVAL if not nextHopsAndLabels else 0
   if errVal:
      print 'No tunnel found for prefix %s' % prefix
   return ( nextHopsAndLabels, errVal )

def getSrFec( mount, prefix ):
   IPv4, IPv6 = AddressFamily.ipv4, AddressFamily.ipv6
   rib, table = mount.srTunnelRib, mount.srTunnelTable

   ribEntry = rib.entry.get( IpGenPrefix( prefix ) )
   if not ribEntry or not ribEntry.tunnelId:
      print 'No tunnel found for prefix %s' % prefix
      return ( None, errno.EINVAL )

   nextHopsAndLabels = []
   for tunnelId in ribEntry.tunnelId.values():
      tunnelTableEntry = table.entry.get( tunnelId )
      if tunnelTableEntry:
         for via in tunnelTableEntry.via.itervalues():
            nexthopList, intfIdList = [], []
            if via.labels.stackSize != 1:
               continue

            if via.nexthop.isAddrZero:
               assert via.nexthop.af == AddressFamily.ipunknown, \
                      'Expected unknown AF for empty IP address'
               # ISIS SR tunnels may be optimized as HFECs, so try
               # resolving the NextLevelFecId if possible, or
               # ISIS SR tunnels might point to TI-LFA tunnel IDs
               # get the primary nexthop, primary interface from the
               # TI-LFA tunnel entry
               nexthopList, intfIdList = resolveHierarchical( mount,
                                                              intf=via.intfId )
               if not nexthopList:
                  continue
            else:
               nexthopList, intfIdList = [ via.nexthop ], [ via.intfId ]

            label = getSrNexthopLabel( mount, via )
            if not label:
               continue

            for nexthop, intfId in zip( nexthopList, intfIdList ):
               assert nexthop.af in [ IPv4, IPv6 ]
               nexthopAddr = nexthop.v4Addr if nexthop.af == IPv4 else \
                             nexthop.v6Addr
               nextHopAndLabel = NextHopAndLabel( nexthopAddr, label, intfId )
               nextHopsAndLabels.append( nextHopAndLabel )

   errVal = errno.EINVAL if not nextHopsAndLabels else 0
   if errVal:
      print 'No tunnel found for tunnelId'
   return ( nextHopsAndLabels, errVal )

def getPwLdpFec( mount, pwLdpName ):
   # Obtain local router id from LDP config from default vrf
   localRouterId = None
   if mount.ldpProtoConfig.protoConfig.get( "default" ):
      localRouterId = mount.ldpProtoConfig.protoConfig[ "default" ].routerId
   if localRouterId is None:
      return None, None, "Local router ID is not configured"

   connectorKey = ConnectorKey()
   connectorKey.ldpConnectorKeyIs( pwLdpName )
   # Obtain remote router id from config
   remoteConnector = mount.pwConfig.connector.get( connectorKey )
   if remoteConnector is None:
      return None, None, "Pseudowire {} is not configured".format( pwLdpName )
   if not remoteConnector.neighborAddrPresent:
      return ( None, None, ( "Neighbor address for pseudowire "
                             "{} is not configured".format( pwLdpName ) ) )
   remoteRouterId = remoteConnector.neighborAddr
   if not remoteConnector.pwIdPresent:
      return ( None, None, ( "Pseudowire ID for pseudowire "
                             "{} is not configured".format( pwLdpName ) ) )
   pwId = remoteConnector.pwId

   pwLdpInfo = PwLdpInfo( IpGenAddr( localRouterId ), remoteRouterId,
                          pwId, None, None, None, None, None )

   # Obtain RCS
   rcs = mount.pwRcs.connectorStatus.get( connectorKey )
   if rcs is None or not rcs.peerInfo:
      return pwLdpInfo, None, "No tunnel in the tunnel RIB to reach this neighbor"

   status = rcs.status
   if status != PseudowireConnectorStatus.up:
      statusMsg = PseudowireConnectorStatusMsg.get( status )
      if statusMsg is None:
         statusMsg = "Unknown status %s" % status
      return ( pwLdpInfo, None,
               "Invalid pseudowire connector status '{}'".format( statusMsg ) )

   # Obtain relevant PW info from RCS/config
   # BUG497800: add support for multiple active peers
   peerInfo = rcs.peerInfo.values()[ 0 ]
   vcLabel = peerInfo.peerLabel
   pwType = rcs.remotePwType
   cvTypes = rcs.peerVccvCvTypes
   ccTypes = rcs.peerVccvCcTypes
   controlWord = peerInfo.encapControlWord
   pwLdpInfo = PwLdpInfo( IpGenAddr( localRouterId ), remoteRouterId,
                          pwId, pwType, vcLabel, cvTypes, ccTypes,
                          controlWord )

   # Lookup forwarding info.
   # We don't care about transport interface at the moment so just getting the
   # tunnelId is okay.
   tunnelId = peerInfo.tunnelId
   if not tunnelId:
      return pwLdpInfo, None, "No tunnel in the tunnel RIB to reach this neighbor"

   def getTunnelInfo( mount, tunnelId ):
      tunEntry = mount.tunnelFib.entry.get( tunnelId )
      if tunEntry is None:
         return None
      tunVia = tunEntry.tunnelVia[ 0 ] if tunEntry.tunnelVia else None
      if tunVia is None:
         return None
      nextHop = tunVia.nexthop
      intfId = tunVia.intfId
      encapId = tunVia.encapId
      if encapId.encapType != Tac.Type( "Tunnel::TunnelTable::EncapType" ).mplsEncap:
         return None
      # NextHopAndLabel code expect labels as int for later usage
      labelStack = None
      if encapId in mount.tunnelFib.labelStackEncap:
         labelStack = mount.tunnelFib.labelStackEncap[ encapId ].labelStack
      if labelStack is None:
         return None
      labels = []
      if labelStack.stackSize == 1:
         labels = [ labelStack.labelStack( 0 ) ]
      else:
         for i in range( labelStack.stackSize ):
            labels.append( labelStack.labelStack( i ) )
      return NextHopAndLabel( nextHop, labels, intfId )

   nextHopAndLabels = getTunnelInfo( mount, tunnelId )
   if nextHopAndLabels is None:
      return pwLdpInfo, None, "Tunnel does not use MPLS encapsulation"
   return pwLdpInfo, nextHopAndLabels, None

def getStaticFec( mount, prefix ):
   if isIpv6Addr( prefix ):
      tacType = 'Arnet::Ip6Prefix'
      routeStatus = mount.route6Status
      forwardingStatus = mount.forwarding6Status
      addr, pLen = prefix.split( '/' )
      routeKey = Tac.Value( tacType, Arnet.Ip6Addr( addr ), int( pLen ) )
   else:
      tacType = 'Arnet::Prefix'
      routeStatus = mount.routeStatus
      forwardingStatus = mount.forwardingStatus
      routeKey = Tac.Value( tacType, stringValue=prefix )
   if routeKey not in routeStatus.route:
      print 'Route %s not configured' % prefix
      return ( None, errno.EINVAL )

   fibRoute = routeStatus.route.get( routeKey )
   if not fibRoute or fibRoute.fecId == FecIdConstants.invalidFecId:
      print 'Invalid nexthop ID'
      return ( None, errno.EINVAL )

   fec = forwardingStatus.fec.get( fibRoute.fecId )
   if not fec:
      return ( None, errno.EINVAL )
   
   return ( fec, 0 )

def getRsvpFec( mount, session, lsp ):
   sessionStateColl = mount.rsvpStatus.sessionStateColl
   spStateColl = mount.rsvpStatus.spStateColl
   rsvpSpIds = []
   rsvpSenderAddr = None
   spStates = []
   lspIds = []

   if not sessionStateColl:
      print 'No RSVP session found'
      return ( None, errno.EINVAL, None, None, None )

   if not spStateColl:
      print 'No RSVP LSP found'
      return ( None, errno.EINVAL, None, None, None )

   #pylint: disable-msg=too-many-nested-blocks
   # SESSION BY ID
   if isinstance( session, int ):
      for sessionState in sessionStateColl.sessionState.values():
         if sessionState.sessionCliId == session:
            for spId in sessionState.spMember:
               spState = spStateColl.spState.get( spId )
               if spState is not None:
                  if not spState.operational:
                     # Ignore LSPs that are not operational, ping simply cannot
                     # work on them
                     continue
                  # If the LSP ID was not specified, store all the possible LSPs;
                  # if it was, stop at the first match
                  if lsp is None:
                     rsvpSpIds.append( spState.spId )
                     spStates.append( spState )
                     lspIds.append( spState.spCliId )
                     rsvpSenderAddr = spState.senderAddr
                  elif spState.spCliId == lsp:
                     rsvpSpIds = [ spState.spId ]
                     spStates = [ spState ]
                     lspIds = [ lsp ]
                     rsvpSenderAddr = spState.senderAddr
                     break
   # SESSION BY NAME
   else:
      for spState in spStateColl.spState.values():
         # Get all the LSP which match sessionName
         if spState.sessionAttributes.sessionName == session:
            if not spState.operational:
               # Ignore LSPs that are not operational, ping simply cannot
               # work on them
               continue
            rsvpSpIds.append( spState.spId )
            spStates.append( spState )
            lspIds.append( spState.spCliId )
            rsvpSenderAddr = spState.senderAddr
   #pylint: enable-msg=too-many-nested-blocks

   if not rsvpSpIds or not spStates:
      if lsp:
         print 'No operational RSVP tunnel found for session %s and LSP %s' % \
               ( session, lsp )
      else:
         print 'No operational RSVP tunnel found for session %s' % session
      return ( None, errno.EINVAL, None, None, None )
   
   if spStates[ 0 ].sessionRole == 'egressSessionRole':
      print 'Session role is egress, nothing to do'
      return ( None, 0, None, None, None )

   # List of all possible adjacencies to reach the destination. In RSVP case, there
   # is only one possible adjacency per spState.
   nextHopsAndLabels = []
   for _, spState in enumerate( spStates ):
      # Ping originating from the PLR will need to go through the bypass tunnel in
      # the sceario of FRR, this requires to get the new next hop and the label for
      # this tunnel, in addition to the label of our previously active downstream.
      if spState.dsFrrInUse:
         nexthop = spState.bypassInfo.bypassNextHop.v4Addr
         bypassLabel = spState.bypassInfo.bypassLabel
         # Retrieve the MP label depending on whether link or node protection
         # is used. With link protection, the label of the MP is the same as
         # the previously used downstream label. With node protection, the
         # label of the MP is the next next hop's label.
         # Note: If the MP is at the egress, its label will be implicit null.
         # We add it regardless, it will be removed later and will not actually
         # be encapsulated.
         if spState.bypassInfo.bypassNodeProtected:
            mpLabel = spState.nnhLabel
         else:
            mpLabel = spState.dsLabel
         labelStack = [ bypassLabel, mpLabel ]
      else:
         nexthop = spState.activeDs.neighborIp.v4Addr
         labelStack = [ spState.dsLabel ]
      nextHopAndLabel = NextHopAndLabel( nexthop, labelStack, None )
      nextHopsAndLabels.append( nextHopAndLabel )

   # nextHopsAndLabels and rsvpSpIds contain at least one element. They can't
   # be empty. lspIds is used to record the lsp id of each tunnel. It contains
   # at least one element.
   return ( nextHopsAndLabels, 0, rsvpSpIds, rsvpSenderAddr, lspIds )

# Calculate the bitset for the specified multipath type
def getMultipathBitset( multipathType, numMultipathBits ):
   multipathBitset = LspPingMultipathBitset()

   if multipathType != 0:
      # Do a lookup for the specified baseip

      # The size of the multipathBitset should be:
      #  "mask of length 2^(32-prefix length) bits"
      # as per RFC4379, and a minimum of 4 bytes.
      multipathBitset.size = ( 2 ** ( numMultipathBits - 1 ).bit_length() ) / 8

      if multipathBitset.size < 4:
         multipathBitset.size = 4

      for byteIndex in range( 0, multipathBitset.size ):
         if numMultipathBits > 7:
            multipathBitset.bitset[ byteIndex ] = 0xFF
            numMultipathBits = numMultipathBits - 8
         else:
            tempByte = 0
            for _ in range( 0, numMultipathBits ):
               tempByte = tempByte << 1
               tempByte = tempByte + 1

            for _ in range( numMultipathBits, 8 ):
               tempByte = tempByte << 1

            multipathBitset.bitset[ byteIndex ] = tempByte
            numMultipathBits = 0

   return multipathBitset

def getMultipathType( multipath, numMultipathBits ):
   if multipath:
      if numMultipathBits == 0:
         return 0
      else:
         return 8

   return 0

# Note that this may return a different value for numMultipathBits than what
# was passed in
def getMultipathInfo( multipath, numMultipathBits ):
   multipathType = getMultipathType( multipath, numMultipathBits )

   if multipathType == 0:
      numMultipathBits = 0

   multipathBitset = getMultipathBitset( multipathType, numMultipathBits )

   return ( multipathType, numMultipathBits, multipathBitset )

# Generate downstream mapping info for LDP, RSVP and Segment-Routing traceroute
def getDsMappingInfo( nexthop, labelStack, l3IntfMtu, multipath, baseip,
                      numMultipathBits ):
   dsLabelStack = Tac.newInstance( 'LspPing::LspPingDownstreamLabelStack' )
   # The given via's label stack should always be a tuple or a list
   stackSize = len( labelStack )
   for i, label in enumerate( labelStack ):
      dsLabelEntry = Tac.newInstance( 'LspPing::LspPingDownstreamLabelEntry' )
      dsLabelEntry.label = label
      dsLabelEntry.bos = ( stackSize == i + 1 )
      assert label <= MplsLabel.max, "Cannot assign a non-valid MPLS label"
      dsLabelStack.labelEntry[ i ] = dsLabelEntry
   dsLabelStack.size = stackSize

   ipGenVia = IpGenAddr( nexthop )
   addrType = 1 if ipGenVia.af == AddressFamily.ipv4 else 3

   multipathType, numMultipathBits, multipathBitset = \
      getMultipathInfo( multipath, numMultipathBits )

   dsMappingAddr = IpGenAddr( baseip )
   dsMappingInfo = Tac.newInstance( 'LspPing::LspPingDownstreamMappingInfo',
                                    l3IntfMtu, addrType, 0, ipGenVia, ipGenVia,
                                    multipathType, 0, dsMappingAddr,
                                    multipathBitset, dsLabelStack, IntfId(),
                                    0, 0, # retcode, retsubcode
                                    True )
   return dsMappingInfo

def _resolveSrTePolicyTunnels( mount, endpoint, color, trafficAf=None ):
   adjToTunnels = {}
   tunnelToAdjacencies = {}

   epAddr = IpGenAddr( str( endpoint ) )
   policyEndpoint = PolicyEndpoint( epAddr )
   policyKey = PolicyKey( policyEndpoint, color )
   ps = mount.policyStatus.status.get( policyKey, None )
   if not ps:
      print "There is no active candidate path for the policy"
      return errno.EINVAL
   policyFecId = ps.labelFecId
   if trafficAf:
      policyFecId = ps.v4FecId if trafficAf == 'v4' else ps.v6FecId
   policyFec = mount.srTeForwardingStatus.fec.get( policyFecId, None )

   if not policyFec:
      print "The FEC for the policy has not been programmed"
      return errno.EINVAL

   nhInfoList = mount.fwdingHelper.resolveHierarchical( True, fecId=policyFecId,
                                                        selectOneVia=False )
   if nhInfoList == [ noMatchNexthopInfo ]:
      print "None of the segment lists for the policy are valid"
      return errno.EINVAL

   # Map each of the (nhIntf, nhMac, segmentList) tuple to a unique clientId.
   # Also map the (segmentList, nhIp) tuple to a set of (nhIntf,nhMac,labelStack)
   # tuples. So In a way we indirectly map the clientId to a set of
   # ( segmentList, nhIp ) tuples called tunnels. This mapping is useful while we
   # print the ping statistics from the echo reply we get.

   for info in nhInfoList:
      tunEntry = ( info.nexthopIp, tuple( info.labelStack ) )
      tunKey = ( tuple( info.labelStack ), info.dstMac, info.intf )
      adjToTunnels.setdefault( tunKey, [] ).append( tunEntry )

   tunnelToAdjacencies = { tuple( v ) : k for k, v in adjToTunnels.items() }
   return tunnelToAdjacencies

def resolveNexthop( mount, state, nexthop, intf=None ):
   if isIpv4Addr( nexthop ):
      if not state.ipv4RoutingSim:
         ipv4RoutingSim = Tac.newInstance( 'Routing::RoutingSimulator', mount.vrf,
                                           mount.routingVrfInfoDir,
                                           mount.routeStatus, 
                                           mount.forwardingStatus,
                                           mount.forwarding6Status,
                                           mount.arpSmash,
                                           mount.arpSmashVrfIdMap,
                                           mount.ipStatus,
                                           mount.ip6Status )
         state.ipv4RoutingSim = ipv4RoutingSim
      rsim = state.ipv4RoutingSim
   else:
      if not state.ipv6RoutingSim:
         ipv6RoutingSim = Tac.newInstance( 'Routing6::Routing6Simulator', mount.vrf,
                                           mount.routing6VrfInfoDir,
                                           mount.route6Status,
                                           mount.forwardingStatus,
                                           mount.forwarding6Status,
                                           mount.arpSmash,
                                           mount.arpSmashVrfIdMap,
                                           mount.ipStatus,
                                           mount.ip6Status )
         state.ipv6RoutingSim = ipv6RoutingSim
      rsim = state.ipv6RoutingSim

   routingOutput = rsim.route( nexthop )
   if routingOutput.outputType in [ RoutingOutputType.invalid,
                                    RoutingOutputType.cpu ]:
      if isIpv4Addr( nexthop ) or intf is None:
         return ( None, None )
      else:
         return ( intf, getIpv6ArpEntry( mount, nexthop, intf ) )
   if routingOutput.nextHopEthAddr == EthAddr().stringValue or \
      routingOutput.nextHopIntf == IntfId( '' ) :
      return ( None, None )

   # FIXME routingOutput.outputType == multiple
   return ( routingOutput.nextHopIntf, routingOutput.nextHopEthAddr )

def resolveNhgTunnelFibEntry( mount, tunnelId ):
   err = None
   tunnelFibEntry = mount.tunnelFib.entry.get( tunnelId )
   if not tunnelFibEntry:
      err = 'No nexthop-group tunnel found in tunnel fib for tunnelId %s' % tunnelId
      return ( None, None, err )
   nexthopGroupId = getNexthopGroupId( tunnelFibEntry.tunnelVia[ 0 ] )
   nhgName = getNhgIdToName( nexthopGroupId, mount )
   if not nhgName:
      err = 'No nexthop-group with id %d' % nexthopGroupId
      return ( None, None, err )
   tunnelIndex = TunnelId( tunnelId ).tunnelIndex()
   return ( nhgName, tunnelIndex, err )

def resolveHierarchical( mount, intf=None, fecId=None ):
   '''
   Resolves the IntfIdFecId into a set of nexthops and intfIds. Used in the scenario
   when HFECs are enabled.
   If a FIB FecId is given as an argument, this method performs a fib lookup and
   generates a set of nexthops corresponding to the FEC.
   If intf is instead a TI-LFA tunnel, get the nexthop IP and interface from the via
   of the TI-LFA tunnel table entry.
   '''
   assert ( ( intf and not fecId ) or ( not intf and fecId ) ), 'Either intf or '\
      'fecId must be provided and also both must not be provided simultaneously'

   if intf and FecIdIntfId.isFecIdIntfId( intf ):
      fecId = FecIdIntfId.intfIdToFecId( intf )

   if fecId:
      if not FecId( fecId ).isValid():
         return None
      fecIdAdjType = FecId( fecId ).adjType()

      # Convert to fibAdj if used by tunnel adj
      if fecIdAdjType == FecAdjType.usedByTunnelV4Adj:
         fecId = FecId.fecIdToNewAdjType( FecAdjType.fibV4Adj, fecId )
      elif fecIdAdjType == FecAdjType.usedByTunnelV6Adj:
         fecId = FecId.fecIdToNewAdjType( FecAdjType.fibV6Adj, fecId )

      # Re-assign in case it changed
      fecIdAdjType = FecId( fecId ).adjType()
      fec = None
      if fecIdAdjType == FecAdjType.fibV4Adj:
         fec = mount.forwardingStatus.fec.get( fecId )
      elif fecIdAdjType == FecAdjType.fibV6Adj:
         fec = mount.forwarding6Status.fec.get( fecId )
      if not fec:
         return ( None, None )

      nhList, intfIdList = [], []
      for i in fec.via:
         genNhAddr = IpGenAddr( str( fec.via[ i ].hop ) )
         intfId = fec.via[ i ].intfId
         if not genNhAddr.isAddrZero:
            nhList.append( genNhAddr )
            intfIdList.append( intfId )
      return nhList, intfIdList

   if intf and DynTunnelIntfId.isDynamicTunnelIntfId( intf ):
      tunId = DynTunnelIntfId.tunnelId( intf )
      # Only handle SR tunnels pointing to TI-LFA tunnels now.
      if TunnelId( tunId ).tunnelType() != 'tiLfaTunnel':
         return None, None
      tunEntry = mount.tiLfaTunnelTable.entry.get( tunId )
      if not tunEntry:
         return None, None
      nh = IpGenAddr( str( tunEntry.via.nexthop ) )
      intfId = tunEntry.via.intfId
      if nh.isAddrZero:
         return [], []
      return [ nh ], [ intfId ]

   return None, None

# Get the nexthop label for an SR FEC. The routine handles the case
# where the SR tunnel is being protected by a TI-LFA tunnel. In such
# a scenario, the nexthop label could be the top label of the label stack in
# the primary via of the TI-LFA tunnel if the SR tunnel is modeled as pop +
# forward into TI-LFA tunnel. For all other cases, return the top label of
# via.labels
def getSrNexthopLabel( mount, via ):
   label = via.labels.labelStack( 0 )
   if label == 3 and DynTunnelIntfId.isDynamicTunnelIntfId( via.intfId ):
      tiLfaTunId = DynTunnelIntfId.tunnelId( via.intfId )
      tunEnt = mount.tiLfaTunnelTable.entry.get( tiLfaTunId )
      if not tunEnt:
         print "No tunnel with ID %x exists in TI-LFA tunnel table" % tiLfaTunId
         return None
      label = tunEnt.via.labels.labelStack( 0 )
   return label

def getBgpLuTunnelFibEntry( mount, prefix ):
   ribEntry = mount.bgpLuTunnelRib.entry.get( IpGenPrefix( prefix ) )
   if not ribEntry or not ribEntry.tunnelId:
      err = 'No BGP labeled unicast tunnel found for prefix %s' % prefix
      return None, err
   tunnelId = ribEntry.tunnelId[ 0 ]
   tunnelFibEntry = mount.tunnelFib.entry.get( tunnelId )
   if not tunnelFibEntry:
      err = bgpLuNoTunnelFoundErr % tunnelId
      return None, err
   if not tunnelFibEntry.tunnelVia:
      err = bgpLuNoTunnelViaFoundErr % tunnelId
      return None, err
   return tunnelFibEntry, None

def getBgpLuLabels( mount, tunnelVia ):
   encapId = tunnelVia.encapId
   if encapId.encapType != Tac.Type( "Tunnel::TunnelTable::EncapType" ).mplsEncap:
      err = ( 'Label stack encap information found in tunnel FIB for'
              ' encap ID %d is not MPLS' % encapId.encapIdValue )
      return None, err
   labelStack = None
   if encapId in mount.tunnelFib.labelStackEncap:
      labelStack = mount.tunnelFib.labelStackEncap[ encapId ].labelStack
   else:
      err = ( 'No label stack encap information found in tunnel FIB for'
              ' encap ID %d' % encapId.encapIdValue )
      return None, err
   labels = labelStackToList( labelStack )
   # Remove implicit null label if present, as label stacks will be concatenated
   while MplsLabel.implicitNull in labels:
      labels.remove( MplsLabel.implicitNull )
   return labels, None

def getNhAndLabelFromTunnelFibVia( mount, tunnelVia ):
   labels = []

   # Resolve DynamicTunnelIntfId tunnelVias
   while DynTunnelIntfId.isDynamicTunnelIntfId( tunnelVia.intfId ):
      labelStack, err = getBgpLuLabels( mount, tunnelVia )
      if err:
         return None, err
      labels = labelStack + labels
      tunnelId = DynTunnelIntfId.tunnelId( tunnelVia.intfId )
      tunnelFibEntry = mount.tunnelFib.entry.get( tunnelId )
      if not tunnelFibEntry:
         err = bgpLuNoTunnelFoundErr % tunnelId
         return None, err
      if not tunnelFibEntry.tunnelVia:
         err = bgpLuNoTunnelViaFoundErr % tunnelId
         return None, err
      tunnelVia = tunnelFibEntry.tunnelVia[ 0 ]

   # Get label stack from final resolved tunnelVia
   labelStack, err = getBgpLuLabels( mount, tunnelVia )
   if err:
      return None, err
   labels = labelStack + labels

   # If final label stack is empty, make it implicit null
   if not labels:
      labels = [ MplsLabel.implicitNull ]

   intfId = tunnelVia.intfId
   nexthop = tunnelVia.nexthop

   nextHopAndLabel = NextHopAndLabel( nexthop, labels, intfId )
   return nextHopAndLabel, None

def validateBgpLuResolvedPushVia( resolvedPushVia, nexthop, label ):
   if not resolvedPushVia:
      err = 'No BGP labeled unicast paths for BGP next hop %s' % nexthop
      if label:
         err += ' label stack %s' % str( label )
      err += ' found'
      return err
   elif len( resolvedPushVia ) > 1 and not label:
      err = ( 'Multiple BGP labeled unicast paths for BGP next hop '
              '%s found, please specify label stack' % nexthop )
      return err

   return None
