#!/usr/bin/env python
# Copyright (c) 2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
from Arnet import IpGenPrefix, IpGenAddr
import CliExtensions
import CliParser
from collections import namedtuple
from IpLibConsts import DEFAULT_VRF
from CliToken.Ipv4 import ipv4MatcherForShow
from CliToken.Ipv6 import ipv6MatcherForShow
import Tac
import LazyMount
from CliMatcher import KeywordMatcher
from CliCommand import Node, CliExpression, OrExpressionFactory, CliExpressionFactory

AddressFamily = Tac.Type( "Arnet::AddressFamily" )

def vrfFromMode( mode ):
   if hasattr( mode, 'vrfName' ):
      return mode.vrfName
   return DEFAULT_VRF

#----------------------------------------------------------------------
# Returns True, if multicast routing is supported by the platform.
# routingHardwareStatus can be used to specify 'routing/hardware/status' in sysdb.
#---------------------------------------------------------------------
def mcastRoutingSupported( sysdbRoot, routingHardwareStatus=None ):
   if routingHardwareStatus is None:
      routingHardwareStatus = sysdbRoot[ 'routing' ][ 'hardware' ][ 'status' ]
   return routingHardwareStatus.multicastRoutingSupported

#-------------------------------------------------------------------------------
# Use this hook to provide access to the collections
# of interfaces that potentially will cause Vif
# allocation. For now these are only where either pim
# or igmp is configured (here 'configured' means that
# the config differs from the default one).
#-------------------------------------------------------------------------------
mcastIfCollHook = CliExtensions.CliHook()
igmpStatusCollHook = CliExtensions.CliHook()

McastConstants = Tac.Type( 'McastCommon::McastConstants' )
# pimReg uses one vif, substract one from maxVifDefault
MAX_VIFS = McastConstants.maxVifDefault - 1

def getAllMcastIntfs( vrfName, af ):
   ''' Return the set of all multicast interfaces. '''
   allMcastIntfs = set( )
   for hook in mcastIfCollHook.extensions():
      mcastIntfConfigColl = hook( vrfName, af )
      allMcastIntfs |= set( mcastIntfConfigColl.keys() )
   return allMcastIntfs

def createMcastIntfConfig( mode, vrfName, af, intfName, createFcn, printError=True ):
   '''Checks if creating a new multicast interface config in this VRF will put us
   over the maximum allowed (determined by the per-VRF, per-AF VIF limit).  If so,
   we don't create the new config, and issue an error if printError is True.  If not,
   we use the given createFcn to create the config.  createFcn should take no
   arguments and return the config.'''
   # Count all multicast routing interfaces and assure
   # there are vifs available for a new one.
   allMcastIntfs = getAllMcastIntfs( vrfName, af )
   if len( allMcastIntfs ) < MAX_VIFS  or intfName in allMcastIntfs:
      return createFcn()
   elif printError:
      mode.addError(
         "Cannot configure more than %d interfaces for multicast routing" %
         MAX_VIFS )
      return None

def mcast6RoutingSupported( sysdbRoot, routingHardwareStatus=None ):
   if routingHardwareStatus is None:
      routingHardwareStatus = sysdbRoot[ 'routing6' ][ 'hardware' ][ 'status' ]
   return routingHardwareStatus.multicastRoutingSupported

def mcastRoutingSupportedGuard( mode, token ):
   if mcastRoutingSupported( mode.sysdbRoot ):
      return None
   return CliParser.guardNotThisPlatform

def mcastRoutingBoundarySupportedGuard( mode, token ):
   routingHardwareStatus = mode.sysdbRoot[ 'routing' ][ 'hardware' ][ 'status' ]
   if routingHardwareStatus.multicastBoundarySupported and \
         mcastRoutingSupported( mode.sysdbRoot ):
      return None
   return CliParser.guardNotThisPlatform

def mcast6RoutingSupportedGuard( mode, token ):
   if mcast6RoutingSupported( mode.sysdbRoot ):
      return None
   return CliParser.guardNotThisPlatform

