# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac
from operator import attrgetter
from Arnet import IpGenPrefix
from ArnetModel import IpGenericAddress
from ArnetModel import IpGenericPrefix
from CliModel import Str
from CliModel import Dict
from CliModel import Enum
from CliModel import List
from CliModel import Int
from CliModel import Bool
from CliModel import Model
from CliModel import Submodel
from TunnelModels import ( IpVia, TunnelId,
                           tunnelTypeEnumValues, 
                           tunnelTypesReverseStrDict )
import SmashLazyMount
import copy
import TypeFuture

# pylint: disable-msg=too-many-nested-blocks

tacTunnelViaStatus = Tac.Type( 'Tunnel::Hardware::TunnelViaStatus' )
tacTunnelType = Tac.Type( "Tunnel::TunnelTable::TunnelType" )
tacTunnelId = Tac.Type( "Tunnel::TunnelTable::TunnelId" )
readerInfo = SmashLazyMount.mountInfo( 'reader' )
tacMplsStackIndex = TypeFuture.TacLazyType( "Arnet::MplsStackEntryIndex" )
tacNexthopGroupIntfId = TypeFuture.TacLazyType( 'Arnet::NexthopGroupIntfId' )

tunnelViaStatusTacToCapi = {
   tacTunnelViaStatus.unresolved: 'unresolved',
   tacTunnelViaStatus.usingPrimaryVias: 'usingPrimaryVias',
   tacTunnelViaStatus.usingBackupVias: 'usingBackupVias',
   tacTunnelViaStatus.unknown: 'unknown',
}
#  must remain consistent (or have values added) to remain revision compatible
tunnelViaStatusCapiEnumVals = tunnelViaStatusTacToCapi.values() + [ 'notProgrammed' ]

tunnelFib = None
nhgStatus = None

def getTunnelViaStatusStr( tunnelViaStatus ):
   if tunnelViaStatus == 'usingBackupVias':
      return "Backup"
   if tunnelViaStatus == 'usingPrimaryVias':
      return "Primary"
   return "None"

class L3TunnelInfo ( Model ):
   srcAddr = IpGenericAddress(
         help="Location in the IP network where the tunnel starts" )
   dstAddr = IpGenericAddress(
         help="Location in the IP network where the tunnel terminates" )
   dscp = Int( optional=True, help="Differentiated services code point" )
   hoplimit = Int( optional=True, help="Hop limit" )
   tos = Int( optional=True, help="Type of service" )
   ttl = Int( optional=True, help="Time to live" )

   def updateOptionalAttributes(
         self, af, dscp=None, hoplimit=None, tos=None, ttl=None ):
      if af == "ipv4":
         if tos:
            self.tos = tos
         if ttl:
            self.ttl = ttl
      else:
         if dscp:
            self.dscp = dscp
         if hoplimit:
            self.hoplimit = hoplimit

   def renderList( self ):
      attributes = list()
      attributes.append( "destination " + str( self.dstAddr ) )
      attributes.append( "source " + str( self.srcAddr ) )
      if self.hoplimit:
         attributes.append( "hoplimit " + str( self.hoplimit ) )
      if self.dscp:
         attributes.append( "dscp " + hex( self.dscp ) )
      if self.ttl:
         attributes.append( "ttl " + str( self.ttl ) )
      if self.tos:
         attributes.append( "tos " + hex( self.tos ) )
      return attributes

   def renderStr( self ):
      return ", ".join( self.renderList() )

   def render( self ):
      print self.renderStr()

class GreTunnelInfo ( L3TunnelInfo ):
   sequence = Bool( help="Sequence numbers in use", default=False )
   checksum = Bool( help="Checksum in use", default=False )
   key = Int( optional=True, help="GRE key" )

   def renderList( self ):
      attributes = super( GreTunnelInfo, self ).renderList()
      if self.key:
         attributes.append( "key " + str( self.key ) )
      if self.sequence:
         attributes.append( "sequence" )
      if self.checksum:
         attributes.append( "checksum" )
      return attributes

   def renderStr( self ):
      return ", ".join( [ "GRE" ] + self.renderList() )

class IpsecTunnelInfo ( L3TunnelInfo ):
   def renderList( self ):
      return super( IpsecTunnelInfo, self ).renderList()

   def renderStr( self ):
      return ", ".join( [ "IPsec" ] + self.renderList() )

class IpsecGreTunnelInfo ( GreTunnelInfo ):
   def renderList( self ):
      return super( IpsecGreTunnelInfo, self ).renderList()

   def renderStr( self ):
      return ", ".join( [ "GRE over IPsec" ] + self.renderList() )

