#!/usr/bin/env python
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from MplsPingHandlers import (
   handleLspPingBgpLu,
   handleLspPingRaw,
   handleLspPingLdpMldpSr,
   handleLspPingRsvp,
   handleLspPingNhg,
   handleLspPingStatic,
   handleLspPingSrTe,
   handleLspPingPwLdp,
   handleLspPingNhgTunnel
)

from MplsTracerouteHandlers import (
   handleLspTracerouteBgpLu,
   handleLspTracerouteRaw,
   handleLspTracerouteLdp,
   handleLspTracerouteMldpSr,
   handleLspTracerouteRsvp,
   handleLspTracerouteNhg,
   handleLspTracerouteStatic,
   handleLspTracerouteSrTe,
   handleLspTracerouteNhgTunnel
)

from ClientCommonLib import (
   LspPing,
   LspTraceroute, 
   LspPingTypeBgpLu,
   LspPingTypeRaw,
   LspPingTypeLdp,
   LspPingTypeMldp, 
   LspPingTypeRsvp,
   LspPingTypeSr,
   LspPingTypeStatic,
   LspPingTypeNhg,
   LspPingTypeSrTe,
   LspPingTypePwLdp,
   LspPingTypeNhgTunnel
)

from ClientState import ( 
   sessionIdIncr, 
   getGlobalState, 
   setGlobalClientIdBaseOverride 
)

import datetime
import Cell
import errno
import os
import QuickTrace
import Tac
import SharedMem
import Smash
import SmashLazyMount
from ForwardingHelper import ( forwardingHelperKwFactory )
from TypeFuture import TacLazyType
import Intf.AllIntfLib
from IpLibConsts import DEFAULT_VRF
from Arnet.NsLib import DEFAULT_NS
from SrTePolicyCommonLib import srTePolicyStatusPath
import Toggles.NexthopGroupToggleLib as NhgToggle

#--------------------------------
# QuickTrace Short hand variables
#--------------------------------
qv = QuickTrace.Var
qt8 = QuickTrace.trace8

TunnelTableMounter = TacLazyType( "Tunnel::TunnelTable::TunnelTableMounter" )
TunnelTableIdentifier = TacLazyType( "Tunnel::TunnelTable::TunnelTableIdentifier" )

