# Copyright (c) 2008, 2009, 2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

#-------------------------------------------------------------------------------
# This module contains the definition of Cli matchers for matching IP
# addresses and IP address prefixes.
#-------------------------------------------------------------------------------
import re

import Arnet
import CliCommand
import CliMatcher
import CliParser
from CliParserCommon import MatchResult, noMatch
import IpUtils
from Toggles import ArnetToggleLib

class IpAddrMatcher( CliMatcher.Matcher ):
   '''Type of matcher that matches an IP address, and evaluates to the
   IP address string as typed in by the administrator.
   '''
   ipAddrRe_ = re.compile( r'(\d+)\.(\d+)\.(\d+)\.(\d+)$' )
   # This is the same regular expression as above, but with each part
   # surrounded with non-grouping parentheses (?:...) and made optional.
   ipAddrCompletionRe_ = re.compile(
      r'(?:(\d+)(?:\.(?:(\d+)(?:\.(?:(\d+)(?:\.(?:(\d+))?)?)?)?)?)?)?$' )
   # The partial re is similar to the Completion re, but the dots are
   # grouped with the number following - i.e., 10.255./16 is not allowed,
   # and the first digit group is required.
   ipAddrPartialRe_ = re.compile(
      r'(\d+)(?:\.(\d+)(?:\.(\d+)(?:\.(\d+))?)?)?$' )

   def __init__( self, helpdesc, partial=False, checkContinuous=False, **kargs ):
      super( IpAddrMatcher, self ).__init__( helpdesc=helpdesc, **kargs )
      self.completion_ = CliParser.Completion( 'A.B.C.D', helpdesc, False )
      self.checkContinuous_ = checkContinuous

      # Partial allows zero to three dots, e.g., to allow specifying
      # a prefix of 10.255/16
      self.partial_ = partial

   def match( self, mode, context, token ):
      addrRe = self.ipAddrPartialRe_ if self.partial_ else self.ipAddrRe_
      m = addrRe.match( token )
      if m is None:
         return noMatch
      for group in m.groups():
         if group is not None and int( group ) > 255:
            return noMatch
      # Pad to a full IP address with trailing zeros if necessary.
      s = ".".join( m.groups( '0' ) )
      if not self.checkContinuous_:
         return MatchResult( s, s )
      try:
         m = Arnet.Mask( s )
      except ValueError:
         return noMatch
      return MatchResult( s, s )

   def completions( self, mode, context, token ):
      m = self.ipAddrCompletionRe_.match( token )
      if m is None:
         return []
      # m.groups( '0' ) returns a list of length 4 containing the four parts of
      # the IP address, with each part defaulting to '0' if it is not present.
      for group in m.groups( '0' ):
         if int( group ) > 255:
            return []
      return [ self.completion_ ]

   def __str__( self ):
      return '<IP address>'

# PREFIX_OVERLAP_ALLOW will not check host bits (default). If the user enters
# '10.0.0.1/8', then the rule will accept it. But beaware it cannot be written
# to the Arnet::Prefix type in Sysdb. Normally the CliPlugin checks for this
# condition and prints out a descriptive error message.
#
# PREFIX_OVERLAP_AUTOZERO will accept '10.0.0.1/8' but automatically zeros out
# the host bits and turn it into '10.0.0.0/8'.
#
# PREFIX_OVERLAP_REJECT will return None in case of host bits are not all zero,
# just like invalid network mask. The CliPlugin could not know the exact reason
# why None is returned.

PREFIX_OVERLAP_ALLOW = 0
PREFIX_OVERLAP_AUTOZERO = 1
PREFIX_OVERLAP_REJECT = 2

def prefixFormat( addr, masklen, resultType=str, overlap=PREFIX_OVERLAP_ALLOW,
                  returnError=None ):
   masklen = int( masklen )
   if overlap != PREFIX_OVERLAP_ALLOW:
      ip = Arnet.IpAddress( addr ).value
      hostmask = ( 1L << ( 32 - masklen ) ) - 1
      if ( ip & hostmask ) != 0:
         # host bits are not all zeros
         if overlap == PREFIX_OVERLAP_AUTOZERO:
            addr = Arnet.IpAddress( ip & ~hostmask ).stringValue
         else:
            return returnError
   return resultType( '%s/%d' % ( addr, masklen ) )

