# Copyright (c) 2015 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# This file contains tunnel CAPI model definitions that are shared between
# multiple packages. Tunnel models specific to certain packages should be
# defined elsewhere.

from ArnetModel import IpGenericAddress, IpGenericPrefix
from CliModel import Bool, Dict, Enum, List, Int, Str
from CliModel import Model, Submodel
from CliPlugin.TunnelCliLib import (
   getDyTunTidFromIntfId,
)
from IntfModels import Interface
import IntfModel
import Tac
from TunnelTypeLib import (
encapTypeEnumValues,
   tunnelTypeEnumValues,
   tunnelTypesReverseStrDict,
   tunnelTypeStrDict,
)
from TypeFuture import TacLazyType

tableTypeAttrs = tunnelTypeStrDict.values()

TunnelIdType = TacLazyType( "Tunnel::TunnelTable::TunnelId" )
TunnelType = Enum( values=tableTypeAttrs, help="Tunnel type" )

class IpTunnelInfo( Model ):
   tunnelEncap = Enum( values=encapTypeEnumValues, help="Tunnel encapsulation type" )
   tunnelSource = IpGenericAddress(
      help="Location in the IP network where the IP tunnel starts" )
   tunnelDestination = IpGenericAddress(
      help="Location in the IP network where the IP tunnel terminates" )
   tunnelKey = Int( help="GRE IP tunnel key", optional=True )
   tunnelDscp = Int( help="Differentiated services code point", optional=True )
   tunnelHoplimit = Int( help="Hop limit", optional=True )
   tunnelTos = Int( help="Type of service", optional=True )
   tunnelTtl = Int( help="Time to live", optional=True )

class TunnelId( Model ):
   __revision__ = 2
   index = Int( help="Tunnel index per tunnel type" )
   type = TunnelType

   def renderTunnelIdStr( self, tunStr="tunnel index" ):
      tunnelViaStr = ( self.type + " " + tunStr + " " +
                       str( self.index ) )
      return tunnelViaStr

   def renderStr( self ):
      return '%s (%d)' % ( self.type, self.index )

   def degrade( self, dictRepr, revision ):
      if revision == 1:
         if dictRepr[ 'type' ] == 'Nexthop Group':
            dictRepr[ 'type' ] = 'Nexthop-Group'

      return dictRepr

   def toRawValue( self ):
      tunType = tunnelTypesReverseStrDict[ self.type ]
      return TunnelIdType.convertToTunnelValue( tunType, self.index )

   def fromRawValue( self, value ):
      tun = TunnelIdType( value )
      self.index = tun.tunnelIndex()
      self.type = tun.typeCliStr()

def getTunnelViaModelFromTunnelIntf( tunnelIntf ):
   tunnelId = getDyTunTidFromIntfId( tunnelIntf )
   recurTunnelId = Tac.Value( "Tunnel::TunnelTable::TunnelId",
                               tunnelId )
   tunnel = TunnelId(
      type=recurTunnelId.typeCliStr(),
      index=recurTunnelId.tunnelIndex() )
   return tunnel

class TunnelViaInfo( Model ):
   nexthopAddr = IpGenericAddress( help="Next hop IP address", optional=True )
   interface = IntfModel.Interface( help="Egress L3 interface of next hop",
                                    optional=True )
   nhgName = Str( help="Next-Hop-Group name", optional=True )
   nextHop = Str( help="Next-Hop name", optional=True )
   interfaceDescription = Str( help="Interface description", optional=True )
   labelStack = List( valueType=int, optional=True,
                      help="MPLS label stack (top-of-stack label first)" )
   resolvingTunnel = Submodel(
      valueType=TunnelId, help="Resolving tunnel information", optional=True )

class TunnelInfo( Model ):
   tunnelType = Enum( values=tunnelTypeEnumValues, help="Tunnel type" )
   tunnelIndex = Int( help="Tunnel table index" )
   tunnelEndPoint = IpGenericPrefix(
      help="Route prefix for MPLS underlay nexthop resolution",
      optional=True )
   tunnelAddressFamily = Enum( values=( 'IPv4', 'IPv6' ),
                               help="Address family for MPLS tunnels",
                               optional=True )
   staticInterfaceTunnelInfo = Submodel(
      valueType=IpTunnelInfo,
      help="Tunnel encap information for a static interface tunnel",
      optional=True )
   tunnelVias = List( valueType=TunnelViaInfo, help="List of tunnel vias",
                      optional=True )

class Via( Model ):
   type = Enum( values=[ 'ip', 'tunnel' ],
                 help="Type of via, indicating the expected attributes" )

class IpVia( Via ):
   nexthop = IpGenericAddress( help="Nexthop IP address" )
   interface = Interface( help="Egress L3 interface of the next-hop" )

   def degradeToV1( self, via ):
      """Degrade the Ip via dict to revision 1"""
      del via[ 'type' ]

   def renderStr( self ):
      return "via %s %s" % ( self.nexthop, self.interface )

   def render( self ):
      print self.renderStr()