class LspUtilMount( object ):
   def __init__( self, lspPingType, entityManager, vrf=None ):
      vrf = vrf or DEFAULT_VRF
      self.vrf = vrf
      self.entityManager = entityManager
      mg = self.entityManager.mountGroup()

      # Local Entities must be created inside the enclosing mount group
      localEntityTypes = [
         ( 'interface/status/all', 'Interface::AllIntfStatusDir' ),
         ( Cell.path( 'interface/status/local' ),
           'Interface::AllIntfStatusLocalDir' )
      ]
      Intf.AllIntfLib.createLocalEntities( entityManager, localEntityTypes )

      self.ipStatus = mg.mount( 'ip/status', 'Ip::Status', 'r' )
      self.ip6Status = mg.mount( 'ip6/status', 'Ip6::Status', 'r' )
      self.bridgingConfig = mg.mount( "bridging/config", "Bridging::Config", "r" )


      self.allIntfStatusDir = mg.mount( 'interface/status/all',
                                        'Interface::AllIntfStatusDir', 'r' )
      self.allIntfStatusLocalDir = mg.mount(
         ( Cell.path( 'interface/status/local' ) ),
         'Interface::AllIntfStatusLocalDir', 'r' )

      self.config = mg.mount( 'mplsutils/config',
                              'MplsUtils::Config', 'r' )
      self.vrfNameStatus = mg.mount( Cell.path( 'vrf/vrfNameStatus' ),
                                     'Vrf::VrfIdMap::NameToIdMapWrapper', 'r' )
      self.intfConfigDir = mg.mount( 'l3/intf/config', 'L3::Intf::ConfigDir', 'r' )

      startNexthopGroupSm = False
      # smash mount
      shmemEm = SharedMem.entityManager( sysdbEm=self.entityManager )
      if lspPingType == LspPingTypeNhg or lspPingType == LspPingTypeStatic or \
         lspPingType == LspPingTypeNhgTunnel:
         self.routingHwNexthopGroupStatus = mg.mount( 
            'routing/hardware/nexthopgroup/status',
            'Routing::Hardware::NexthopGroupStatus', 'r' )
         self._routingNhgConfig = Tac.newInstance( 'Routing::NexthopGroup::Config' )
         self.nexthopGroupCliConfig = mg.mount( 'routing/nexthopgroup/input/cli',
                                                'Routing::NexthopGroup::ConfigInput',
                                                'r' )
         self.nexthopGroupConfigDir = mg.mount( 'routing/nexthopgroup/input/config',
                                                'Tac::Dir', 'ri' )
         startNexthopGroupSm = True
         nexthopEntryTableInfo = Tac.Value( 'NexthopGroup::TableInfo' )
         nexthopEntryStatusMountInfo = nexthopEntryTableInfo.entryStatus( 'shadow' )
         self.smashNhgStatus = shmemEm.doMount(
               "routing/nexthopgroup/entrystatus",
               "NexthopGroup::EntryStatus", 
               nexthopEntryStatusMountInfo )
      self.tunnelFib = SmashLazyMount.mount( self.entityManager,
                                             'tunnel/tunnelFib',
                                             'Tunnel::TunnelFib::TunnelFib',
                                             Smash.mountInfo( 'reader' ) )
      self.arpSmash = shmemEm.doMount( 'arp/status', 'Arp::Table::Status',
                                       Smash.mountInfo( 'reader' ) )
      self.arpSmashVrfIdMap = shmemEm.doMount( 'vrf/vrfIdMapStatus',
                                               'Vrf::VrfIdMap::Status',
                                               Smash.mountInfo( 'reader' ) )
      self.routingVrfInfoDir = mg.mount( 'routing/vrf/routingInfo/status',
                                         'Tac::Dir', 'ri' )
      self.routing6VrfInfoDir = mg.mount( 'routing6/vrf/routingInfo/status',
                                          'Tac::Dir', 'ri' )
   
      fibInfo = Tac.Value( 'Smash::Fib::TableInfo' )
      routeMountInfo = fibInfo.routeInfo( 'reader' )
      forwardingMountInfo = fibInfo.forwardingInfo( 'reader' )

      if vrf == DEFAULT_VRF:
         routeStatusPath = 'routing/status'
         forwardingStatusPath = 'forwarding/status'
         route6StatusPath = 'routing6/status'
         forwarding6StatusPath = 'forwarding6/status'
      else:
         routeStatusPath = 'routing/vrf/status/%s' % vrf
         forwardingStatusPath = 'forwarding/vrf/status/%s' % vrf
         # FIXME V6 in VRF?
         route6StatusPath = forwarding6StatusPath = None
      self.routeStatus = shmemEm.doMount( routeStatusPath,
                                          'Smash::Fib::RouteStatus',
                                          routeMountInfo )
      self.forwardingStatus = shmemEm.doMount( forwardingStatusPath,
                                               'Smash::Fib::ForwardingStatus',
                                               forwardingMountInfo )
      if lspPingType == LspPingTypeSrTe:
         self.srTeForwardingStatus = \
                                 shmemEm.doMount( 'forwarding/srte/status',
                                                  'Smash::Fib::ForwardingStatus',
                                                  forwardingMountInfo )
      if route6StatusPath and forwarding6StatusPath:
         self.route6Status = shmemEm.doMount( route6StatusPath,
                                              'Smash::Fib6::RouteStatus',
                                              routeMountInfo )
         self.forwarding6Status = shmemEm.doMount( forwarding6StatusPath,
                                                   'Smash::Fib6::ForwardingStatus',
                                                   forwardingMountInfo )
      else:
         self.route6Status = self.forwarding6Status = None

      self.mplsTunnelConfig = mg.mount( 'routing/mpls/tunnel/config',
                                        'Tunnel::MplsTunnelConfig', 'r' )

      # BGP LU specific mounts. Mount BGP LU TunnelRib to  get fec to tunnelId
      # mapping
      if lspPingType == LspPingTypeBgpLu:
         self.bgpLuTunnelRib = shmemEm.doMount(
            'tunnel/protoTunnelRib/bgpLu', 'Tunnel::TunnelTable::TunnelRib',
            Smash.mountInfo( 'keyshadow' ) )

      # Nexthop group tunnel specific mounts. Mount nexthop group TunnelRib to
      # get fec to tunnelId mapping
      if lspPingType == LspPingTypeNhgTunnel:
         self.nhgTunnelRib = shmemEm.doMount(
            'tunnel/protoTunnelRib/nexthopGroup', 'Tunnel::TunnelTable::TunnelRib',
            Smash.mountInfo( 'keyshadow' ) )

      # Ldp specific mounts. Mount ldp TunnelRib to get fec to
      # tunnelId mapping and tunnel table to get tunnelId to via mapping
      if lspPingType == LspPingTypeLdp:
         self.ldpTunnelRib = shmemEm.doMount(
            'tunnel/protoTunnelRib/ldp', 'Tunnel::TunnelTable::TunnelRib',
            Smash.mountInfo( 'reader' ) )
         tableInfo = TunnelTableMounter.getMountInfo(
                     TunnelTableIdentifier.ldpTunnelTable ).tableInfo
         self.ldpTunnelTable = SmashLazyMount.mount(
                self.entityManager, tableInfo.mountPath, tableInfo.tableType,
                Smash.mountInfo( 'reader' ) )

      # Mldp specific mounts. Mounting protoLfib mldp to get nexthop and
      # label info.
      if lspPingType == LspPingTypeMldp:
         self.mldpLfib = shmemEm.doMount(
            "mpls/protoLfibInputDir/mldp", "Mpls::LfibStatus",
            Smash.mountInfo( 'reader' ) )

         self.mldpOpaqueValueTable = mg.mount(
            "mpls/ldp/mldpOpaqueValueTable", "Mpls::MldpOpaqueValueTable", "r" )

      # Sr specific mounts.
      # sr TunnelRib provides: FEC --> tunnelId
      # SrTunnelTable provides: tunnelId --> via
      # Acheive FEC --> via
      if lspPingType == LspPingTypeSr:
         self.srTunnelRib = shmemEm.doMount(
            'tunnel/protoTunnelRib/sr', 'Tunnel::TunnelTable::TunnelRib',
            Smash.mountInfo( 'reader' ) )

         tableInfo = TunnelTableMounter.getMountInfo(
             TunnelTableIdentifier.srTunnelTable ).tableInfo
         self.srTunnelTable = SmashLazyMount.mount(
            self.entityManager, tableInfo.mountPath, tableInfo.tableType,
            Smash.mountInfo( 'reader' ) )
         tableInfo = TunnelTableMounter.getMountInfo(
            TunnelTableIdentifier.tiLfaTunnelTable ).tableInfo
         self.tiLfaTunnelTable = SmashLazyMount.mount(
            self.entityManager, tableInfo.mountPath, tableInfo.tableType,
            Smash.mountInfo( "reader" ) )

      if lspPingType == LspPingTypeRsvp:
         self.rsvpStatus = mg.mount( 'mpls/rsvp/status', 'Rsvp::RsvpStatus', 'rS' )

      # SR-TE policy specific mounts.
      if lspPingType == LspPingTypeSrTe:
         tableInfo = TunnelTableMounter.getMountInfo(
             TunnelTableIdentifier.srTeSegmentListTunnelTable ).tableInfo
         self.srTeSegmentListTunnelTable = SmashLazyMount.mount(
            self.entityManager, tableInfo.mountPath, tableInfo.tableType,
            Smash.mountInfo( 'reader' ) )
         self.policyStatus = SmashLazyMount.mount(
               self.entityManager, srTePolicyStatusPath(),
               "SrTePolicy::PolicyStatus", Smash.mountInfo( "reader" ) )
      # PW specific mounts
      if lspPingType == LspPingTypePwLdp:
         self.pwConfig = mg.mount( 'pseudowire/config', 'Pseudowire::Config', 'r' )
         self.pwRcs = mg.mount( 'pseudowire/agent/remoteConnectorStatusColl',
                                'Pseudowire::RemoteConnectorStatusColl',
                                'r' )
         self.ldpProtoConfig = mg.mount( 'mpls/ldp/ldpProtoConfigColl/',
                                         'Ldp::LdpProtoConfigColl', 'r' )

      self.kniStatus = shmemEm.doMount( 'kni/ns/%s/status' % DEFAULT_NS,
                                        'KernelNetInfo::Status',
                                        Smash.mountInfo( 'reader' ) )

      mg.close( blocking=True )
      self.vrfIpIntfStatus = self.ipStatus.vrfIpIntfStatus.get( vrf )
      self.vrfIp6IntfStatus = self.ip6Status.vrfIp6IntfStatus.get( vrf )

      if startNexthopGroupSm:
         if NhgToggle.toggleNhgCliDynamicMergeSmEnabled():
            Typename = 'NexthopGroup::CliDynamicConfigMergeSmWrapper'
         else:
            Typename = 'Ira::NexthopGroupConfigMergeSm'
         self.nexthopGroupConfigMergeSm = \
               Tac.newInstance( Typename,
                                self._routingNhgConfig,
                                self.nexthopGroupCliConfig,
                                self.nexthopGroupConfigDir )

      self.trie = Tac.newInstance( "Routing::Trie", "trie" )
      self.trieBuilder = Tac.newInstance( "Routing::TrieBuilder", self.routeStatus,
                                          self.trie )
      self.v6trie = Tac.newInstance( "Routing6::Trie", "v6trie" )
      self.v6trieBuilder = Tac.newInstance( "Routing6::TrieBuilder",
                                            self.route6Status, self.v6trie )

      # Instantiate the Etba Forwarding helper class.
      vrfRoutingStatus = { vrf : self.routeStatus }
      vrfRouting6Status = { vrf : self.route6Status }
      vrfTrie = { vrf : self.trie }
      vrfTrie6 = { vrf : self.v6trie }
      resolveHelperKwargs = {
         'vrfRoutingStatus' : vrfRoutingStatus,
         'vrfRouting6Status' : vrfRouting6Status,
         'forwardingStatus' : self.forwardingStatus,
         'forwarding6Status' :self.forwarding6Status,
         'arpSmash' : self.arpSmash,
         'trie4' : vrfTrie,
         'trie6' : vrfTrie6,
         'vrfNameStatus' : self.vrfNameStatus,
         'intfConfigDir' : self.intfConfigDir,
         'tunnelFib' : self.tunnelFib,
      }
      if lspPingType == LspPingTypeSrTe:
         resolveHelperKwargs[ 'srTeForwardingStatus' ] = self.srTeForwardingStatus
         resolveHelperKwargs[ 'srTeSegmentListTunnelTable' ] = \
                                                   self.srTeSegmentListTunnelTable
      self.fwdingHelper = forwardingHelperKwFactory( **resolveHelperKwargs )

   @property
   def routingNhgConfig( self ):
      if self._routingNhgConfig is None:
         return None
      # The NexthopGroupConfigMergeSm uses async mounts, so wait for the
      # SM's configReady flag to be set before returning a reference to the merged
      # nexthop group config.
      Tac.waitFor( lambda: self._routingNhgConfig.configReady,
            description='NexthopGroupConfigMergeSm configReady flag to be True' )
      return self._routingNhgConfig

