# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac
import collections
from ClientCommonLib import ( lspPingRetCodeStr, LspPingReturnCode )

LspPingNhgModeAllEntries, LspPingNhgModeOneEntry = 0, 1

# ---------------------------------------------------------
#               general render helpers
# ---------------------------------------------------------

def lspPingReplyStr( replyPktInfo, expectedRetCode=None ):
   if replyPktInfo:
      string =  "Reply from " + str( replyPktInfo.replyHost )
      string += ": seq="
      string += str( replyPktInfo.seqNum )
      string += ", time="
      string += str( replyPktInfo.roundTrip / 1000.0 )
      string += "ms" # millisecond
      if expectedRetCode is not None:
         string += ","
         if replyPktInfo.retCode == expectedRetCode:
            string += " success: "
         else:
            string += " error: "
         string += lspPingRetCodeStr( replyPktInfo.retCode )
   else:
      string = "Request timeout"
   return string

# ---------------------------------------------------------
#           static ping render helpers
# ---------------------------------------------------------

def printVias( vias, resolved=True ):
   for via in vias:
      intf = '%s, ' % via[ 2 ] if resolved else ''
      if isinstance( via[ 1 ], int ):
         labels = [ via[ 1 ] ]
      else:
         labels = list( via[ 1 ] )
      labelStr = 'label stack: {}'.format( labels )
      print 'Via {}, {}{}'.format( via[ 0 ], intf, labelStr )

def lspPingStaticReplyRender( clientId, replyPktInfo, renderArgs ):
   vias = renderArgs[ 0 ][ clientId ]
   printVias( vias )
   print '   ' + lspPingReplyStr( replyPktInfo )

def lspPingStaticStatisticsRender( clientIds, time, txPkts, replyInfo, renderArgs ):
   clientIdToVias, prefix, unresolvedVias = renderArgs
   if not txPkts:
      return

   print "\n--- static MPLS push-label route " + prefix + ": lspping statistics ---"
   for clientId in clientIds:
      # sum up all recorded RTTs to figure out the total # of echo replies
      if clientId in replyInfo:
         recvNum =  sum( [ len( v ) for v in replyInfo[ clientId ].values() ] )
      else:
         recvNum = 0
      lossRate = 100 - recvNum * 100 / txPkts[ clientId ]
      string = ''
      vias = clientIdToVias.get( clientId )
      if vias is not None:
         printVias( vias )
         string += '   '
      string += str( txPkts[ clientId ] ) + " packets transmitted, "
      string += str( recvNum ) + " received, "
      string += str( lossRate ) + "% packet loss, time " + str( time ) + "ms"
      print string
      if clientId in replyInfo:
         ll = []
         for host, rtts in replyInfo[ clientId ].items():
            ll.append( '{} received from {}, rtt min/max/avg ' 
                       '{:.3f}/{:.3f}/{:.3f} ms'.format( 
                           len( rtts ), host, min( rtts ), max( rtts ),
                           sum( rtts ) / len( rtts ) ) )
            print '   ' + ', '.join( ll )
      print

   if unresolvedVias:
      printVias( unresolvedVias, resolved=False )
      print '   Not resolved'
      print

# ---------------------------------------------------------
#           nhg ping render helpers
# ---------------------------------------------------------

def printTunnels( tunnels ):
   for tunnel in sorted( tunnels, key=lambda entry: entry.entryIndex ):
      print 'Entry {}'.format( tunnel.entryIndex )
      print '   Via {}'.format( tunnel.nexthop )

def printUnresolvedTunnels( tunnels ):
   string = 'Entry'
   maxLineLen = 66
   start = len( string )
   lineLen = start
   for tunnel in sorted( tunnels ):
      l = len( ' %d' % tunnel )
      if lineLen + l > maxLineLen:
         string += '\n' + ' ' * start
         lineLen = start    # reset lineLen
      string += ' %d' % tunnel
      lineLen += l
   print string

def lspPingNhgReplyRender( clientId, replyPktInfo, renderArgs ):
   tunnels = renderArgs[ 0 ][ clientId ]
   clientIdBaseToNhgName = renderArgs[ 4 ]
   prefix = renderArgs[ 5 ]
   nhgNameToNhgTunnelIdx = renderArgs[ 7 ]
   if prefix is not None and clientId in clientIdBaseToNhgName:
      nhgName = clientIdBaseToNhgName[ clientId ]
      if nhgNameToNhgTunnelIdx:
         print '\n%s: nexthop-group tunnel index %d (nexthop-group name: %s)' % \
               ( prefix, nhgNameToNhgTunnelIdx[ nhgName ],
                 clientIdBaseToNhgName[ clientId ] )
      else:
         print '\n%s: nexthop-group route (nexthop-group name: %s)' % ( prefix, 
                                                                        nhgName )
   printTunnels( tunnels )
   print '   ' + lspPingReplyStr( replyPktInfo )

