# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
import prettytable

from CliModel import ( Bool, DeferredModel, Dict, Enum, Float, Model,
                       Int, Str, Submodel, List )
import Arnet
from ArnetModel import IpGenericAddress
from IntfModels import Interface
import Tac

# Aliasing types
SessionId = Int
LspId = Int
NeighborAddress = IpGenericAddress
MplsLabel = Int
BwPriority = Tac.Type( "TrafficEngineering::BwPriority" )

def roundBw( number, digits ):
   r = round( number, digits )
   if number > 0.0 and r == 0.0:
      # round up if a non-zero number is be rounded down to zero,
      # just so customers are not confused seeing zero bandwidth
      # when they know there are active reservations in the systems
      r = 10**( -digits )
   return r

class NeighborLspInfo( Model ):
   sessionId = SessionId( help="Session ID" )
   lspId = LspId( help="LSP ID" )
   sessionName = Str( help="Session name", optional=True )

class NeighborBypassInfo( Model ):
   bypassTunnel = Enum( help="Bypass tunnel state",
                        values=( 'available', 'inUse', 'notAvailable',
                                 'notRequested' ) )
   bypassTunnelName = Str( help="Bypass tunnel name", optional=True )
   bypassNextHop = IpGenericAddress( help="Bypass tunnel next hop", optional=True )
   bypassLabel = MplsLabel( help="Bypass tunnel label", optional=True )

class RsvpNeighbor( Model ):
   upstreams = List( help="Upstream session LSP information",
                     valueType=NeighborLspInfo )
   downstreams = List( help="Downstream session LSP information",
                      valueType=NeighborLspInfo )
   neighborCreationTime = Float( help="UTC time when neighbor was created",
                                 optional=True )
   authentication = Enum( help="Cryptographic authentication",
                          values=( 'disabled', 'MD5' ),
                          optional=True )

   # Not set in summary mode
   lastHelloReceived = Float(
         optional=True,
         help="UTC time when the last RSVP hello message was received" )
   lastHelloSent = Float(
         optional=True,
         help="UTC time when the last RSVP hello message was sent" )

   # Bypass tunnel with link protection
   bypassTunnel = Enum( help="Bypass tunnel state",
                        values=( 'available', 'inUse', 'notAvailable',
                                 'notRequested' ) )
   bypassTunnelName = Str( help="Bypass tunnel name", optional=True )
   bypassNextHop = IpGenericAddress( help="Bypass tunnel next hop", optional=True )
   bypassLabel = MplsLabel( help="Bypass tunnel label", optional=True )

   # Bypass tunnels with node protection
   bypassInfoNode = Dict( help="A dictionary of bypass tunnel state for node "
                               "protection keyed by tunnel destination IP address",
                          keyType=IpGenericAddress, valueType=NeighborBypassInfo,
                          optional=True )

class RsvpNeighbors( DeferredModel ):
   '''Model for "show mpls rsvp neighbor" '''
   neighbors = Dict( help="Neighbor information",
                     keyType=NeighborAddress,
                     valueType=RsvpNeighbor )

class RsvpNeighborStatistics( Model ):
   upstreamSessionCount = Int( help="Number of upstream sessions" )
   upstreamLspCount = Int( help="Number of upstream LSPs" )
   downstreamSessionCount = Int( help="Number of downstream sessions" )
   downstreamLspCount = Int( help="Number of downstream LSPs" )

class RsvpNeighborSummary( DeferredModel ):
   '''Model for "show mpls rsvp neighbor summary" '''
   neighborSummary = Dict( help="Neighbor summary",
                           keyType=NeighborAddress,
                           valueType=RsvpNeighborStatistics )

class LspCount( Model ):
   total = Int( help="Number of LSPs" )
   operational = Int( help="Number of operational LSPs" )
   usingBypassTunnels = Int( help="Number of LSP using bypass tunnels" )
   ingress = Int( help="Number of ingress LSPs" )
   transit = Int( help="Number of transit LSPs" )
   egress = Int( help="Number of egress LSPs" )

