# Copyright (c) 2013 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable-msg=ungrouped-imports
import CliSave
import MplsLib
import Tac
import Toggles.MplsToggleLib
from CliMode.Mpls import StaticMulticastModeBase, TunnelStaticModeBase
from IntfCliSave import IntfConfigMode
from MplsLib import tunTypeEnumDict
from MplsTypeLib import tunnelTypeRevXlate
from RoutingIntfUtils import allRoutingProtocolIntfNames
from TypeFuture import TacLazyType

CliSave.GlobalConfigMode.addCommandSequence( 'Mpls.labelRange',
                                             after=[ 'Ira.routes' ] )
CliSave.GlobalConfigMode.addCommandSequence( 'Mpls.routes',
                                             after=[ 'Mpls.labelRange' ] )
CliSave.GlobalConfigMode.addCommandSequence( 'Mpls.config', after=[ 'Ira.routing' ] )
CliSave.GlobalConfigMode.addCommandSequence( 'Tunnel.static.config',
                                             after=[ 'Ira.routing' ] )
CliSave.GlobalConfigMode.addCommandSequence( 'Mpls.staticL3Vpn',
                                             after=[ 'Tunnel.static.config' ] )
IntfConfigMode.addCommandSequence( 'Mpls.intf', after=[ 'Ira.ipIntf' ] )

FecIdIntfId = TacLazyType( 'Arnet::FecIdIntfId' )
MplsVia = Tac.Type( 'Tunnel::TunnelTable::MplsVia' )
TacDyTunIntfId = Tac.Type( 'Arnet::DynamicTunnelIntfId' )
LabelRangeInfo = TacLazyType( 'Mpls::LabelRangeInfo' )

def routeKey( rk ):
   return ( rk.topLabel, )

#mpls static top-label 14460 00fe:11c7::1245:0000 swap-label 12902 metric 1
#mpls static top-label 14500 00fd::1234:0001 pop payload-type ipv4
#mpls label dynamic-range 1000 2000

@CliSave.saver( 'Mpls::RouteConfigInput', 'routing/mpls/route/input/cli' )
def saveRoutingTable( entity, root, sysdbRoot, options, requireMounts ):
   metricDefaultValue = Tac.Type( 'Mpls::Constants' ).metricDefault

   routeCmd = 'mpls static top-label'

   for rk in sorted( entity.route, key=routeKey ):
      route = entity.route[ rk ]
      if route.dynamic:
         continue

      for via in sorted( route.via.keys() ):
         cmd = "%s %s " % ( routeCmd, rk.routeKey.labelStack.cliString() )
         addPayload = True
         if via.nexthopGroup:
            cmd += "nexthop-group %s " % via.nexthopGroup
            addPayload = False
         else:
            if via.intf:
               if TacDyTunIntfId.isDynamicTunnelIntfId( via.intf ):
                  tunnelId = TacDyTunIntfId.tunnelId( via.intf )
                  tunnelId = Tac.Value( 'Tunnel::TunnelTable::TunnelId', tunnelId )
                  tunTypeEnum = tunnelId.tunnelType()
                  tunType = tunTypeEnumDict[ tunTypeEnum ]
                  tunInd = tunnelId.tunnelIndex()
                  cmd += 'debug tunnel-type %s tunnel-index %d' \
                         % ( tunType, tunInd )
                  addPayload = False
               else:
                  cmd += "%s " % via.intf
                  cmd += via.nextHop.stringValue
            else:
               cmd += via.nextHop.stringValue


         if via.labelAction == 'swap':
            cmd += ' swap-label ' + str( via.outLabel )
         if via.labelAction == 'pop':
            cmd += ' pop'
            if via.payloadType != 'autoDecide':
               cmd += ' payload-type ' + via.payloadType
            elif addPayload:
               cmd += ' payload-type auto'
            if via.skipEgressAcl:
               cmd += ' access-list bypass'
         if rk.metric != metricDefaultValue or options.saveAll:
            cmd += ' metric ' + str( rk.metric )

         root[ 'Mpls.routes' ].addCommand( cmd )