class MplsTunnelInfo ( Model ):
   labelStack = List( valueType=str, help="MPLS label stack" )

   def renderStr( self ):
      attributes = list()
      if self.labelStack:
         attributes.append( "label " + " ".join( self.labelStack ) )
      return ", ".join( attributes )

   def render( self ):
      print self.renderStr()

class TunnelFibVia( IpVia ):
   __revision__ = 2
   encapId = Int( optional=True, help="Encapsulation identifier" )
   mplsEncap = Submodel(
         optional=True, valueType=MplsTunnelInfo,
         help="Encapsulation parameters for MPLS tunnel" )
   ipsecEncap = Submodel(
         optional=True, valueType=IpsecTunnelInfo,
         help="Encapsulation parameters for IPsec tunnel" )
   greEncap = Submodel(
         optional=True, valueType=GreTunnelInfo,
         help="Encapsulation parameters for GRE tunnel" )
   ipsecGreEncap = Submodel(
         optional=True, valueType=IpsecGreTunnelInfo,
         help="Encapsulation parameters for GRE over IPsec tunnel" )
   nhgName = Str( optional=True, help="Nexthop Group name for NHG tunnel" )
   _isBackupVia = Bool( help="Backup Via", default=False )
   resolvingTunnel = Submodel( valueType=TunnelId, optional=True,
                      help="Resolving tunnel Information" )

   def degrade( self, dictRepr, revision ):
      from TunnelCliLib import isDyTunIntfId
      if revision == 1:
         if not isDyTunIntfId( dictRepr.get( 'interface' ) ) and \
            dictRepr.get( 'resolvingTunnel' ):
            dictRepr.pop( 'resolvingTunnel' )
      return dictRepr

   def getEncap( self ):
      encapList = [
            self.mplsEncap, self.ipsecEncap, self.greEncap, self.ipsecGreEncap ]
      encap = next( ( item for item in encapList if item is not None ), None )
      return encap

   def getEncapStr( self ):
      encap = self.getEncap()
      if encap:
         return " " + encap.renderStr()
      return ""

   def renderFibVia( self, indent=3 ):
      if self._isBackupVia:
         viaStr = 'backup via '
      else:
         viaStr = 'via '

      if self.nhgName:
         attributes = [ viaStr + self.nhgName ]
      else:
         attributes = [ viaStr + str( self.nexthop ), str( self.interface ) ]
      encap = self.getEncap()
      if encap and self.encapId is not None:
         # Debug option is use, adjust indent accordingly
         attributes.append( "encapId " + str( self.encapId ) )
      viaStr = " " * indent + ", ".join( attributes ) + self.getEncapStr()
      return viaStr

   def getNhgEntry( self ):
      # pylint: disable-msg=E1101
      cachedNHG = {}
      for key, entry in nhgStatus.nexthopGroupEntry.iteritems():
         cachedNHG[ entry.nhgId ] = key.nhgName()
      return cachedNHG

   def renderTunnelVia( self, cachedNHG, visitedTunnelIds, indent=3 ):
      '''
      This function will recurse over resolving tunnels to print the details
      of the further tunnels.
      It will fetch the tunnel Id from the models interface and fetch the tunnel
      entry from the tunnel fib using this tunnelId. It will print the current tunnel
      information and it will recurse over this tunnel fib entry and repeat
      this process until it gets a nexthop and an interface without any
      further tunnels. It then prints the final nexthop and interface and returns.
      e.g.:
      Type 'IS-IS SR', index 4, endpoint 26::1:1:0/112, forwarding None
         via TI-LFA tunnel index 4 label 350 340 330
            via 47::2:3:9, 'Test47' label 330 320 310
            backup via 48::2:3:9, 'Test48' label 340 330 320
      '''
      from CliPlugin.TunnelCliLib import ( isDyTunIntfId,
                                           getDyTunTidFromIntfId )
      from CliPlugin.TunnelModels import getTunnelViaModelFromTunnelIntf
      tunnelPresent = False
      if self.resolvingTunnel:
         nextTunnelId = getDyTunTidFromIntfId( self.interface )
         tunnelPresent = True
         tunnelIdStr = self.resolvingTunnel.renderTunnelIdStr()
         print " " * indent + "via " + tunnelIdStr + self.getEncapStr()
         # To detect cycles in the tunnel, this map has been added.
         if nextTunnelId in visitedTunnelIds:
            print " " * indent + "cycle detected at " + tunnelIdStr
            return
         visitedTunnelIds[ nextTunnelId ] = True
         tunnelFibEntry = tunnelFib.entry.get( nextTunnelId )
         if not tunnelFibEntry:
            print " " * indent + tunnelIdStr + " not present in Tunnel FIB"
            return
         for via in sorted( tunnelFibEntry.tunnelVia.itervalues() ):
            resolvingTunnel = None
            if isDyTunIntfId( via.intfId ):
               resolvingTunnel = getTunnelViaModelFromTunnelIntf( via.intfId )
            viaModel = getTunnelFibViaModel( via, cachedNHG,
                                             resolvingTunnel=resolvingTunnel )
            viaModel.renderTunnelVia( cachedNHG, visitedTunnelIds, indent + 3 )
         for via in sorted( tunnelFibEntry.backupTunnelVia.itervalues() ):
            viaModel = getTunnelFibViaModel( via, cachedNHG, backup=True )
            viaModel.renderTunnelVia( cachedNHG, visitedTunnelIds, indent + 3 )
      if not tunnelPresent:
         print self.renderFibVia( indent )

   def renderTunnelFibVia( self, selfTunnelId ):
      cachedNHG = self.getNhgEntry()
      # Using this dictionary to break the loop for a cycle
      visitedTunnelIds = {}
      visitedTunnelIds[ selfTunnelId ] = True
      self.renderTunnelVia( cachedNHG, visitedTunnelIds )