class TunnelVia( Via ):
   tunnelId = Submodel( valueType=TunnelId, help="Tunnel Identifier" )

   def degradeToV1( self, via ):
      """Degrade the Tunnel via dict to revision 1"""
      via[ 'nexthop' ] = '0.0.0.0'
      tidType = tunnelTypesReverseStrDict[ via[ 'tunnelId' ][ 'type' ] ]

      afBitMask = 0x0
      if tidType in [ 'staticV6Tunnel', 'srV6Tunnel' ]:
         # FIXME BUG200023 Can remove this afBitMask after convertToTunnelValue
         # is refactored.
         # if ID is a V6 channel, need to flip the address family bit to 1
         TunnelIdConstants = Tac.Type( "Tunnel::TunnelTable::TunnelIdConstants" )
         afBitMask = TunnelIdConstants.tunnelAfMask
      if tidType in [ 'staticV4Tunnel', 'staticV6Tunnel' ]:
         tidType = 'staticTunnel'
      elif tidType in [ 'srV4Tunnel', 'srV6Tunnel' ]:
         tidType = 'srTunnel'
      tid = Tac.Type( 'Tunnel::TunnelTable::TunnelId' ).convertToTunnelValue(
            tidType, via[ 'tunnelId' ][ 'index' ] ) | afBitMask
      intfId = Tac.Type( 'Arnet::DynamicTunnelIntfId' ).tunnelIdToIntfId( tid )
      via[ 'interface' ] = intfId
      del via[ 'type' ]
      del via[ 'tunnelId' ]

class MplsVia( IpVia ):
   labels = List( valueType=str, help="Label stack" )
   _isBackupVia = Bool( help="Backup Via", default=False )
   # This attribute will be populated if a tunnel is being protected
   # by another tunnel. For e.g SR tunnel being protected by TI-LFA tunnels
   backupTunnelInfo = Submodel( valueType=TunnelInfo,
                                help="Backup Tunnel Information",
                                optional=True )

   def render( self ):
      viaStr = "backup via" if self._isBackupVia else "via"
      labelsStr = " ".join( self.labels )
      print " " * 3 + "%s %s, %s" % ( viaStr, self.nexthop, self.interface )
      print " " * 6 + "label stack %s" % labelsStr

# Please maintain this. __cmp__ is not allowed to be defined
def mplsViaCmp( self, other ):
   return cmp( ( self.type, self.nexthop, self.interface, self.labels ),
               ( other.type, other.nexthop, other.interface, other.labels ) )

def mplsTunnelViaCmp( self, other ):
   return cmp( ( self.type, self.tunnelId.type, self.tunnelId.index, self.labels ),
               ( other.type, other.tunnelId.type, other.tunnelId.index,
                 other.labels ) )

def viaCmp( self, other ):
   if self.type != other.type:
      return cmp( self.type, other.type )
   elif self.type == 'ip':
      return mplsViaCmp( self, other )
   elif self.type == 'tunnel':
      return mplsTunnelViaCmp( self, other )
   else:
      assert False, "handle new via type"
      return 0

class MplsTunnelVia( TunnelVia ):
   labels = List( valueType=str, help="Label stack" )

class TunnelTableEntry( Model ):
   endpoint = IpGenericPrefix( help="Endpoint of the tunnel" )

class TunnelTableEntryWithMplsVias( TunnelTableEntry ):
   vias = List( valueType=MplsVia, help="List of nexthops" )

   def renderIsisTunnelTableEntry( self, table, tunnelIndex ):
      from CliPlugin import TunnelCliLib
      nhStr = intfStr = labelsStr = '-'
      backupTunnelStr = "-"
      if self.vias:
         firstVia = self.vias[ 0 ]
         nhStr, intfStr = TunnelCliLib.getNhAndIntfStrs( firstVia )
         if firstVia.backupTunnelInfo:
            tunnelInfo = firstVia.backupTunnelInfo
            backupTunnelStr = str( tunnelInfo.tunnelIndex )
         labelsStr = '[ ' + ' '.join( firstVia.labels ) + ' ]'
      table.newRow( tunnelIndex, str( self.endpoint ), nhStr, intfStr, labelsStr,
                    backupTunnelStr )
      for via in self.vias[ 1 : ]:
         nhStr, intfStr = TunnelCliLib.getNhAndIntfStrs( via )
         labelsStr = '[ ' + ' '.join( via.labels ) + ' ]'
         backupTunnelStr = "-"
         if via.backupTunnelInfo:
            tunnelInfo = via.backupTunnelInfo
            backupTunnelStr = str( tunnelInfo.tunnelIndex )
         table.newRow( '-', '-', nhStr, intfStr, labelsStr, backupTunnelStr )

class LdpIsisSrTunnelTableCommon( Model ):
   entries = Dict( keyType=long, valueType=TunnelTableEntryWithMplsVias,
                   help="Tunnel table entries keyed by tunnel index" )

   def render( self ):
      raise NotImplementedError( "Missing render method" )

   def getTunnelIdFromIndex( self, index, addrFamily ):
      raise NotImplementedError( "Missing getTunnelIdFromIndex method" )