class IpPrefixMatcher( CliMatcher.Matcher ):
   '''Type of matcher that matches an IP address prefix, and evaluates
   to the corresponding Arnet::Prefix value.  If partial is specified,
   then a partial IP address is permitted (e.g., a route prefix,
   vs. an interface address with mask length)
   '''
   prefixRe_ = re.compile( r'([\d.]+)/(\d+)$' )
   completionRe_ = re.compile( r'(?:([\d.]+)(?:/(?:(\d+))?)?)?$' )

   def __init__( self, helpdesc, resultType=IpUtils.Prefix, partial=False,
                 overlap=PREFIX_OVERLAP_ALLOW, **kargs ):
      super( IpPrefixMatcher, self ).__init__( helpdesc=helpdesc, **kargs )
      self.completion_ = CliParser.Completion( 'A.B.C.D/E', helpdesc, False )
      self.ipAddrMatcher_ = IpAddrMatcher( helpdesc, partial=partial )
      self.prefixLenMatcher_ = CliMatcher.IntegerMatcher( 0, 32, helpname=None,
                                          helpdesc='Length of the prefix in bits' )
      self.resultType_ = resultType
      self.overlap_ = overlap

   def match( self, mode, context, token ):
      m = self.prefixRe_.match( token )
      if m is None:
         return noMatch
      addr = self.ipAddrMatcher_.match( mode, context, m.group( 1 ) )
      if addr is noMatch:
         return noMatch
      addr = addr.result
      if self.prefixLenMatcher_.match( mode, context, m.group( 2 ) ) is noMatch:
         return noMatch
      result = prefixFormat( addr, m.group( 2 ), resultType=self.resultType_,
            overlap=self.overlap_ )
      return MatchResult( result, str( result ) )

   def completions( self, mode, context, token ):
      m = self.completionRe_.match( token )
      if m is None:
         return []
      if m.group( 1 ) is not None:
         if self.ipAddrMatcher_.completions( mode, context, m.group( 1 ) ) == []:
            return []
      if m.group( 2 ) is not None:
         if self.prefixLenMatcher_.completions( mode, context, m.group( 2 ) ) == []:
            return []
      return [ self.completion_ ]

   def __str__( self ):
      return '<IP address prefix>'

# General purpose rules for matching an IP address and an IP address prefix.
#-------------------------------------------------------------------------------
ipAddrMatcher = IpAddrMatcher( 'IP address' )
ipPrefixMatcher = IpPrefixMatcher( 'IP address prefix' )