class TunnelFibEntry( Model ):
   __revision__ = 2
   tunnelType = Enum( values=tunnelTypeEnumValues, help="Tunnel type" )
   tunnelIndex = Int( help="Tunnel index within tunnel type" )
   vias = List( valueType=TunnelFibVia, help="List of tunnel vias" )
   backupVias = List( valueType=TunnelFibVia,
                      help="List of tunnel backup vias" )
   endpoint = IpGenericPrefix( optional=True,
         help="Route prefix used for underlay nexthop resolution" )
   tunnelViaStatus = Enum( values=tunnelViaStatusCapiEnumVals,
                     help="Tunnel programming status", optional=True )
   interface = Str( optional=True, help="Interface name for tunnel" )
   tunnelId = Int( optional=True, help="Tunnel identifier" )
   tunnelFecId = Int( optional=True, help="FEC identifier" )
   seqNo = Int( optional=True, help="Sequence number of tunnel entry" )

   def degradeToV1( self, dictRepr, revision ):
      from TunnelCliLib import getDyTunTidFromIntfId
      tr = Tac.newInstance( "Tunnel::TunnelFib::TunnelResolver", tunnelFib )

      if revision == 1:
         viaListToRemove = []
         ansList = []
         for i, tunnelFibVia in enumerate( dictRepr[ 'vias' ] ):
            tr.renderedTunnelId.clear()
            tr.renderedTunnelId[ self.tunnelId ] = True
            if tunnelFibVia.get( 'resolvingTunnel' ):
               nextTunnelId = getDyTunTidFromIntfId( tunnelFibVia.get(
                  'interface' ) )
               nextTunnelTableId = Tac.Value( "Tunnel::TunnelTable::TunnelId",
                                              nextTunnelId )
               if not tunnelFibVia.get( 'mplsEncap' ):
                  continue
               if not tr.consolidatedEntry.get( nextTunnelId ):
                  tr.updateFlattenedInfo( nextTunnelTableId )
               processList = [ consEntry.viaInfo for consEntry in
                               tr.consolidatedEntry.values()
                               if consEntry.tunnelId == nextTunnelId ]
               for setVia in processList:
                  for mplsVia in setVia:
                     fibViaCopy = copy.deepcopy( tunnelFibVia )
                     fibViaCopy[ 'interface' ] = mplsVia.intfId
                     fibViaCopy[ 'nexthop' ] = mplsVia.nexthop.stringValue
                     labelOp = mplsVia.labels
                     labels = []
                     fibLabelStack = fibViaCopy[ 'mplsEncap' ][ 'labelStack' ]
                     maxStackSize = tacMplsStackIndex.max
                     stackSize = labelOp.stackSize
                     if stackSize + len( fibLabelStack ) <= maxStackSize:
                        for index in range( stackSize ):
                           labels.append( str( labelOp.labelStack(
                              stackSize - index - 1 ) ) )
                        labels.extend( fibLabelStack )
                     else:
                        labels = fibLabelStack
                     fibViaCopy[ 'mplsEncap' ] = MplsTunnelInfo(
                                                labelStack=labels ).toDict()
                     fibViaCopy.pop( 'resolvingTunnel' )
                     ansList.append( fibViaCopy )
               # Remove all the vias having further tunnel information
               viaListToRemove.append( i )
         # Remove all the vias with further tunnels in them
         for i in sorted( viaListToRemove, reverse=True ):
            dictRepr[ 'vias' ].pop( i )
         # Add the resolved vias back for further tunnels.
         dictRepr[ 'vias' ].extend( ansList )
      return dictRepr

   def render( self ):
      attributes = list()
      def adjustCase( text, attrs ):
         # Uppercases first letter of text in case it starts a newline
         return text.lower() if attrs else text.capitalize()

      # Debug context adds interface name and tunnelId to the output
      debugAttributes = list()
      if self.interface:
         # pylint: disable-msg=E1101
         debugAttributes.append( self.interface )
      if self.tunnelId is not None:
         debugAttributes.append(
               "%s %d" % ( adjustCase( "id", debugAttributes ), self.tunnelId ) )
      if self.seqNo is not None:
         debugAttributes.append( "sequence num " + str( self.seqNo ) )

      # All contexts
      attributes.append( "%s '%s'" % ( adjustCase( "type", attributes ),
                                       self.tunnelType ) )
      attributes.append( "index %d" % self.tunnelIndex )
      if self.endpoint:
         attributes.append( "endpoint " + str( self.endpoint ) )
      if self.tunnelViaStatus:
         attributes.append( "forwarding " + getTunnelViaStatusStr(
            self.tunnelViaStatus ) )

      print "\n", ", ".join( attributes )
      if debugAttributes:
         print ", ".join( debugAttributes )

      # Debug context adds the tunnel's FEC id to the output. This will
      # also cause the via rendering to be indented more than for non-
      # debug context. The FEC id goes on a separate line because it is
      # too long to fit on the same line as the tunnel type/index when
      # in debug content.
      if self.tunnelFecId is not None:
         print " " * 3 + "FEC id " + str( self.tunnelFecId )

      # All contexts
      from TunnelCliLib import getTunnelIdFromIndex
      # pylint: disable-msg=no-member
      af = self.endpoint.af if self.endpoint else None
      selfTunnelId = self.tunnelId or getTunnelIdFromIndex(
         tunnelTypesReverseStrDict[ self.tunnelType ], self.tunnelIndex, af )
      for via in sorted( self.vias, key=attrgetter( 'nexthop', 'interface' ) ):
         via.renderTunnelFibVia( selfTunnelId )
      for via in sorted( self.backupVias, key=attrgetter( 'nexthop', 'interface' ) ):
         via.renderTunnelFibVia( selfTunnelId )