def lspPingNhgStatisticsRender( clientIds, time, txPkts, replyInfo, renderArgs ):
   ( clientIdToTunnels, nhgNames, nhgNameToClientIds, 
     nhgNameToUnresolvedTunnels, _, _, mode, nhgNameToNhgTunnelIdx )  = renderArgs
   if not txPkts:
      return

   for nhgName in nhgNames:
      entryInfo = ''
      if mode == LspPingNhgModeOneEntry:
         entryInfo = ' entry %d' % \
            clientIdToTunnels[ nhgNameToClientIds[ nhgName ][ 0 ] ][ 0 ].entryIndex
      if nhgNameToNhgTunnelIdx:
         print "\n--- nexthop-group tunnel index %d, " \
               "nexthop-group %s%s: lspping statistics ---" % \
               ( nhgNameToNhgTunnelIdx[ nhgName ], nhgName, entryInfo )
      else:
         print "\n--- nexthop-group %s%s: lspping statistics ---" % \
               ( nhgName, entryInfo )
      for clientId in nhgNameToClientIds[ nhgName ]:
         # sum up all recorded RTTs to figure out the total # of echo replies
         if clientId in replyInfo:
            recvNum =  sum( [ len( v ) for v in replyInfo[ clientId ].values() ] )
         else:
            recvNum = 0
         lossRate = 100 - recvNum * 100 / txPkts[ clientId ] 
         string = ''
         tunnels = clientIdToTunnels[ clientId ]
         if tunnels is not None:
            printTunnels( tunnels )
            string += '   '
         string += str( txPkts[ clientId ] ) + " packets transmitted, "
         string += str( recvNum ) + " received, "
         string += str( lossRate ) + "% packet loss, time " + str( time ) + "ms"
         print string
         if clientId in replyInfo:
            ll = []
            for host, rtts in replyInfo[ clientId ].items():
               ll.append( '{} received from {}, rtt min/max/avg ' 
                          '{:.3f}/{:.3f}/{:.3f} ms'.format( 
                              len( rtts ), host, min( rtts ), max( rtts ),
                              sum( rtts ) / len( rtts ) ) )
               print '   ' + ', '.join( ll )    ### what about the line is too long
         print

      # print unresolved tunnels if any
      if nhgNameToUnresolvedTunnels[ nhgName ]:
         printUnresolvedTunnels( nhgNameToUnresolvedTunnels[ nhgName ] )
         print '   Not resolved or configured\n'

# ---------------------------------------------------------
#           pwLdp ping render helpers
# ---------------------------------------------------------

def lspPingPwLdpReplyRender( clientId, replyPktInfo, renderArgs ):
   vias = renderArgs[ 0 ][ clientId ]
   printVias( vias )
   print '   ' + \
         lspPingReplyStr( replyPktInfo,
                          expectedRetCode=LspPingReturnCode.repRouterEgress )

def lspPingPwLdpStatisticsClientRender( clientId, vias, time, numPktsSent,
                                            replyHostRtts ):
   recvNum = 0 if replyHostRtts is None or not replyHostRtts else \
       sum( len( rtts ) for rtts in replyHostRtts.values() )

   lossRate = 100 - recvNum * 100 / numPktsSent

   string = str( numPktsSent ) + " packets transmitted, "
   string += str( recvNum ) + " received, "
   string += str( lossRate ) + "% packet loss, time " + str( time ) + "ms"
   print string

   # Can we really have multiple hosts in response?
   if replyHostRtts is not None:
      ll = []
      for host, rtts in replyHostRtts.items():
         ll.append( '{} received from {}, rtt min/max/avg '
                    '{:.3f}/{:.3f}/{:.3f} ms'.format(
               len( rtts ), host, min( rtts ), max( rtts ),
               sum( rtts ) / len( rtts ) ) )
         print ', '.join( ll )