@CliSave.saver( 'Mpls::Config', 'routing/mpls/config',
                requireMounts=( 'interface/config/all', 'interface/status/all',
                                  'routing/hardware/mpls/capability' ) )
def saveConfig( entity, root, sysdbRoot, options, requireMounts ):
   if entity.mplsRouting:
      root[ 'Mpls.config' ].addCommand( "mpls ip" )
   elif options.saveAll:
      root[ 'Mpls.config' ].addCommand( "no mpls ip" )

   cmd = "mpls next-hop resolution allow default-route"
   if entity.nexthopResolutionAllowDefaultRoute == True:
      root[ 'Mpls.config' ].addCommand( cmd )
   elif options.saveAll:
      root[ 'Mpls.config' ].addCommand( "no %s" % cmd )

   cmd = "mpls tunnel termination model ttl %s dscp %s"
   hwCapability = requireMounts[ 'routing/hardware/mpls/capability' ]
   ttlMode = entity.labelTerminationTtlMode
   dscpMode = entity.labelTerminationDscpMode
   # ttl=pipe, dscp=uniform is not supported
   if ( ttlMode != dscpMode ) and ( ttlMode == "pipe" or
        not hwCapability.mplsTtlUniformDscpPipeModelSupported ):
      pass
   elif ttlMode == "pipe" or dscpMode == "pipe":
      root[ 'Mpls.config' ].addCommand( cmd %
            ( entity.labelTerminationTtlMode, entity.labelTerminationDscpMode ) )
   elif options.saveAll:
      root[ 'Mpls.config' ].addCommand( cmd % ( "uniform", "uniform" ) )

   # The meaning of the default "undefined" differs between platforms.
   # Therefore only output the cmd if it has been explicity set to
   # "pipe" or "uniform"
   cmdStem = 'mpls tunnel termination php model'
   cmd = cmdStem
   ttlMode = entity.labelPhpTtlMode
   dscpMode = entity.labelPhpDscpMode
   if ttlMode != 'undefinedTtlMode':
      cmd += ' ttl %s' % ttlMode
   if dscpMode != 'undefinedTtlMode':
      cmd += ' dscp %s' % dscpMode
   if cmd != cmdStem:
      root[ 'Mpls.config' ].addCommand( cmd )
   elif options.saveAll:
      root[ 'Mpls.config' ].addCommand( "no " + cmd )

   cmd = "mpls fec ip sharing disabled"
   if not entity.optimizeStaticRoutes:
      root[ 'Mpls.config' ].addCommand( cmd )
   elif options.saveAll:
      root[ 'Mpls.config' ].addCommand( "no " + cmd )

   if Toggles.MplsToggleLib.toggleMplsEntropyLabelSupportEnabled():
      cmd = "mpls entropy-label pop"
      if entity.entropyLabelPop:
         root[ 'Mpls.config' ].addCommand( cmd )
      elif options.saveAll:
         root[ 'Mpls.config' ].addCommand( 'no ' + cmd )

   # Save Interface level config
   saveIntfConfig( entity, root, sysdbRoot, options, requireMounts )

def saveIntfConfig( entity, root, sysdbRoot, options, requireMounts ):

   if options.saveAll:
      intfs = allRoutingProtocolIntfNames( root, requireMounts=requireMounts )
   else:
      intfs = entity.mplsRoutingDisabledIntf.keys()

   for intf in intfs:
      mode = root[ IntfConfigMode ].getOrCreateModeInstance( intf )
      cmdSequence = mode[ 'Mpls.intf' ]
      if intf not in entity.mplsRoutingDisabledIntf and options.saveAll:
         cmdSequence.addCommand( 'mpls ip' )
      elif ( intf in entity.mplsRoutingDisabledIntf and
             entity.mplsRoutingDisabledIntf[ intf ] ):
         cmdSequence.addCommand( 'no mpls ip' )

CliSave.GlobalConfigMode.addCommandSequence( 'Mpls.tunnel.config' )