def allocateTraceFile( util, dstType ):
   maxNumLogs = 10
   timeStamp = str( datetime.datetime.now().time() ).replace( ".", ":" )
   agentName = "MplsUtilLsp"
   traceFileName = "-".join( [ agentName, util, dstType, timeStamp ] )
   traceFileName += ".qt"

   os.environ[ 'QUICKTRACEDIR' ] = "/var/log/qt"
   qtPath = os.environ[ 'QUICKTRACEDIR' ]

   if not os.path.exists( qtPath ):
      os.makedirs( qtPath, 0o755 )

   # cleanup oldest qt file if max files is reached
   listOfFiles = [ name for name in os.listdir( qtPath )
                   if name.startswith( agentName + "-" + util ) ]
   if len( listOfFiles ) >= maxNumLogs:
      listOfFilesAbsPath = [ qtPath + '/' + name for name in listOfFiles ]
      timeSortedQtFiles = sorted( listOfFilesAbsPath, key=os.path.getctime )
      for qtFile in timeSortedQtFiles:
         try:
            os.remove( os.path.abspath( qtFile ) )
         except OSError:
            pass
         else:
            QuickTrace.initialize( traceFileName )
            break
      else:
         # OS deletion errored for all existing trace files, hence do not keep
         # adding new files, instead overwrite traces to one file in /var/log/qt.
         # This way qtraces doesn't fail because of OS error and
         # we don't keep adding new files.
         QuickTrace.initialize( agentName + "-" + "LspPingTracerouteClient.qt" )
   else:
      QuickTrace.initialize( traceFileName )