def lspPingPwLdpStatisticsRender( clientIds, time, txPkts, replyInfo, renderArgs ):
   clientIdToVias = renderArgs[ 0 ]
   pwName = renderArgs[ 1 ]
   if not txPkts:
      return

   print "\n--- LDP pseudowire %s : lspping statistics ---" % ( pwName )

   for clientId in clientIds:
      packetsSent = txPkts.get( clientId )
      if packetsSent != None and packetsSent > 0:
         vias = clientIdToVias.get( clientId )
         lspPingPwLdpStatisticsClientRender( clientId, vias, time, packetsSent,
                                           replyInfo.get( clientId ) )
         print

# ---------------------------------------------------------
#           SrTe ping render helpers
# ---------------------------------------------------------

SrTePingRenderArgs = collections.namedtuple( 'SrTePingRenderArgs',
                                             'clientIdToTunnels endpoint color' )

def printSrTeTunnels( tunnels ):
   '''
   Prints the Tunnels of the form:
   Segment List Label Stack: [21, 22, 32]
       Via 1.1.1.2
   '''
   for tunnel in tunnels:
      via, segmentList = tunnel
      segmentListRepr = str( list( segmentList ) )
      print 'Segment list label stack: {}'.format( segmentListRepr )
      print '   Via: {}'.format( via )

def lspPingSrTeReplyRender( clientId, replyPktInfo, renderArgs ):
   tunnels = renderArgs[ 0 ][ clientId ]
   printSrTeTunnels( tunnels )
   print '   ' + lspPingReplyStr( replyPktInfo,
                                  expectedRetCode=LspPingReturnCode.repRouterEgress )

def lspPingSrTeReplyRenderWithoutCode( clientId, replyPktInfo, renderArgs ):
   tunnels = renderArgs[ 0 ][ clientId ]
   printSrTeTunnels( tunnels )
   print '   ' + lspPingReplyStr( replyPktInfo )

def lspPingSrTeStatisticsRender( clientIds, time, txPkts, replyInfo, renderArgs ):
   clientIdToTunnels, endpoint, color = renderArgs
   if not txPkts:
      return

   print "\n--- SR-TE Policy endpoint: {} color: {} lspping statistics ---".format(
      str( endpoint ), color )
   for clientId in clientIds:
      if clientId in replyInfo:
         recvNum = sum( [ len( v ) for v in replyInfo[ clientId ].values() ] )
      else:
         recvNum = 0
      lossRate = 100 - recvNum * 100 / txPkts[ clientId ]
      tunnels = clientIdToTunnels[ clientId ]
      if tunnels is not None:
         printSrTeTunnels( tunnels )
      string = '   '
      string += str( txPkts[ clientId ] ) + " packets transmitted, "
      string += str( recvNum ) + " received, "
      string += str( lossRate ) + "% packet loss, time " + str( time ) + "ms"
      print string
      if clientId in replyInfo:
         ll = []
         for host, rtts in replyInfo[ clientId ].items():
            ll.append( '{} received from {}, rtt min/max/avg '
                       '{:.3f}/{:.3f}/{:.3f} ms'.format(
                          len( rtts ), host, min( rtts ), max( rtts ),
                          sum( rtts ) / len( rtts ) ) )
            print '   ' + ', '.join( ll )
      print

# ---------------------------------------------------------
#           raw ping render helpers
# ---------------------------------------------------------

def lspPingRawReplyRender( clientId, replyPktInfo, fec ):
   print lspPingReplyStr( replyPktInfo )

def lspPingRawStatisticsClientRender( clientId, time, numPktsSent, replyHostRtt ):
   recvNum = 0 if replyHostRtt is None or not replyHostRtt else \
       sum( [ len( rtts ) for rtts in replyHostRtt.values() ] )
   lossRate = 100 - recvNum * 100 / numPktsSent
   string = str( numPktsSent ) + " packets transmitted, "
   string += str( recvNum ) + " received, "
   string += str( lossRate ) + "% packet loss, time " + str( time ) + "ms"
   print string
   print

def lspPingRawStatisticsRender( clientIds, time, txPkts, replyInfo, fec ):
   if not txPkts:
      return

   print "\n--- %s : lspping statistics ---" % fec
   for clientId in clientIds:
      numPktsSent = txPkts.get( clientId )
      if numPktsSent > 0:
         lspPingRawStatisticsClientRender( clientId, time, numPktsSent,
                                           replyInfo.get( clientId ) )

# ---------------------------------------------------------
#           RSVP ping render helpers
# ---------------------------------------------------------