class RsvpState( Model ):
   '''Model for 'show mpls rsvp' '''
   adminState = Enum( help="Is RSVP enabled by config",
                      values=( "enabled", "disabled" ) )
   operationalState = Enum( help="Operational state of RSVP",
                            values=( "up", "down" ) )
   refreshInterval = Float( help="Time interval in seconds between "
                                 "RSVP neighbor state refresh messages" )
   refreshReduction = Bool( help="Is bundling of RSVP neighbor messages enabled" )
   helloEnabled = Bool( help="Are hello messages enabled" )
   helloInterval = Float( help="Time interval in seconds between "
                               "RSVP hello messages. "
                               "Set when helloEnabled is True",
                          optional=True )
   helloMultiplier = Float( help="Number of missed hellos after which "
                                 "the neighbor is expired. "
                                 "Set when helloEnabled is True",
                            optional=True )
   fastReroute = Bool( help="Is fast re-route enabled" )
   fastRerouteMode = Enum( help="Fast re-route mode",
                           values=( "none", "linkProtection", "nodeProtection" ) )
   hierarchicalFecsEnabled = Bool( help="Are hierarchical FECs enabled" )
   reversionModeConfig = Enum( help="Configured FRR reversion mode",
                               values=( "global", "local" ), optional=True )
   reversionMode = Enum( help="Operational FRR reversion mode",
                         values=( "global", "local" ), optional=True )
   authentication = Enum( help="Cryptographic authentication",
                          values=( 'disabled', 'MD5' ) )
   srlgMode = Enum( help="SRLG mode",
                    values=( "srlgNone", "srlgLoose", "srlgStrict" ) )
   preemptionPeriod = Float( help="Preemption timer value in seconds" )
   mtuSignalingEnabled = Bool( help='Is MTU signaling enabled' )
   sessionCount = Int( help="Number of sessions" )
   lspCount = Submodel( help="LSP related counts", valueType=LspCount )
   ingressCount = Int( help="Number of ingress sessions" )
   transitCount = Int( help="Number of transit sessions" )
   egressCount = Int( help="Number of egress sessions" )
   bypassTunnelCount = Int( help="Number of bypass tunnels" )
   neighborCount = Int( help="Number of neighbors" )
   interfaceCount = Int( help="Number of interfaces" )
   labelLocalTerminationMode = Enum( help='Label local termination mode',
                                     values=( "implicitNull", "explicitNull" ),
                                     optional=True )

   def render( self ):
      print "Administrative state: %s" % self.adminState
      print "Operational state: %s" % self.operationalState
      print "Refresh interval: %d seconds" % self.refreshInterval
      print "Refresh reduction: %s" % (
            "enabled" if self.refreshReduction else "disabled" )
      print "Hello messages: %s" % (
            'enabled' if self.helloEnabled else 'disabled' )
      if self.helloEnabled:
         print "   Hello interval: %d seconds" % self.helloInterval
         print "   Hello multiplier: %d" % self.helloMultiplier
      print "Fast Re-Route: %s" % (
            "enabled" if self.fastReroute else "disabled" )
      frrModeStr = "none"
      if self.fastRerouteMode == 'nodeProtection':
         frrModeStr = "node protection"
      elif self.fastRerouteMode == 'linkProtection':
         frrModeStr = "link protection"
      print "   Mode: %s" % frrModeStr
      print "   Hierarchical FECs: %s" % (
            "enabled" if self.hierarchicalFecsEnabled else "disabled" )
      if self.fastReroute and self.reversionMode != None:
         if self.reversionMode == 'global' and self.reversionModeConfig == 'local':
            reversionModeStr = "global (implied by node protection)"
         else:
            reversionModeStr = self.reversionMode
         print "   Reversion: %s" % reversionModeStr
      if self.authentication == 'disabled':
         authStr = 'disabled'
      else:
         authStr = "enabled (%s)" % self.authentication
      print "Cryptographic authentication: %s" % authStr
      d = {
         'srlgNone' : 'none',
         'srlgLoose' : 'loose',
         'srlgStrict' : 'strict',
      }
      if self.srlgMode in d:
         print "SRLG mode: %s" % d[ self.srlgMode ]
      if self.preemptionPeriod == 0:
         print "Soft preemption: disabled"
      else:
         print "Soft preemption: enabled"
         print "   Preemption timer: %d seconds" % self.preemptionPeriod
      print 'MTU signaling: %s' % (
            'enabled' if self.mtuSignalingEnabled else 'disabled' )

      d = {
         'explicitNull' : 'explicit null',
         'implicitNull' : 'implicit null',
      }
      if self.labelLocalTerminationMode in d:
         print "Label type for local termination: %s" % (
               d[ self.labelLocalTerminationMode ] )
      print "Number of sessions: %d" % self.sessionCount
      print "   Ingress/Transit/Egress: %d/%d/%d" \
         % ( self.ingressCount, self.transitCount, self.egressCount )
      print "Number of LSPs: %d" % self.lspCount.total
      print "   Operational: %d" % self.lspCount.operational
      print "   Ingress/Transit/Egress: %d/%d/%d" \
         % ( self.lspCount.ingress, self.lspCount.transit, self.lspCount.egress )
      print "   Currently using bypass tunnels: %d" % \
            self.lspCount.usingBypassTunnels
      print "Number of bypass tunnels: %d" % self.bypassTunnelCount
      print "Number of neighbors: %d" % self.neighborCount
      print "Number of interfaces: %d" % self.interfaceCount

