#!/usr/bin/env python
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac
import TableOutput
from datetime import datetime
from IntfModels import Interface
from Intf.IntfRange import intfListToCanonical
from CliModel import (
   Bool,
   Dict,
   Float,
   Int,
   List,
   Model,
   Str,
   Submodel,
)
from SftModel import (
   FlowDetailModel,
   TrackingModel,
)
import Toggles.InbandTelemetryCommonToggleLib
from TrafficPolicyCliModel import Rule
from operator import attrgetter

class Profile( Model ):
   profileType = Str( help="The type of profile" )
   sampleRate = Int( help="Sample rate of the profile", optional=True )
   if Toggles.InbandTelemetryCommonToggleLib.\
      toggleFeatureInbandTelemetrySamplePolicyEnabled():
      samplePolicy = Str( help="Sample policy", optional=True )
   egressCollection = Str( help="The collector name" )
   egressDrop = Str( help="Whether egress drop is enabled or disabled",
                     optional=True )
   profileStatus = Str(
      help="Whether the port is active or inactive or in error condition" )
   profileErrorReason = Str( help="None or the reason for error" )

class IntfList( Model ):
   Intfs = List( valueType=Interface, help='Interface list', optional=True )

class InbandTelemetryProfiles( Model ):
   coreProfiles = Dict( keyType=str, valueType=Profile,
   help="Maps core profiles to their corresponding configuration.", optional=True )
   edgeProfiles = Dict( keyType=str, valueType=Profile,
   help="Maps edge profiles to their corresponding configuration.", optional=True )

   def renderAttributes( self, profiles ):
      for name in profiles:
         profile = profiles[ name ]
         if profile.profileType == 'edge':
            profileType = 'Edge'
         else:
            profileType = 'Core'
         print profileType + " profile: " + name
         if profile.sampleRate:
            print "Ingress sample rate: " + str( profile.sampleRate )
         if Toggles.InbandTelemetryCommonToggleLib.\
            toggleFeatureInbandTelemetrySamplePolicyEnabled():
            if profile.samplePolicy:
               print "Ingress sample policy: " + str( profile.samplePolicy )
         # uncomment when the sample policy feature is enabled
         # if profile.samplePolicy:
         # print "Ingress sample policy: " + profile.samplePolicy
         print "Egress collection: " + profile.egressCollection
         if profile.egressDrop:
            print "Egress drop: " + profile.egressDrop
         print "Profile status: " + profile.profileStatus
         print "Profile error reason: " + profile.profileErrorReason + "\n"

   def render( self ):
      if self.coreProfiles:
         self.renderAttributes( self.coreProfiles )
      if self.edgeProfiles:
         self.renderAttributes( self.edgeProfiles )

class SamplePolicyModel( Model ):
   rules = List( valueType=Rule, help="Detailed information of match rules" )

   def render( self ):
      print "Total number of rules configured: %d" % len( self.rules )
      for rule in self.rules:
         print "match %s %s:" % ( rule.matchOption, rule.ruleString )
         dscpStr = ""
         if rule.matches.dscps:
            for dscpRange in rule.matches.dscps:
               if dscpRange.low == dscpRange.high:
                  dscpStr += "%d, " % dscpRange.high
               else:
                  dscpStr += "%d-%d, " % ( dscpRange.low, dscpRange.high )
            dscpStr = dscpStr.rstrip( " " )
            dscpStr = dscpStr.rstrip( "," )
            print "\tDSCP: %s" % dscpStr
         actions = rule.actions
         if actions.sample or actions.sampleAll:
            action = "sample" if actions.sample else "sample all"
            print "\tActions: %s" % action