lspUtilHandlerMap = {
   LspPing : {       LspPingTypeBgpLu : handleLspPingBgpLu,
                     LspPingTypeRaw : handleLspPingRaw,
                     LspPingTypeLdp : handleLspPingLdpMldpSr,
                     LspPingTypeMldp : handleLspPingLdpMldpSr,
                     LspPingTypeRsvp : handleLspPingRsvp,
                     LspPingTypeSr : handleLspPingLdpMldpSr,
                     LspPingTypeNhg : handleLspPingNhg,
                     LspPingTypeStatic : handleLspPingStatic,
                     LspPingTypeSrTe : handleLspPingSrTe,
                     LspPingTypePwLdp : handleLspPingPwLdp,
                     LspPingTypeNhgTunnel : handleLspPingNhgTunnel },
   LspTraceroute : { LspPingTypeBgpLu : handleLspTracerouteBgpLu,
                     LspPingTypeRaw : handleLspTracerouteRaw,
                     LspPingTypeLdp : handleLspTracerouteLdp,
                     LspPingTypeMldp : handleLspTracerouteMldpSr,
                     LspPingTypeRsvp : handleLspTracerouteRsvp,
                     LspPingTypeSr : handleLspTracerouteMldpSr,
                     LspPingTypeNhg : handleLspTracerouteNhg,
                     LspPingTypeStatic : handleLspTracerouteStatic,
                     LspPingTypeSrTe : handleLspTracerouteSrTe,
                     LspPingTypeNhgTunnel : handleLspTracerouteNhgTunnel }
}