def mcast6RoutingBoundarySupportedGuard( mode, token ):
   routingHardwareStatus = mode.sysdbRoot[ 'routing6' ][ 'hardware' ][ 'status' ]
   if routingHardwareStatus.multicastBoundarySupported and \
         mcast6RoutingSupported( mode.sysdbRoot ):
      return None
   return CliParser.guardNotThisPlatform

def mcastGenRoutingSupportedGuard( mode, token ):
   if mcastRoutingSupported( mode.sysdbRoot ) or \
         mcast6RoutingSupported( mode.sysdbRoot ):
      return None
   return CliParser.guardNotThisPlatform

def mcastRoutingSupportedIntfGuard( mode, token ):
   if mcastRoutingSupported( mode.sysdbRoot ):
      return None
   return CliParser.guardNotThisPlatform

def mcast6RoutingSupportedIntfGuard( mode, token ):
   if mcast6RoutingSupported( mode.sysdbRoot ):
      return None
   return CliParser.guardNotThisPlatform

def mcastGenRoutingSupportedIntfGuard( mode, token ):
   if mcastRoutingSupported( mode.sysdbRoot ):
      return None
   if mcast6RoutingSupported( mode.sysdbRoot ):
      return None
   return CliParser.guardNotThisPlatform

ipFamilyAlias = "AF"
ipv4KwMatcher = KeywordMatcher( 'ipv4', helpdesc='IPv4 version' )
ipv4Node = Node( matcher=ipv4KwMatcher,
   guard=mcastRoutingSupportedGuard,
   alias=ipFamilyAlias )
ipv6KwMatcher = KeywordMatcher( 'ipv6', helpdesc='IPv6 version' )
ipv6Node = Node( matcher=ipv6KwMatcher,
   guard=mcast6RoutingSupportedGuard,
   alias=ipFamilyAlias )

class IpFamilyExpr( CliExpression ):
   expression = "ipv4 | ipv6"
   data = {
      'ipv4': ipv4Node,
      'ipv6': ipv6Node
   }

ipv4NodeForShow = Node(
   matcher=ipv4MatcherForShow,
   guard=mcastRoutingSupportedGuard )

ipv6NodeForShow = Node(
   matcher=ipv6MatcherForShow,
   guard=mcast6RoutingSupportedGuard )

class ShowAddressFamilyExpr( CliExpression ):
   expression = "ipv4 | ipv6"
   data = {
      'ipv4': ipv4NodeForShow,
      'ipv6': ipv6NodeForShow
   }

   @staticmethod
   def adapter( mode, args, argsList ):
      if args.pop( 'ipv4', None ):
         args[ 'addressFamily' ] = AddressFamily.ipv4
      elif args.pop( 'ipv6', None ):
         args[ 'addressFamily' ] = AddressFamily.ipv6

class IpFamilySubExpr( CliExpressionFactory ):
   def __init__( self, af ):
      self.af = af
      CliExpressionFactory.__init__( self )

   def generate( self, name ):
      if self.af == AddressFamily.ipv4:
         node = ipv4Node
      else:
         node = ipv6Node

      typeSubExpr = name + '_ipFamily'
      class IpFamilySubExpr_( CliExpression ):
         expression = typeSubExpr
         data = {
            typeSubExpr: node
         }
      return IpFamilySubExpr_

class IpFamilyExpression( OrExpressionFactory ):
   def __init__( self, af=None ):
      OrExpressionFactory.__init__( self )

      if not af or af == AddressFamily.ipv4:
         self |= ( 'ipv4', IpFamilySubExpr( AddressFamily.ipv4 ) )
      if not af or af == AddressFamily.ipv6:
         self |= ( 'ipv6', IpFamilySubExpr( AddressFamily.ipv6 ) )

IpGenState = namedtuple( 'IpGenState', [ 'ipConfig', 'ipStatus',
                                         'ip6Config', 'ip6Status' ] )