class InbandTelemetry( Model ):
   enabled = Bool( help="Inband telemetry is enabled on the switch" )
   deviceId = Str( help="Device ID of the switch" )
   probeMarker = Int( help="Configured probe marker for Inband telemetry" )
   probeProtocol = Int( help="IP header protocol/next-header field for INT packets" )
   operStatus = Str( help="Operational status of Inband telemetry" )
   policies = Dict( keyType=str, valueType=SamplePolicyModel,
                    help="Maps sample policy name to its configuration",
                    optional=True )
   profiles = Submodel( valueType=InbandTelemetryProfiles,
                        help="Core and edge profiles" )
   intDetectionMethod = Str( help="Default Inband telemetry detection method" )

   def render( self ):
      if self.enabled:
         print "Enabled: True"
      else:
         print "Enabled: False"
      print "Device ID: %s" % self.deviceId
      if self.intDetectionMethod == \
            Tac.Type( "InbandTelemetry::IntDetectionMethod" ).ProbeMarkerBased:
         markerA = self.probeMarker >> 32
         markerB = self.probeMarker - ( markerA << 32 )
         hexA = hex( markerA )[ : -1 ] if 'L' in hex( markerA ) else hex( markerA )
         hexB = hex( markerB )[ : -1 ] if 'L' in hex( markerB ) else hex( markerB )
         print "Probe Marker: %s %s" % ( hexA, hexB )
      else:
         print "Probe IP Protocol: %s" % self.probeProtocol
      print "Operational Status: %s" % self.operStatus
      if self.policies:
         print "\nSample policies:"
         for policy in self.policies:
            print "Sample policy %s" % policy
            self.policies[ policy ].render()
            print "\n"
      print "\nProfiles:"
      self.profiles.render()

class ModelIntProfileSummary( Model ):
   coreIntfList = Dict( keyType=str,
                  valueType=IntfList,
                  help='Core profiles',
                  optional=True )
   edgeIntfList = Dict( keyType=str,
                  valueType=IntfList,
                  help='Edge profiles',
                  optional=True )

   def render( self ):
      if self.coreIntfList:
         print 'Core profiles'
         for profName, inList in sorted( self.coreIntfList.iteritems() ):
            print '%s: %s' % ( profName,
                  ','.join( intfListToCanonical( sorted( inList.Intfs ) ) ) )
      if self.edgeIntfList:
         if self.coreIntfList:
            print '\nEdge profiles'
         else:
            print 'Edge profiles'
         for profName, inList in sorted( self.edgeIntfList.iteritems() ):
            print '%s: %s' % ( profName,
                  ','.join( intfListToCanonical( sorted( inList.Intfs ) ) ) )

class IntFlowDetailIntervalStats( Model ):
   timestamp = Float( help="Interval start time" )
   pkts = Int( help="Number of packets" )
   congestions = List( valueType=bool, help="Congestion per device in path" )
   avgLatencies = List( valueType=long,
         help="Average latency per device in path (ns)" )
   maxLatencies = List( valueType=long,
         help="Maximum latency per device in path (ns)" )
   minLatencies = List( valueType=long,
         help="Minimum latency per device in path (ns)" )

class IntFlowDetailNodeInfo( Model ):
   deviceId = Int( help="Device ID" )
   ingressPortId = Int( help="Ingress port" )
   egressPortId = Int( help="Egress port ID" )
   egressQueueId = Int( help="Egress queue ID" )
   ttl = Int( help="Inband telemetry TTL (hop count)" )
   lastPacketCongestion = Bool(
         help="Congestion detected in the last sampled packet" )
   lastPacketLatency = Int( help="Last packet latency (ns)" )

class IntFlowDetailModel( FlowDetailModel ):
   pathTransistions = Int( help="Path transistions" )
   pathPackets = Int( help="Path packets" )
   devicesInPath = Int( help="Devices in path" )
   flowIntervals = Int( help="Flow intervals" )
   hopCountExceeded = Bool( help="Hop count exceeded" )
   devicesInformation = List( valueType=IntFlowDetailNodeInfo,
                          help="List of inband telemetry device information" )
   flowIntervalStats = List( valueType=IntFlowDetailIntervalStats,
                          help="List of inband telemetry interval statistics" )