def lspPingRsvpReplyRender( clientId, replyPktInfo, renderArgs ):
   # renderArgs[ 3 ] is a dictionary which stores the lspIds associated with
   # its corresponding clientID. lspId refers to the spCliId which can be specified
   # in the RSVP ping command with the `lsp` argument.
   # If there is no entry in renderArgs[ 3 ], it means that the `lsp` argument in
   # the ping command was specified and there is no need to specify the LSP in the
   # output.
   clientIdToLspId = renderArgs[ 3 ]
   if clientIdToLspId:
      lspId = clientIdToLspId[ clientId ]
      print 'LSP {}'.format( lspId )
   # Get the vias from clientId
   vias = renderArgs[ 0 ][ clientId ]
   printVias( vias ) 
   expectedRetCode = LspPingReturnCode.repRouterEgress
   print '   ' + lspPingReplyStr( replyPktInfo, expectedRetCode )

def lspPingRsvpStatisticsClientRender( clientId, vias, time, numPktsSent,
                                            replyHostRtts, clientIdToLsp ):
   recvNum = 0 if replyHostRtts is None or not replyHostRtts else \
       sum( len( rtts ) for rtts in replyHostRtts.values() )

   lossRate = 100 - recvNum * 100 / numPktsSent

   if clientIdToLsp:
      print 'LSP {}'.format( clientIdToLsp[ clientId ] )

   if vias is not None:
      printVias( vias )
   string = str( numPktsSent ) + " packets transmitted, "
   string += str( recvNum ) + " received, "
   string += str( lossRate ) + "% packet loss, time " + str( time ) + "ms"
   print '   ' + string

   if replyHostRtts is not None:
      ll = []
      for host, rtts in replyHostRtts.items():
         ll.append( '{} received from {}, rtt min/max/avg '
                    '{:.3f}/{:.3f}/{:.3f} ms'.format(
               len( rtts ), host, min( rtts ), max( rtts ),
               sum( rtts ) / len( rtts ) ) )
         print '   ' + ', '.join( ll )

def lspPingRsvpStatisticsRender( clientIds, time, txPkts, replyInfo,
                                      renderArgs ):
   clientIdToVias = renderArgs[ 0 ]
   prefix = renderArgs[ 1 ]
   protocol = renderArgs[ 2 ]
   clientIdToLsp = renderArgs[ 3 ]
   if not txPkts:
      return

   print "\n--- %s target fec %s : lspping statistics ---" % ( protocol, prefix )

   for clientId in clientIds:
      packetsSent = txPkts.get( clientId )
      if packetsSent != None and packetsSent > 0:
         vias = clientIdToVias.get( clientId )
         lspPingRsvpStatisticsClientRender( clientId, vias, time, packetsSent,
                                            replyInfo.get( clientId ),
                                            clientIdToLsp )
         print

# -----------------------------------------------------------------
#        LDP, MLDP and Segment Routing Ping Render helpers
# -----------------------------------------------------------------

def lspPingLabelDistReplyRender( clientId, replyPktInfo, renderArgs ):
   vias = renderArgs[ 0 ][ clientId ]
   printVias( vias )
   print '   ' + \
         lspPingReplyStr( replyPktInfo,
                          expectedRetCode=LspPingReturnCode.repRouterEgress )

def lspPingLabelDistStatisticsClientRender( vias, time, numPktsSent,
                                            protocol, replyHostRtts ):
   recvNum = 0 if replyHostRtts is None or not replyHostRtts else \
       sum( len( rtts ) for rtts in replyHostRtts.values() )

   lossRate = 100 - recvNum * 100 / numPktsSent

   if vias is not None:
      printVias( vias )
   string = str( numPktsSent ) + " packets transmitted, "
   string += str( recvNum ) + " received"
   if protocol != 'MLDP':
      string += ", " + str( lossRate ) + "% packet loss"
   string += ", time " + str( time ) + "ms"
   print '   ' + string

   if replyHostRtts is not None:
      ll = []
      for host, rtts in replyHostRtts.items():
         ll.append( '{} received from {}, rtt min/max/avg '
                    '{:.3f}/{:.3f}/{:.3f} ms'.format(
               len( rtts ), host, min( rtts ), max( rtts ),
               sum( rtts ) / len( rtts ) ) )
         print '   ' + ', '.join( ll )