def ipGenStatusInit( ipConfig, ipStatus, ip6Config, ip6Status ):
   return IpGenState( ipConfig, ipStatus, ip6Config, ip6Status )

def getVrfNameFromIntf( ipGenState, intfName, af=AddressFamily.ipunknown ):
   if af == AddressFamily.ipv6:
      ip6IntfConfig = ipGenState.ip6Config.intf.get( intfName )
      if ip6IntfConfig:
         return ip6IntfConfig.vrf
   else:
      ip4IntfConfig = ipGenState.ipConfig.ipIntfConfig.get( intfName )
      if ip4IntfConfig:
         return ip4IntfConfig.vrf
   return DEFAULT_VRF

def validateMulticastAddress( address ):
   if type( address ) is not str:
      address = str( address )

   if IpGenAddr( address ).isMulticast:
      #TODO: Check for reserved range
      return None
   else:
      return 'Invalid Multicast Range.'

_defaultPrefixIpv4 = IpGenPrefix( "0.0.0.0/32" )
_defaultPrefixIpv6 = IpGenPrefix( "::/0" )

def defaultPrefix( af ):
   if af == AddressFamily.ipv4:
      return _defaultPrefixIpv4
   elif af == AddressFamily.ipv6:
      return _defaultPrefixIpv6
   else:
      raise ValueError( "Unknown af: %s" % af )

_defaultMcastPrefixStrIpv4 = '224.0.0.0/4'
_defaultMcastPrefixStrIpv6 = 'ff00::/8'

def defaultMcastPrefixStr( af ):
   if af == AddressFamily.ipv4:
      return _defaultMcastPrefixStrIpv4
   elif af == AddressFamily.ipv6:
      return _defaultMcastPrefixStrIpv6
   else:
      raise ValueError( "Unknown af: %s" % af )

def defaultMcastPrefix( af ):
   return IpGenPrefix( defaultMcastPrefixStr( af ) )

def toPrefix( address ):
   ''' Converts an v4 or v6 address to a /32 or /128 IpGenPrefix
      Note: Due to AleMroute issues a zero source is  0.0.0.0/32 in v4
      and ::/0 in v6'''

   if type( address ) is not str:
      address = str( address )

   if ':' in address:
      genAddr = IpGenAddr( address )
      if genAddr.isAddrZero:
         prefix = defaultPrefix( AddressFamily.ipv6 )
      else:
         prefix = IpGenPrefix( address + '/128' )
   else:
      prefix = IpGenPrefix( address + '/32' )

   return prefix


mounts = {}

def doReadMounts( entityManager, types ):
   '''
   Performs config mounts for each config type for both
   Ipv4 and Ipv6 Address family
   types: List of types to be mounted
   '''
   for af in [ AddressFamily.ipv4, AddressFamily.ipv6 ]:
      for typeName in types:
         tacType = Tac.Type( typeName )
         assert hasattr( tacType, 'mountPath' )
         path = tacType.mountPath( af )
         mounts[ ( af, typeName) ] = LazyMount.mount( entityManager, path,
                                                      typeName, 'w' )


families = [ AddressFamily.ipv4, AddressFamily.ipv6 ]
def getMount( af, mType=None ):
   '''Returns config root of 'configType' based on the current mode'''
   assert af in families
   return mounts.get( ( af, mType ) )

def getAddressFamilyFromMode( mode, legacy=False ):
   ''' Returns: ipv4 or ipv6 or ipunknown.
      ipunknown is used for something common with ipv4 and ipv6'''
   if not legacy and hasattr( mode, 'af' ):
      return mode.af
   return AddressFamily.ipv4

def getMountFromMode( mode, mType=None, legacy=False ):
   ''' Returns the Configuration root based on mode'''
   af = getAddressFamilyFromMode( mode, legacy=legacy )
   return getMount( af, mType )

def getVrfEntity( af, vrfName, mType=None, collectionName='vrfConfig' ):
   mount = getMount( af, mType=mType )
   if mount:
      assert hasattr( mount, collectionName )
      collection = getattr( mount, collectionName )

      entity = collection.get( vrfName )
      return entity

