# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
#
# Functions related to Vxlan Config Sanity Checking

import Arnet
from CliPlugin.IraCommonCli import AddressFamily
import IpUtils
import LazyMount
import Tac
import VxlanModel

# pylint: disable=W0212

warnExplainedMsg = ( 'Your configuration contains warnings. This does not'
                     ' mean misconfigurations. But you may wish to re-check'
                     ' your configurations.' )
dynVlanVniWarnMsg = ( 'There are dynamic VLAN-VNI mapping errors.'
                      ' See syslog for more detail.' )

vlanNotCreatedFmt = 'VLAN %d does not exist'

dynVlanVniConflictFmt = 'dynamic %s %d conflict'

noRemoteVtepVlanFloodlistFmt = 'No remote VTEP in VLAN %d'

noVniInVrfToVniFormat = 'No VRF-VNI mapping for VNI %d'

noVrfInVrfToVniFormat = 'No VRF-VNI mapping for VRF %s'

underlayDefaultVrfId = 0

#-------------------------------------------------------------------------------
# decorators for sanity check functions
#
# use:
#  @ConfigCheckItem if function is returning one ConfigCheckItem
#  @GenConfigCheckItems if function is returning a list of ConfigCheckItems.
#     When this decorator is used, the functions themselves have to perform
#     object creation
#-------------------------------------------------------------------------------

# wrap a check function with ConfigCheckItem creation
def ConfigCheckItem( itemName, priority ):
   """Function wrapper for config check functions. Prevents code duplication"""
   def nameWrap( f ):
      def funcWrap():
         item = VxlanModel.ConfigCheckItem()
         item._priority = priority
         item.name = itemName
         item = f( item )
         return item
      return funcWrap
   return nameWrap

# wrap a config check function that generates ConfigCheckItems
def GenConfigCheckItems( itemNameBase, priority ):
   def nameWrap( f ):
      def funcWrap():
         items = f( itemNameBase )
         for item in items:
            item._priority = priority
         return items
      return funcWrap
   return nameWrap

#-------------------------------------------------------------------------------
# helper classes
#
#-------------------------------------------------------------------------------

class RouteTrie( object ):
   def __init__( self, routingStatus, forwardingStatus,
                 routing6Status, forwarding6Status ):
      self.routingStatus = routingStatus
      self.forwardingStatus = forwardingStatus
      self.routing6Status = routing6Status
      self.forwarding6Status = forwarding6Status
      if isinstance( routingStatus, LazyMount._Proxy ):
         LazyMount.force( routingStatus )
      if isinstance( routing6Status, LazyMount._Proxy ):
         LazyMount.force( routing6Status )
      self.trie = None
      self.trie6 = None
      self.trieBuilder = None
      self.trieBuilder6 = None

   def _newTrie( self, af ):
      return Tac.newInstance( 'Routing::TrieGen', 'trie', af )

   def _newTrieBuilder( self, af='ipv4' ):
      if af == 'ipv4':
         return Tac.newInstance( 'Routing::TrieGenBuilder', self.routingStatus, None,
                                 self.forwardingStatus, None, None, None, self.trie )
      else:
         return Tac.newInstance( 'Routing::TrieGenBuilder', None,
                                 self.routing6Status, None, self.forwarding6Status,
                                 None, None, self.trie6 )

   def refreshRoutes( self ):
      self.trie = None
      self.trie6 = None
      self.trieBuilder = None
      self.trie = self._newTrie( AddressFamily.ipv4 )
      self.trie6 = self._newTrie( AddressFamily.ipv6 )
      self.trieBuilder = self._newTrieBuilder()
      self.trieBuilder6 = self._newTrieBuilder( af='ipv6' )

   def getRoute( self, prefix ):
      if not self.trie:
         self.trie = self._newTrie( AddressFamily.ipv4 )
      if not self.trie6:
         self.trie6 = self._newTrie( AddressFamily.ipv6 )
      if not self.trieBuilder:
         self.trieBuilder = self._newTrieBuilder()
      if not self.trieBuilder6:
         self.trieBuilder6 = self._newTrieBuilder( af='ipv6' )

      while True:
         if prefix.af == AddressFamily.ipv4:
            prefixMatch = self.trie.longestMatch( prefix ).v4Prefix
            maxLen = 32
            routingStatus = self.routingStatus
            forwardingStatus = self.forwardingStatus
         elif prefix.af == AddressFamily.ipv6:
            prefixMatch = self.trie6.longestMatch( prefix ).v6Prefix
            maxLen = 128
            routingStatus = self.routing6Status
            forwardingStatus = self.forwarding6Status

         if not routingStatus or not forwardingStatus:
            return None, None
         if prefixMatch.isNullPrefix:
            break

         if routingStatus.route.has_key( prefixMatch ):
            route = routingStatus.route[ prefixMatch ]
            if not route:
               return None, None
            fec = forwardingStatus.fec.get( route.fecId )
            if not fec:
               return None, None
            else:
               vias = []
               for i in range( len( fec.via ) ):
                  if fec.via[ i ].intfId:
                     vias.append( fec.via[ i ] )
               return route, vias

         if prefixMatch.len == 0:
            break

         ip = IpUtils.IpAddress( prefixMatch.address )
         ip = IpUtils.IpAddress( ip.toNum() &
                                 ( 0xFFFFFFFF << ( maxLen - prefixMatch.len + 1 ) ) )
         prefix = Arnet.IpGenPrefix( str( ip ) + '/' + str( prefixMatch.len - 1 ) )

      return None, None