def lspTracerouteLabelDistStatisticsClientRender( vias, time, numPktsSent,
                                                  protocol, replyHostRtts ):
   recvNum = ( 0 if replyHostRtts is None or not replyHostRtts
               else sum( len( rtts ) for rtts in replyHostRtts.values() ) )
   lossRate = 100 - recvNum * 100 / numPktsSent

   if vias:
      printVias( vias )
   lossStr = ''
   if protocol != 'mldp':
      lossStr = ", {}% packet loss".format( lossRate )
   stats = ( '{sent} packets transmitted, {recv} received{loss}, '
             'time {time}ms' ).format( sent=numPktsSent,
                                       recv=recvNum,
                                       loss=lossStr,
                                       time=time )
   print '   ' + stats

   if replyHostRtts:
      for host, rtts in replyHostRtts.items():
         if rtts:
            print ( '   {} received from {}, rtt min/max/avg '
                    '{:.3f}/{:.3f}/{:.3f} ms'.format(
                    len( rtts ), host, min( rtts ), max( rtts ),
                    sum( rtts ) / len( rtts ) ) )


def lspPingLabelDistStatisticsRender( clientIds, time, txPkts, replyInfo,
                                      renderArgs ):
   clientIdToVias = renderArgs[ 0 ]
   prefix = renderArgs[ 1 ]
   protocol = renderArgs[ 2 ]
   if not txPkts:
      return

   print "\n--- %s target fec %s : lspping statistics ---" % ( protocol, prefix )

   for clientId in clientIds:
      packetsSent = txPkts.get( clientId )
      if packetsSent != None and packetsSent > 0:
         vias = clientIdToVias.get( clientId )
         lspPingLabelDistStatisticsClientRender( vias, time,
                                                 packetsSent, protocol,
                                                 replyInfo.get( clientId ) )
         print

def lspTracerouteLabelDistStatisticsRender( time, txPkts, replyHostRtts,
                                            protocol, prefix, vias ):
   if txPkts > 0:
      print "\n--- %s target fec %s : lsptraceroute statistics ---" % (
         protocol.upper(), prefix )
      lspTracerouteLabelDistStatisticsClientRender( vias, time, txPkts, protocol,
                                                    replyHostRtts )
      print

# -----------------------------------------------------------------
#        BGP LU Ping Render helpers
# -----------------------------------------------------------------

def lspPingBgpLuReplyRender( clientId, replyPktInfo, renderArgs ):
   vias = renderArgs[ 0 ][ clientId ]
   printVias( vias )
   print ( '   ' +
           lspPingReplyStr( replyPktInfo,
                            expectedRetCode=LspPingReturnCode.repRouterEgress ) )

def lspPingBgpLuStatisticsClientRender( clientId, vias, time, numPktsSent,
                                        replyHostRtts ):
   recvNum = ( 0 if replyHostRtts is None or not replyHostRtts else
               sum( len( rtts ) for rtts in replyHostRtts.values() ) )

   lossRate = 100 - recvNum * 100 / numPktsSent

   if vias is not None:
      printVias( vias )
   string = ( '{} packets transmitted, {} received, '
              '{}% packet loss, time {}ms' ).format( numPktsSent, recvNum,
                                                     lossRate, time )
   print '   ' + string

   if replyHostRtts is not None:
      hostInfo = []
      for host, rtts in replyHostRtts.items():
         hostInfo.append( '{} received from {}, rtt min/max/avg '
                          '{:.3f}/{:.3f}/{:.3f} ms'.format( len( rtts ),
                                                            host,
                                                            min( rtts ),
                                                            max( rtts ),
                                                            ( sum( rtts ) /
                                                              len( rtts ) ) ) )
         print '   ' + ', '.join( hostInfo )

def lspPingBgpLuStatisticsRender( clientIds, time, txPkts, replyInfo, renderArgs ):
   clientIdToVias, prefix, protocol, unresolvedVias = renderArgs
   if not txPkts:
      return

   print "\n--- %s target fec %s : lspping statistics ---" % ( protocol, prefix )

   for clientId in clientIds:
      packetsSent = txPkts.get( clientId )
      if packetsSent != None and packetsSent > 0:
         vias = clientIdToVias.get( clientId )
         lspPingBgpLuStatisticsClientRender( clientId, vias, time, packetsSent,
                                             replyInfo.get( clientId ) )
         print

   if unresolvedVias:
      printVias( unresolvedVias, resolved=False )
      print '   Not resolved'
      print