def _ipPrefixExpr( name, addrdesc, maskdesc, prefixdesc,
                   resultType, partial=False, inverseMask=False,
                   maskKeyword=False, discontinuousMask=False,
                   overlap=PREFIX_OVERLAP_ALLOW, returnError=None ):
   addrMather = IpAddrMatcher( helpdesc=addrdesc )
   checkContinuous = not inverseMask and not discontinuousMask
   maskMatcher = IpAddrMatcher( helpdesc=maskdesc,
                                checkContinuous=checkContinuous )
   addrName = name + '_ADDR'
   maskKwName = name + '_MASK_KW'
   maskName = name + '_MASK'

   prefixMatcher = IpPrefixMatcher( prefixdesc, resultType, partial,
                                    value=lambda mode, match:
                                    prefixFormat( *tuple( match.split( '/' ) ),
                                    resultType=resultType,
                                    overlap=overlap,
                                    returnError=returnError )
                                    if isinstance( match, str ) else match )

   def maskFormat( addr, mask ):
      # state space: prefix or full? wildcard or static? contiguous or not?
      # If the mask is actually a prefix, or a contiguous full mask
      try:
         masklen = Arnet.Mask( mask, inverse=inverseMask ).maskLen
         if masklen is None:
            raise ValueError
         return prefixFormat( addr, masklen, resultType=resultType, 
                              overlap=overlap, returnError=returnError )
      except ValueError:
         # The netmask was invalid.
         # try a discontiguous mask, if appropriate
         if discontinuousMask:
            i_mask = Arnet.intFromMask( mask )
            if i_mask is not None:
               return resultType( addr, ~i_mask if # pylint: disable-msg=E1130
                                  inverseMask else i_mask )
      return returnError

   def adapterFunc( mode, args, argsList ):
      if addrName in args:
         addr = args.pop( addrName )
         mask = args.pop( maskName )
         args.pop( maskKwName, None )
         isList = isinstance( addr, list )
         if not isList:
            addr = [ addr ]
            mask = [ mask ]
         newPrefix = []
         for a, m in zip( addr, mask ):
            newPrefix.append( maskFormat( a, m ) )
         if not isList:
            newPrefix = newPrefix[ 0 ]
         assert name not in args
         args[ name ] = newPrefix

   class IpPrefixWithMaskExpr( CliCommand.CliExpression ):
      expression = '%s | ( %s %s %s )' % ( name, addrName, maskKwName, maskName )
      data = { name : prefixMatcher,
               addrName : addrMather,
               maskKwName : CliMatcher.KeywordMatcher( 'mask',
                  helpdesc='Network mask' ),
               maskName : maskMatcher }
      adapter = adapterFunc

   class IpPrefixExpr( CliCommand.CliExpression ):
      expression = '%s | ( %s %s )' % ( name, addrName, maskName )
      data = { name : prefixMatcher,
               addrName : addrMather,
               maskName : maskMatcher }
      adapter = adapterFunc

   if maskKeyword:
      return IpPrefixWithMaskExpr
   else:
      return IpPrefixExpr

class IpPrefixExprFactory( CliCommand.CliExpressionFactory ):
   def __init__( self, addrdesc, maskdesc, prefixdesc, resultType,
                 **kwargs ):
      self.addrdesc_ = addrdesc
      self.maskdesc_ = maskdesc
      self.prefixdesc_ = prefixdesc
      self.resultType_ = resultType
      self.kwargs_ = kwargs
      CliCommand.CliExpressionFactory.__init__( self )

   def generate( self, name ):
      return _ipPrefixExpr( name, self.addrdesc_, self.maskdesc_, self.prefixdesc_,
                            self.resultType_, **self.kwargs_ )

def ipPrefixExpr( addrdesc, maskdesc, prefixdesc, **kwargs ):
   return IpPrefixExprFactory( addrdesc, maskdesc, prefixdesc, IpUtils.Prefix,
                               **kwargs )

def ipAddrWithMaskExpr( addrdesc, maskdesc, prefixdesc, **kwargs ):
   return IpPrefixExprFactory( addrdesc, maskdesc, prefixdesc, Arnet.AddrWithMask,
                               **kwargs )

def ipAddrWithFullMaskExpr( addrdesc, maskdesc, prefixdesc, **kwargs ):
   assert 'discontinuousMask' not in kwargs
   kwargs[ 'discontinuousMask' ] = True
   return IpPrefixExprFactory( addrdesc, maskdesc, prefixdesc,
                               Arnet.AddrWithFullMask,
                               **kwargs )