def getVrfNameFromMode( mode ):
   if hasattr( mode, 'vrfName'):
      return mode.vrfName
   return DEFAULT_VRF

def getVrfEntityFromMode( mode, mType=None, legacy=False,
                          collectionName='vrfConfig' ):
   ''' Returns a vrfEntoty based on the current mode.
   mType is defaulted to None to play nice with functools.partial'''
   af = getAddressFamilyFromMode( mode, legacy=legacy )

   vrfName = getVrfNameFromMode( mode )
   return getVrfEntity( af, vrfName, mType, collectionName )

def getAfFromIpFamilyRule( af, legacy ):
   ''' Based on the argument passed to cli handle by th IpFamily rule
       we infer the applicable address family
   '''
   if legacy == True:
      return AddressFamily.ipv4
   elif af is None:
      # Applies to common config
      return AddressFamily.ipunknown
   else:
      return af

def isMulticastRoutingEnabled( mode, vrfName, af ):
   if not af:
      af = AddressFamily.ipunknown

   assert af in [ AddressFamily.ipv4, AddressFamily.ipv6, AddressFamily.ipunknown ]

   if vrfName is None or vrfName == '':
      return False

   mfibConfig = []
   root = mode.sysdbRoot
   if af in [ AddressFamily.ipv4, AddressFamily.ipunknown ]:
      mfibConfig.append( root.entity[ 'routing/multicast/vrf/config' ] )
   if af in [ AddressFamily.ipv6, AddressFamily.ipunknown ]:
      mfibConfig.append( root.entity[ 'routing6/multicast/vrf/config' ] )

   for config in mfibConfig:
      if vrfName in config.config and config.config[ vrfName ].routing:
         return True

   return False

def validateMulticastRouting( mode, vrfName, af, msgType='error' ):
   if not af:
      af = AddressFamily.ipunknown
   assert af in [ AddressFamily.ipv4, AddressFamily.ipv6, AddressFamily.ipunknown ]
   assert msgType in [ 'error', 'warn' ]
   if isMulticastRoutingEnabled( mode, vrfName, af ):
      return True

   msg = 'Multicast routing not configured for vrf %s ' % vrfName
   if af != AddressFamily.ipunknown:
      msg += 'in %s' % af

   if msgType == 'error':
      mode.addError( msg )
   else:
      mode.addWarning( msg )
   return False

def isUnicastRoutingEnabled( mode, vrfName, af ):
   if vrfName is None:
      vrfName = DEFAULT_VRF
   if af is None:
      af = AddressFamily.ipv4
   assert af in [ AddressFamily.ipv4, AddressFamily.ipv6 ]

   root = mode.sysdbRoot.entity

   if vrfName == DEFAULT_VRF:
      if af == AddressFamily.ipv4:
         config = root[ 'routing/config' ]
      else:
         config = root[ 'routing6/config' ]
   else:
      if af == AddressFamily.ipv4:
         configColl = root[ 'routing/vrf/config' ]
      else:
         configColl = root[ 'routing6/vrf/config' ]

      config = configColl.vrfConfig.get( vrfName )

   return config and config.routing


def validateUnicastRouting( mode, vrfName, af, msgType='error' ):
   assert af in [ AddressFamily.ipv4, AddressFamily.ipv6 ]
   assert msgType in [ 'error', 'warn' ]
   if isUnicastRoutingEnabled( mode, vrfName, af ):
      return True
   msg = 'Unicast routing not configured for vrf %s in %s' % ( vrfName, af )
   if msgType == 'error':
      mode.addError( msg )
   else:
      mode.addWarning( msg )
   return False


def validateRouting( mode, vrfName, af, msgType='error' ):
   return ( validateUnicastRouting( mode, vrfName, af, msgType=msgType ) and
            validateMulticastRouting( mode, vrfName, af, msgType=msgType ) )

