#!/usr/bin/env python
# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import Tac

import CliGlobal
import SharedMem
import Smash

from CliPlugin.TunnelCli import (
   TunnelTableIdentifier,
   readMountTunnelTable,
)
from CliPlugin.SrTePolicyLibModel import SrTeSegmentListVia
from SrTePolicyLib import MplsLabel
from TypeFuture import TacLazyType

DynTunnelIntfId = TacLazyType( "Arnet::DynamicTunnelIntfId" )
FecId = TacLazyType( 'Smash::Fib::FecId' )
FecIdIntfId = TacLazyType( 'Arnet::FecIdIntfId' )
TunnelId = TacLazyType( "Tunnel::TunnelTable::TunnelId" )

def cliDisplayOrderLabelList( mplsLabelStack ):
   """Converts a BoundedMplsLabelStack to a list of labels (int) in the order that
   they should be shown in show commands.
   This is a duplicate of the code in Mpls package. It's duplicated to avoid
   depending on Mpls pacakge which results in a cyclic dependency.

   BoundedMplsLabelStack reserves label index 0 for the bottom of the stack, however
   the CLI generally displays labels with the outermost label on the wire
   (top of stack) on the left, so the list representation for that purpose is in
   reverse order when compared to BoundedMplsLabelStack.
   """
   return [ mplsLabelStack.label( i )
            for i in xrange( mplsLabelStack.stackSize - 1, -1, -1 ) ]

def initFecModeSm():
   gv.fecModeStatus = Tac.newInstance( 'Smash::Fib::FecModeStatus', 'fms' )
   gv.fecModeSm = Tac.newInstance( 'Ira::FecModeSm', gv.l3Config, gv.fecModeStatus )

#-----------------------------------------------------------------------------
# In multiAgent mode, the unifiedForwardingStatus and unifiedForwarding6Status
# provide the FECs that are published by IpRib
#-----------------------------------------------------------------------------
def getForwardingStatus():
   unifiedMode = Tac.Type( 'Smash::Fib::FecMode' ).fecModeUnified
   if gv.fecModeStatus.fecMode == unifiedMode:
      return( gv.unifiedForwardingStatus, gv.unifiedForwarding6Status )
   else:
      return( gv.forwardingStatus, gv.forwarding6Status )

def getSegmentListVias( tunnelEntry ):
   '''
   Given a tunnelEntry return a list of filled in SrTeSegmentListVia model objects
   '''
   vias = []
   backupVias = []
   for via in sorted( tunnelEntry.via.itervalues() ):
      labels = cliDisplayOrderLabelList( via.labels.boundedLabelStack )
      if DynTunnelIntfId.isDynamicTunnelIntfId( via.intfId ):
         # SR-TE tunnels pointing to TI-LFA tunnels.
         tunnelId = TunnelId( DynTunnelIntfId.tunnelId( via.intfId ) )
         tilfaTunnelEntry = gv.tilfaTunnelTable.entry.get( tunnelId, None )
         if tilfaTunnelEntry is None:
            continue
         primaryLabels = cliDisplayOrderLabelList(
                            tilfaTunnelEntry.via.labels.boundedLabelStack )
         primaryLabels = [ x for x in primaryLabels if x != MplsLabel.implicitNull ]
         backupLabels = cliDisplayOrderLabelList(
                           tilfaTunnelEntry.backupVia.labels.boundedLabelStack )
         backupLabels = [ x for x in backupLabels if x != MplsLabel.implicitNull ]
         primaryLabels.extend( labels )
         backupLabels.extend( labels )
         vias.append( SrTeSegmentListVia( nexthop=tilfaTunnelEntry.via.nexthop,
                                          interface=tilfaTunnelEntry.via.intfId,
                                          mplsLabels=primaryLabels ) )

         backupVias.append( SrTeSegmentListVia(
                               nexthop=tilfaTunnelEntry.backupVia.nexthop,
                               interface=tilfaTunnelEntry.backupVia.intfId,
                               mplsLabels=backupLabels ) )
      elif FecIdIntfId.isFecIdIntfId( via.intfId ):
         fecId = FecId( FecIdIntfId.intfIdToFecId( via.intfId ) )
         if fecId.adjType() == 'usedByTunnelV4Adj':
            fecId = FecId( FecId.fecIdToNewAdjType( 'fibV4Adj', fecId ) )
         elif fecId.adjType() == 'usedByTunnelV6Adj':
            fecId = FecId( FecId.fecIdToNewAdjType( 'fibV6Adj', fecId ) )

         fec = None
         fecStatus, fec6Status = getForwardingStatus()
         if fecId.adjType() == 'fibV4Adj':
            fec = fecStatus.fec.get( fecId )
         elif fecId.adjType() == 'fibV6Adj':
            fec = fec6Status.fec.get( fecId )
         if fec is None:
            continue
         for fecVia in sorted( fec.via.values() ):
            addr = Tac.Value( "Arnet::IpGenAddr", str( fecVia.hop ) )
            vias.append( SrTeSegmentListVia( nexthop=addr,
                                             interface=fecVia.intfId,
                                             mplsLabels=labels ) )
      else:
         vias.append( SrTeSegmentListVia( nexthop=via.nexthop,
                                          interface=via.intfId,
                                          mplsLabels=labels ) )
   if not backupVias:
      backupVias = None
   return ( vias, backupVias )

gv = CliGlobal.CliGlobal( dict(
   fecModeSm=None,
   fecModeStatus=None,
   forwardingStatus=None,
   forwarding6Status=None,
   l3Config=None,
   tilfaTunnelTable=None,
   unifiedForwardingStatus=None,
   unifiedForwarding6Status=None,
   ) )

def Plugin( entityManager ):
   gv.tilfaTunnelTable = readMountTunnelTable(
      TunnelTableIdentifier.tiLfaTunnelTable, entityManager )
   smashEm = SharedMem.entityManager( sysdbEm=entityManager )
   readerInfo = Smash.mountInfo( 'reader' )
   gv.forwardingStatus = smashEm.doMount( "forwarding/status",
                                       "Smash::Fib::ForwardingStatus",
                                       readerInfo )
   gv.forwarding6Status = smashEm.doMount( "forwarding6/status",
                                        "Smash::Fib6::ForwardingStatus",
                                        readerInfo )
   gv.unifiedForwardingStatus = smashEm.doMount( "forwarding/unifiedStatus",
                                              "Smash::Fib::ForwardingStatus",
                                              readerInfo )
   gv.unifiedForwarding6Status = smashEm.doMount( "forwarding6/unifiedStatus",
                                               "Smash::Fib6::ForwardingStatus",
                                               readerInfo )

   def doMountsComplete():
      initFecModeSm()
   mg = entityManager.mountGroup()
   gv.l3Config = mg.mount( "l3/config", "L3::Config", "ri" )
   mg.close( doMountsComplete )