class IpAddrOrPrefixExprFactory( CliCommand.CliExpressionFactory ):
   def __init__( self, addrdesc, maskdesc, prefixdesc,
         partial=False, overlap=PREFIX_OVERLAP_AUTOZERO ):
      CliCommand.CliExpressionFactory.__init__( self )
      self.addrdesc_ = addrdesc
      self.maskdesc_ = maskdesc
      self.prefixdesc_ = prefixdesc
      self.partial_ = partial
      self.overlap_ = overlap

   def generate( self, name ):
      '''Accept either an IPv4 address or a prefix'''

      exprIpPrefix = ipPrefixExpr( self.addrdesc_, self.maskdesc_, self.prefixdesc_,
            partial=self.partial_, overlap=self.overlap_ )

      class IpAddrOrPrefixExpr( CliCommand.CliExpression ):
         expression = 'IP_ADDR_%s | IP_PREFIX_%s' % ( name, name )
         data = {
            'IP_ADDR_%s' % name : CliCommand.Node(
               matcher=IpAddrMatcher( 'Match this IP address' ),
               alias=name ),
            'IP_PREFIX_%s' % name : exprIpPrefix,
         }

         @staticmethod
         def adapter( mode, args, argsList ):
            if 'IP_PREFIX_%s' % name in args:
               args[ name ] = args[ 'IP_PREFIX_%s' % name ]

      return IpAddrOrPrefixExpr

def splitIpAddrToInts( ipAddrStr ):
   '''splits the string form of an IP address into integers'''
   return [ int( x ) for x in ipAddrStr.split( '.' ) ]

def isValidIpAddr( addr ):
   ''' this function tests whether the ip address passed it is
       a valid IP address. Returns True if it is, False otherwise'''

   #  this can be used once we get a valid ip address using ipAddr rule.
   c = splitIpAddrToInts( addr )
   if len( c ) != 4:
      return False
   for i in c:
      if i > 255 or i < 0:
         return False
   return True

def isLoopbackIpAddr( addr ):
   ''' the function tests whether the ip address passed it is Loopback
       Returns True, if it is. False if not. Loopback addresses are 127.0.0.0/8'''

   #  this can be used once we get a valid ip address using ipAddr rule.
   if not isValidIpAddr( addr ):
      return False
   c = splitIpAddrToInts( addr )

   # check if addr is a loopback address (127.0.0.0/8)
   return c[ 0 ] == 127

def isReserved( addr ):
   #  this can be used once we get a valid ip address using ipAddr rule.
   # 240.0.0.0-255.255.255.255 Reserved addresses (e.g. Class E addresses).
   if not isValidIpAddr( addr ):
      return False
   c = splitIpAddrToInts( addr )

   # check to see if address is reserved.  note that we have
   # already checked that c[0] <= 255 in the valid IP check above,
   # so there is no need to do it again.
   return c[ 0 ] >= 240

def isReservedIpAddr( addr ):
   ''' the function tests whether the ip address passed it is reserved
       Returns True, if it is. False if not.'''
   # isReserved checks if addr is valid
   reserved = isReserved( addr )
   if ArnetToggleLib.toggleIpv4Routable240ClassEEnabled():
      return False
   return reserved

def isMartianIpAddr( addr ):
   if not isValidIpAddr( addr ):
      return False
   c = splitIpAddrToInts( addr )

   if c[ 0 ] == 0 or c[ 0 ] == 127:
      return True

   if not ArnetToggleLib.toggleIpv4Routable240ClassEEnabled():
      if c[ 0 ] >= 240:
         return True
   return False

def validateMulticastIpAddr( addr, allowReserved=True ):
   ''' the function tests whether the multicast address passed is valid.
       Returns None, if accepted. Error message if not.'''
   #  this can be used once we get a valid ip address using ipAddr rule.
   #  we could also write a multicastipaddr rule which can do the match
   #  directly, but this is better since we can then present an appropriate
   #  error message '''
   # 224.0.0.0-224.0.0.255 Reserved for special well-known multicast addresses.
   # 224.0.1.0-238.255.255.255 Globally-scoped (Internet-wide) multicast
   # addresses.
   # 239.0.0.0-239.255.255.255 Administratively-scoped (local) multicast
   # addresses.
   if not isValidIpAddr( addr ):
      return 'Invalid Address.'
   c = splitIpAddrToInts( addr )

   # check for multicast component
   if c[ 0 ] > 239 or c[ 0 ] < 224:
      return 'Invalid Multicast Range.'
   if not allowReserved and c[ 0 ] == 224 and c[ 1 ] == 0 and c[ 2 ] == 0:
      return 'Reserved Multicast Range.'
   return None