# ------------------------------------------------------------------------
#       lspUtilHandler object used by ping/traceroute utilities 
# ------------------------------------------------------------------------

def lspUtilHandler( util, entityManager, args ):
   if not util in [ LspPing, LspTraceroute ]:
      print 'Util not supported: %s' % util
      return errno.EINVAL

   if not args or not isinstance( args, dict ):
      print 'Wrong type of arguments'
      return errno.EINVAL

   sessionIdIncr()
   currState = getGlobalState()

   # Ping scale test only: set ClientIdBase and request src port
   if 'cidbase' in args:
      clientIdBaseOverride = args.pop( 'cidbase' )
      if clientIdBaseOverride is not None:
         setGlobalClientIdBaseOverride( clientIdBaseOverride )
         currState.clientIdBase = int( clientIdBaseOverride )
   if util == LspPing:
      reqSrcPort = args.pop( 'sport' )
      if reqSrcPort:
         currState.clientRootUdpPamSrcPort = int( reqSrcPort )

   retCode = 0
   dst = args.pop( 'destination' )
   dstType = args[ 'type' ]
   handler = lspUtilHandlerMap[ util ].get( dstType )

   if args.get( 'session_name' ):
      session = args.pop( 'session_name' )
      args[ 'session' ] = session
   if args.get( 'session_id' ):
      session = args.pop( 'session_id' )
      args[ 'session' ] = session

   if handler :
      if args.get( 'label' ):
         args[ 'label' ] = [ int( l ) for l in args[ 'label' ].split( ',' ) ]
      mount = LspUtilMount( dstType, entityManager, args[ 'vrf' ] )
      allocateTraceFile( util, dstType )
      retCode = handler( dst, mount, **args )
   else:
      print '%s type not supported: %s' % ( util, dstType )
      retCode = errno.EINVAL
   return retCode