class LspNeighborInfo( Model ):
   address = NeighborAddress( help="Neighbor address" )

   # Not set in summary mode
   localInterface = Interface( help="Local interface", optional=True )
   lastRefreshReceived = Float(
         optional=True,
         help="UTC time when the last neighbor refresh was received"
         )
   lastRefreshSent = Float(
         optional=True,
         help="UTC time when the last neighbor refresh was sent" )

class Lsp( Model ):
   state = Enum( help="Operational state",
                 values=( 'up', 'down', 'pathOnly' ) )
   # Not set in summary mode
   lspType = Enum( help="LSP purpose",
                   values=( 'primary', 'bypass' ), optional=True )
   pathState = Enum( help="Path state",
                     values=( 'up', 'down', 'receivedOnly', 'sentOnly' ) )
   resvState = Enum( help="Reservation state",
                     values=( 'up', 'down', 'receivedOnly', 'sentOnly' ) )

   lspId = Int( help="LSP ID", optional=True )

   # Not set in summary mode
   lspCreationTime = Float( help="UTC time when LSP was created", optional=True )

   # Not set in summary mode
   sourceAddress = IpGenericAddress( help="Source address", optional=True )

   bypassTunnel = Enum( help="Bypass tunnel state",
                        values=( 'available', 'inUse', 'notAvailable',
                                 'notRequested', 'egress' ) )

   # Not set in summary mode
   bypassTunnelName = Str( help="Bypass tunnel name", optional=True )
   bypassNextHop = IpGenericAddress( help="Bypass tunnel next hop", optional=True )
   bypassLabel = MplsLabel( help="Bypass tunnel label", optional=True )

   # Not set when the session is in ingressRole or in summary mode
   upstreamNeighbor = Submodel( help="Upstream neighbor information",
                                valueType=LspNeighborInfo,
                                optional=True )
   backupUpstreamNeighborAddress = NeighborAddress( optional=True,
         help="Upstream backup neighbor address" )

   # Not set when the session is in egressRole or in summary mode
   downstreamNeighbor = Submodel( help="Downstream neighbor information",
                                  valueType=LspNeighborInfo,
                                  optional=True )
   backupDownstreamNeighborAddress = NeighborAddress( optional=True,
         help="Downstream backup neighbor address" )

   # Not set when the session is in ingressRole or in summary mode
   localLabel = MplsLabel( help="Local MPLS label",
                           optional=True )
   # Not set when the session is in egressRole or in summary mode
   downstreamLabel = MplsLabel( help="Downstream MPLS label",
                                optional=True )

   sessionName = Str( help="Head end session name", optional=True )

   # Not set in summary mode
   path = List( valueType=str, help="Explicit path", optional=True )

   bypassLastHopLabel = MplsLabel( help="Merge point label", optional=True )
   frrRequestedMode = Enum( help="Fast re-route mode requested",
                            values=( "none", "linkProtection", "nodeProtection" ),
                            optional=True )
   frrOperationalMode = Enum( help="Fast re-route mode operational",
                              values=( "none", "linkProtection", "nodeProtection" ),
                              optional=True )
   bandwidth = Float( help="Requested bandwidth in bps", optional=True )
   bandwidthState = Enum(
      help="Current status of bandwidth reservation",
      values=( "unknown", "pending", "failedAdmission", "failedPreemption",
               "preempted", "reserved" ),
      optional=True )
   # Only set in detail mode
   setupPriority = Int( help="Setup priority", optional=True )
   holdPriority = Int( help="Hold priority", optional=True )
   softPreemptionRequested = Bool( help="Soft preemption has been requested",
                                   optional=True )
   mtuSignalingEnabled = Bool( help="MTU signaling is enabled for this LSP",
                               optional=True )
   inMtu = Int( help="MTU received in the incoming path message", optional=True )
   outMtu = Int( help="MTU received in the outgoing path message", optional=True )
   # Only set in history mode
   lspHistory = List( valueType=str, help="Last 10 LSP history events",
                      optional=True )