class TunnelFibCategory( Model ):
   __revision__ = 2
   entries = Dict( keyType=long, valueType=TunnelFibEntry,
                   help="Tunnel FIB entries for a tunnel type" )

   def degradeToV1( self, dictRepr, revision ):
      from TunnelCliLib import getTunnelIdFromIndex
      if revision == 1:
         idsToRemove = []
         entries = dictRepr[ 'entries' ]
         for key in sorted( entries ):
            entry = entries[ key ]
            endpoint = entry.get( 'endpoint' )
            # pylint: disable-msg=no-member
            af = IpGenPrefix( endpoint ).af if endpoint else None
            selfTunnelId = entry.get( 'tunnelId' ) or \
                           getTunnelIdFromIndex(
                           tunnelTypesReverseStrDict[
                           entry[ 'tunnelType' ] ],
                           entry[ 'tunnelIndex' ], af )

            tunnelFibEntryModel = TunnelFibEntry( tunnelId=selfTunnelId )
            newDict = tunnelFibEntryModel.degradeToV1( entry, revision )
            if not newDict[ 'vias' ]:
               idsToRemove.append( key )
            else:
               entries[ key ] = newDict
         for key in idsToRemove:
            entries.pop( key )
      return dictRepr

   def render( self ):
      for key in sorted( self.entries.keys() ):
         self.entries[ key ].render()

class TunnelFib( Model ):
   __revision__ = 2
   categories = Dict( keyType=str, valueType=TunnelFibCategory,
         help="Collections of tunnels in the FIB, grouped by tunnel service type" )

   def degrade( self, dictRepr, revision ):
      if revision == 1:
         keysToRemove = []
         categories = dictRepr[ 'categories' ]
         for key in sorted( categories ):
            newDict = TunnelFibCategory().degradeToV1(
               categories[ key ], revision )
            if not newDict[ 'entries' ] :
               keysToRemove.append( key )
            else:
               categories[ key ] = newDict
         for key in keysToRemove:
            categories.pop( key )
      return dictRepr

   def render( self ):
      # pylint: disable-msg=E1101
      for key in sorted( self.categories.keys() ):
         self.categories[ key ].render()