class IntTrackingModel( TrackingModel ):

   def renderIntFlowDetailNodeInfo( self, intDetail ):
      if not intDetail.devicesInformation:
         return
      nodeInfoHeadings = ( "Device ID", "Ingress Port ID", "Egress Port ID",
            "Egress Queue ID", "TTL", "Congestion (last pkt)",
            "Latency (last pkt) (ns)" )
      formatCommon = TableOutput.Format( justify="right", maxWidth=7, wrap=True )
      formatTTL = TableOutput.Format( justify="right", maxWidth=5, wrap=True )
      formatCongestion = TableOutput.Format( justify="left", maxWidth=13, wrap=True )
      nodeInfoTable = TableOutput.createTable( nodeInfoHeadings, indent=6 )
      nodeInfoTable.formatColumns( formatCommon, formatCommon, formatCommon,
            formatCommon, formatTTL, formatCongestion, formatCommon )
      pathIDList = []
      for nodeInfo in intDetail.devicesInformation:
         pathIDList.append( str( hex( nodeInfo.deviceId ) ) )
         lastPacketCongestion = "not congested"
         if nodeInfo.lastPacketCongestion:
            lastPacketCongestion = "congested"
         nodeInfoTable.newRow( hex( nodeInfo.deviceId ),
                               hex( nodeInfo.ingressPortId ),
                               hex( nodeInfo.egressPortId ),
                               nodeInfo.egressQueueId,
                               nodeInfo.ttl,
                               lastPacketCongestion,
                               nodeInfo.lastPacketLatency )
      tab = " " * 6
      pathIDStr = ' -> '.join( pathIDList )
      print "%sPath device IDs: %s" % ( tab, pathIDStr )
      print nodeInfoTable.output()

   def renderIntFlowDetailIntervalStats( self, intDetail ):
      tab = " " * 6
      if not intDetail.flowIntervalStats:
         return
      intrvlStatsHeadings = ( "Time Stamp", "Pkts", "Congestion Per Device In Path",
            ( "Latency Per Device In Path (ns)", ( "Avg", "Max", "Min" ) ) )
      intrvlStatsTable = TableOutput.createTable( intrvlStatsHeadings, indent=6 )
      formatTS = TableOutput.Format( justify="right", maxWidth=10, wrap=True )
      formatPkts = TableOutput.Format( justify="right", maxWidth=5, wrap=True )
      formatCongestion = TableOutput.Format( justify="left", maxWidth=10, wrap=True )
      formatLatencyTime = TableOutput.Format( justify="right", maxWidth=10,
            wrap=True )
      intrvlStatsTable.formatColumns( formatTS, formatPkts, formatCongestion,
            formatLatencyTime, formatLatencyTime, formatLatencyTime )
      intervalStats = sorted( intDetail.flowIntervalStats,
            key=attrgetter( 'timestamp' ), reverse=True )
      for intrvlStat in intervalStats:
         congestion = [ 'c' if c else '-' for c in intrvlStat.congestions ]
         congestionTup = tuple( congestion )
         congestionStr = str( congestionTup ).replace( "'", "" )
         startTime = datetime.fromtimestamp(
               intrvlStat.timestamp ).strftime( "%Y-%m-%d %H:%M:%S" )
         intrvlStatsTable.newRow( startTime, intrvlStat.pkts, congestionStr,
               str( tuple( intrvlStat.avgLatencies ) ).replace( 'L', '' ),
               str( tuple( intrvlStat.maxLatencies ) ).replace( 'L', '' ),
               str( tuple( intrvlStat.minLatencies ) ).replace( 'L', '' ) )
      print "%sFlow interval statistics" % ( tab )
      print intrvlStatsTable.output()

   def renderIntFlowDetail( self, detail ):
      tab = " " * 6
      print "%sInband telemetry information" % ( tab )
      print "%sPath transitions: %d, Path packets: %d, Hop count exceeded: " \
            "%s" % ( tab, detail.pathTransistions, detail.pathPackets,
                  str( detail.hopCountExceeded ).lower() )
      print "%sDevices in path: %d, Flow intervals: %d" % \
            ( tab, detail.devicesInPath, detail.flowIntervals )

      self.renderIntFlowDetailNodeInfo( detail )
      self.renderIntFlowDetailIntervalStats( detail )

   def renderFlowDetail( self, flow, key ):
      super( IntTrackingModel, self ).renderFlowDetail( flow, key )
      self.renderIntFlowDetail( flow.flowDetail )