class RsvpSession( Model ):
   destination = IpGenericAddress( help="Destination address" )
   # Not set in summary mode
   tunnelId = Int( help="Tunnel ID", optional=True )
   # Not set in summary mode
   extendedTunnelId = IpGenericAddress( help="Extended tunnel ID",
                                        optional=True )
   # Not set in summary mode
   state = Enum( help="Session operational state",
                 values=( 'up', 'down' ),
                 optional=True )
   sessionCreationTime = Float( help="UTC time when session was created",
                                optional=True )
   sessionRole = Enum( help="Role for Session",
                       values=( 'ingress', 'egress', 'transit', 'unknown' ) )
   lsps = Dict( help="Table of LSP information keyed by LSP number",
                keyType=int, valueType=Lsp )
   # Only set in history mode
   sessionHistory = List( valueType=str, help="Last 10 session history events",
                          optional=True )

class RsvpSessions( DeferredModel ):
   '''Model for "show mpls rsvp session" and
      "show mpls rsvp session summary" 
      "show mpls rsvp session history" '''
   sessions = Dict( help="Table of session information keyed by session number",
                    keyType=int,
                    valueType=RsvpSession )

class RsvpMessageCounter( Model ):
   counts = Dict( keyType=str, valueType=int,
                  help="Message counts keyed by RsvpMessageType("
                       "Path, PathTear, PathErr, Resv, ResvTear"
                       "Srefresh, Other, Errors )" )

   def countInc( self, mType, ct ):
      currentCt = self.counts.get( mType, 0 )
      self.counts[ mType ] = currentCt + ct

   def renderRow( self ):
      orderedMessageTypes = [ "Path", "PathTear", "PathErr",
                             "Resv", "ResvTear", "ResvErr",
                             "Srefresh", "Other", "Errors" ]
      return [ self.counts.get( mType, 0 ) for mType in orderedMessageTypes ]


class RsvpInterfaceCounter( Model ):
   rx = Submodel( help="Received message counts",
                  valueType=RsvpMessageCounter )
   tx = Submodel( help="Sent message counts",
                  valueType=RsvpMessageCounter )

def printTable( table ):
   table.border = False
   table.set_style( prettytable.PLAIN_COLUMNS )
   table.right_padding_width = 2
   table.align = 'l'
   print table

def formatTable( table ):
   table.border = False
   table.set_style( prettytable.PLAIN_COLUMNS )
   table.right_padding_width = 2
   for col in table.field_names:
      table.align[ col ] = 'l'

def bandwithToPercentage( totalReserved, totalBandwidth ):
   if totalBandwidth > 0.0:
      use = totalReserved / totalBandwidth * 100.0
      percentageUse = str( round( use, 1 ) )
   else:
      percentageUse = "N/A"
   return percentageUse

def sortIntf( intfs ):
   '''Sort interfaces, but the default IntfId goes last.
   '''
   intfs = Arnet.sortIntf( intfs )
   if intfs and not intfs[ 0 ]:
      intfs = intfs[ 1: ] + intfs[ :1 ]
   return intfs

class RsvpMessageCounters( Model ):
   '''Model for "show mpls rsvp counters" '''
   interfaces = Dict( help="Interface message counters",
                      keyType=Interface,
                      valueType=RsvpInterfaceCounter )

   def intfModel( self, intfId ):
      model = self.interfaces.get( intfId )
      if model is None:
         model = RsvpInterfaceCounter()
         model.rx = RsvpMessageCounter()
         model.tx = RsvpMessageCounter()
         self.interfaces[ intfId ] = model

      return model

   def render( self ):
      headers = [ "Interface", "Path", "PathTear", "PathErr",
                  "Resv", "ResvTear", "ResvErr", "Srefresh",
                  "Other", "Errors" ]
      separator = [ "-" * len( col ) for col in headers ]

      rxTable = prettytable.PrettyTable( headers )
      rxTable.add_row( separator )
      txTable = prettytable.PrettyTable( headers )
      txTable.add_row( separator )

      for intf in sortIntf( self.interfaces ):
         counters = self.interfaces[ intf ]
         for table, model in [ ( rxTable, counters.rx ),
                               ( txTable, counters.tx ) ]:
            row = [ intf if intf else "Unknown" ]
            row.extend( model.renderRow() )
            table.add_row( row )

      print "Received Messages:\n"
      printTable( rxTable )

      print "\nSent Messages:\n"
      printTable( txTable )