@CliSave.saver( 'Tunnel::MplsTunnelConfig', 'routing/mpls/tunnel/config' )
def saveMplsTunnelConfig( entity, root, sysdbRoot, options ):
   if Toggles.MplsToggleLib.toggleMplsEntropyLabelSupportEnabled():
      cmd = "mpls tunnel entropy-label push"
      if entity.entropyLabelPush:
         root[ 'Mpls.tunnel.config' ].addCommand( cmd )
      elif options.saveAll:
         root[ 'Mpls.tunnel.config' ].addCommand( 'no ' + cmd )

@CliSave.saver( 'Mpls::Config', 'routing/mpls/config' )
def saveLabelRanges( entity, root, sysdbRoot, options, requireMounts ):
   for rangeType in sorted( entity.labelRange ):
      saveLabelRange( entity, root, options, rangeType )

def saveLabelRange( entity, root, options, rangeType ):
   defaultRange = MplsLib.labelRangeDefault( entity, rangeType )
   value = MplsLib.labelRange( entity, rangeType )
   cliRangeType = rangeType if rangeType != LabelRangeInfo.rangeTypeL2evpnSharedEs \
                  else 'l2evpn ethernet-segment'
   if options.saveAll or value != defaultRange:
      base = value.base
      size = value.size
      baseCmd = 'mpls label range ' + cliRangeType
      root[ 'Mpls.labelRange' ].addCommand( baseCmd + ' %d %d' % ( base, size ) )

@CliSave.saver( 'Mpls::Config', 'routing/mpls/config' )
def saveLookupLabelCount( entity, root, sysdbRoot, options, requireMounts ):
   if entity.mplsLookupLabelCount != 1 or options.saveAll:
      root[ 'Mpls.config' ].addCommand( 'mpls lookup label count %d'
                                        % entity.mplsLookupLabelCount )

class TunnelStaticSaveMode( TunnelStaticModeBase, CliSave.Mode ):
   def __init__( self, param ):
      tunName, tep = param
      TunnelStaticModeBase.__init__( self, tunName, tep )
      CliSave.Mode.__init__( self, tunName )

   def __cmp__( self, other ):
      return cmp( self.tunName, other.tunName )

CliSave.GlobalConfigMode.addCommandSequence( 'Mpls.staticTunnelConfig',
                                             after=[ 'Tunnel.static.config' ] )
CliSave.GlobalConfigMode.addChildMode( TunnelStaticSaveMode,
                                       after=[ 'Mpls.staticTunnelConfig' ] )
TunnelStaticSaveMode.addCommandSequence( 'Mpls.staticTunnelConfigVias' )

@CliSave.saver( 'Tunnel::Static::Config', 'tunnel/static/config' )
def saveStaticTunnelConfig( entity, root, sysdbRoot, options ):
   for name in entity.entry:
      staticTunnelConfigEntry = entity.entry[ name ]
      if len( staticTunnelConfigEntry.via ) > 1:
         assert staticTunnelConfigEntry.inStaticTunnelMode
      if staticTunnelConfigEntry.inStaticTunnelMode:
         mode = root[ TunnelStaticSaveMode ].getOrCreateModeInstance(
            ( staticTunnelConfigEntry.name, staticTunnelConfigEntry.tep ) )
         cmds = mode[ 'Mpls.staticTunnelConfigVias' ]
         for via in staticTunnelConfigEntry.via:
            cmd = []
            cmd.append( 'via' )
            cmd.append( via.nexthop.stringValue )
            cmd.append( via.intfId )
            labelStack = []
            labelsObj = via.labels
            for idx in reversed( range( labelsObj.stackSize ) ):
               labelStack.append( str( labelsObj.labelStack( idx ) ) )
            if labelStack == [ "3" ]:
               cmd.append( 'imp-null-tunnel' )
            else:
               cmd.append( 'label-stack' )
               cmd += labelStack
            cmds.addCommand( ' '.join( cmd ) )
      else:
         for via in staticTunnelConfigEntry.via:
            saveStaticTunnelConfigEntryVia( root,
                                            staticTunnelConfigEntry,
                                            via,
                                            False )
            break # This case will have only one via at most
      if staticTunnelConfigEntry.backupVia != MplsVia():
         saveStaticTunnelConfigEntryVia( root,
                                         staticTunnelConfigEntry,
                                         staticTunnelConfigEntry.backupVia,
                                         True )