def getTunnelFibViaModel( tunnelVia, cachedNHG, debug=False, backup=False,
                          resolvingTunnel=None ):
   isNhgTunnel = tacNexthopGroupIntfId.isNexthopGroupIntfId( tunnelVia.intfId )
   intf = "" if isNhgTunnel else tunnelVia.intfId
   viaModel = TunnelFibVia( nexthop=tunnelVia.nexthop,
                            interface=intf, type='ip',
                            resolvingTunnel=resolvingTunnel )
   if backup:
      # pylint: disable-msg=protected-access
      viaModel._isBackupVia = True
   if debug:
      viaModel.encapId = tunnelVia.encapId.encapIdValue

   if tunnelVia.encapId.encapType == 'mplsEncap':
      labelStackEncap = tunnelFib.labelStackEncap.get(
         tunnelVia.encapId, Tac.Value( "Tunnel::TunnelFib::LabelStackEncap" ) )
      labelOp = labelStackEncap.labelStack

      labels = []
      for mplsStackIndex in reversed ( range( labelOp.stackSize ) ):
         labels.append( str( labelOp.labelStack( mplsStackIndex ) ) )

      if labels:
         viaModel.mplsEncap = MplsTunnelInfo( labelStack=labels )

   elif tunnelVia.encapId.encapType == 'greEncap':
      greEncap = tunnelFib.greEncap.get( tunnelVia.encapId,
                                         Tac.Value( 'Tunnel::TunnelFib::GreEncap' ) )
      greInfo = greEncap.tunnelInfo

      encap = GreTunnelInfo( srcAddr=greInfo.src, dstAddr=greInfo.dst )

      super( GreTunnelInfo, encap ).updateOptionalAttributes(
            af=greInfo.src.af, dscp=greInfo.dscp, hoplimit=greInfo.hoplimit,
            tos=greInfo.tos, ttl=greInfo.ttl )

      if greInfo.oKey or greInfo.iKey:
         encap.key = greInfo.oKey or greInfo.iKey

      viaModel.greEncap = encap

   elif tunnelVia.encapId.encapType == 'ipSecEncap':
      if tunnelVia.encapId in tunnelFib.ipsecVtiEncap:
         ipsecInfo = tunnelFib.ipsecVtiEncap[ tunnelVia.encapId ].tunnelInfo
      else:
         ipsecInfo = Tac.Value( 'Tunnel::IpsecVtiTunnelInfo' )

      encap = IpsecTunnelInfo( srcAddr=ipsecInfo.src, dstAddr=ipsecInfo.dst )

      super( IpsecTunnelInfo, encap ).updateOptionalAttributes(
            af=ipsecInfo.src.af, dscp=ipsecInfo.dscp, hoplimit=ipsecInfo.hoplimit,
            tos=ipsecInfo.tos, ttl=ipsecInfo.ttl )

      viaModel.ipsecEncap = encap

   elif tunnelVia.encapId.encapType == 'ipSecGreEncap':
      if tunnelVia.encapId in tunnelFib.ipsecGreEncap:
         ipsecGreInfo = tunnelFib.ipsecGreEncap[ tunnelVia.encapId ].tunnelInfo
      else:
         ipsecGreInfo = Tac.Value( 'Tunnel::IpsecGreTunnelInfo' )

      encap = IpsecGreTunnelInfo(
            srcAddr=ipsecGreInfo.src, dstAddr=ipsecGreInfo.dst )

      super( IpsecGreTunnelInfo, encap ).updateOptionalAttributes(
            af=ipsecGreInfo.src.af, dscp=ipsecGreInfo.dscp,
            hoplimit=ipsecGreInfo.hoplimit, tos=ipsecGreInfo.tos,
            ttl=ipsecGreInfo.ttl )

      if ipsecGreInfo.oKey or ipsecGreInfo.iKey:
         encap.key = ipsecGreInfo.oKey or ipsecGreInfo.iKey

      viaModel.ipsecGreEncap = encap

   if isNhgTunnel:
      nhgId = tacNexthopGroupIntfId.nexthopGroupId( tunnelVia.intfId )
      nhgName = cachedNHG.get( nhgId, "NexthopGroupUnknown" )
      viaModel.nhgName = nhgName

   return viaModel

def Plugin( entityManager ):
   global tunnelFib
   global nhgStatus

   tunnelFib = SmashLazyMount.mount(
      entityManager, 'tunnel/tunnelFib', 'Tunnel::TunnelFib::TunnelFib',
      readerInfo )
   nhgStatus = SmashLazyMount.mount( entityManager,
                                     "routing/nexthopgroup/entrystatus",
                                     "NexthopGroup::EntryStatus",
                                     readerInfo )