class RsvpBandwidthIntfModel( Model ):
   reservedBandwidth = Dict(
         help="A map of priority to bandwidth reserved in bits per second",
         keyType=int,
         valueType=float )
   totalBandwidth = Float(
         help="Total bandwidth available to RSVP in bits per second" )

class RsvpBandwidthIntfSummaryModel( Model ):
   reservedBandwidth = Float( help="Bandwidth reserved in bits per second" )
   totalBandwidth = Float(
         help="Total bandwidth available to RSVP in bits per second" )

class RsvpBandwidthSummaryModel( Model ):
   interfaces = Dict( help="A mapping of interface to bandwidth summary",
                      keyType=Interface, valueType=RsvpBandwidthIntfSummaryModel )

   def render( self ):
      # Setup Table Headers
      header1 = [ "Interface",
                  "Reserved bandwidth",
                  "Percentage used",
                  "Total bandwidth" ]
      header2 = [ "         ",
                  "              Mbps",
                  "              %",
                  "           Mbps" ]
      separator = [ "-" * ( len( col ) + 1 ) for col in header2 ]

      table = prettytable.PrettyTable( header1 )
      table.field_names = header1
      table.add_row( header2 )
      table.add_row( separator )

      # Align Numeric values to the right
      formatTable( table )
      table.align[ "Reserved bandwidth" ] = 'r'
      table.align[ "Percentage used" ] = 'r'
      table.align[ "Total bandwidth" ] = 'r'

      # Add rows
      for intfId in sortIntf( self.interfaces ):
         intfModel = self.interfaces[ intfId ]
         reservedBandwidthMbps = roundBw( intfModel.reservedBandwidth / 1e6, 1 )
         percentageUse = bandwithToPercentage( intfModel.reservedBandwidth,
                                               intfModel.totalBandwidth )
         totalBandwidthMbps = roundBw( intfModel.totalBandwidth / 1e6, 1 )
         table.add_row( [ intfId,
                          reservedBandwidthMbps,
                          percentageUse,
                          totalBandwidthMbps ] )
      print table

class RsvpBandwidthModel( Model ):
   interfaces = Dict( help="Interface's Bandwidth information",
                      keyType=Interface, valueType=RsvpBandwidthIntfModel )

   def render( self ):
      header1 = [ "Interface", "Priority", "Bandwidth", "" ]
      header2 = [ "         ", "        ", "Mbps     ", "  % " ]
      separator = [ "-" * ( len( col ) + 1 ) for col in header2 ]

      table = prettytable.PrettyTable( header1 )
      table.add_row( header2 )
      table.add_row( separator )

      # Align Numeric values to the right
      formatTable( table )
      table.align[ "Bandwidth" ] = 'r'
      table.align[ "Priority" ] = 'r'
      table.align[ "" ] = 'r'

      # Populate Table
      # Add rows
      for intfId in sortIntf( self.interfaces ):
         intfModel = self.interfaces[ intfId ]
         totalReserved = 0.0
         for priority in range( BwPriority.min, BwPriority.max + 1 ):
            bw = intfModel.reservedBandwidth.get( priority, 0.0 )
            reservedBandwidthMbps = roundBw( bw / 1e6, 1 )
            totalReserved += bw
            percentageUse = bandwithToPercentage( bw, intfModel.totalBandwidth )

            if priority == 0:
               row = [ intfId ]
            else:
               row = [ "" ]

            row.extend( [ priority,
                          reservedBandwidthMbps,
                          percentageUse ] )
            table.add_row( row )
         # Add total row
         totalReservedMbps = roundBw( totalReserved / 1e6, 1 )
         percentageUse = bandwithToPercentage( totalReserved,
                                               intfModel.totalBandwidth )
         table.add_row( [ "",
                          "Total",
                          totalReservedMbps,
                          percentageUse ] )
         table.add_row( [ "" for _ in header1 ] )
      print table