def saveStaticTunnelConfigEntryVia( root, configEntry, via, backup ):
   cmd = []
   cmd.append( 'mpls' )
   resolving = FecIdIntfId.isFecIdIntfId( via.intfId )
   resolvingTunnel = TacDyTunIntfId.isDynamicTunnelIntfId( via.intfId )
   if backup or resolving or resolvingTunnel:
      cmd.append( 'debug' )
   if backup:
      cmd.append( 'backup' )
   cmd.append( 'tunnel static' )
   cmd.append( configEntry.name )
   if resolving:
      cmd.append( 'resolving' )
      cmd.append( configEntry.tep.stringValue )
   elif resolvingTunnel:
      cmd.append( 'resolving-tunnel' )
      cmd.append( configEntry.tep.stringValue )
      tunnelId = TacDyTunIntfId.tunnelId( via.intfId )
      tunnelId = Tac.Value( 'Tunnel::TunnelTable::TunnelId', tunnelId )
      tunType = tunnelTypeRevXlate[ tunnelId.tunnelType() ]
      cmd.append( 'tunnel-type %s' % tunType )
      tunIndex = tunnelId.tunnelIndex()
      cmd.append( 'tunnel-index %d' % tunIndex )
   else:
      cmd.append( configEntry.tep.stringValue )
      cmd.append( via.nexthop.stringValue )
      cmd.append( via.intfId )
   labelStack = []
   labelsObj = via.labels
   for idx in reversed( range( labelsObj.stackSize ) ):
      labelStack.append( str( labelsObj.labelStack( idx ) ) )
   if labelStack == [ "3" ]:
      cmd.append( 'imp-null-tunnel' )
   else:
      cmd.append( 'label-stack' )
      cmd += labelStack
   root[ 'Tunnel.static.config' ].addCommand( ' '.join( cmd ) )

@CliSave.saver( 'Mpls::VrfLabelConfigInput', 'routing/mpls/vrfLabel/input/cli' )
def saveVrfLabel( entity, root, sysdbRoot, options, requireMounts ):
   staticVrfLabelCmd = 'mpls static vrf-label {} vrf {}'

   mplsStatic = root[ 'Mpls.staticL3Vpn' ]
   for vrfLabel in entity.vrfLabel.itervalues():
      cmd = staticVrfLabelCmd.format( vrfLabel.label, vrfLabel.vrfName )
      mplsStatic.addCommand( cmd )

class StaticMulticastSaveMode( StaticMulticastModeBase, CliSave.Mode ):

   def __init__( self, inLabel ):
      StaticMulticastModeBase.__init__( self, inLabel )
      CliSave.Mode.__init__( self, inLabel )

CliSave.GlobalConfigMode.addCommandSequence( 'Mpls.staticMulticast',
                                             after=[ 'Mpls.config' ] )
CliSave.GlobalConfigMode.addChildMode( StaticMulticastSaveMode,
                                       after=[ 'Mpls.staticMulticast' ] )
StaticMulticastSaveMode.addCommandSequence( 'Mpls.routeConfig' )

@CliSave.saver( 'Mpls::LfibSysdbStatus', 'mpls/staticMcast/lfib' )
def saveStasticMcastLfib( entity, root, sysdbRoot, options, requireMounts ):
   for viaSet in entity.mldpStaticViaSet.values():
      mode = root[ StaticMulticastSaveMode ].getOrCreateModeInstance( viaSet.label )
      cmds = mode[ 'Mpls.routeConfig' ]
      for idx in reversed( range( viaSet.getViaCount() ) ):
         via = viaSet.mplsVia[ idx ]
         nhAddr = via.nextHop.stringValue
         outLabel = via.outLabel.topLabel()
         cmd = "next-hop %s swap-label %s" % ( nhAddr, outLabel )
         cmds.addCommand( cmd )


def Plugin( entityManager ):
   pass
