#!/usr/bin/env python
# Copyright (c) 2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
import Tac
from AclLib import ( tcpServiceByName, udpServiceByName, commonServiceByName_ )
from ClassificationLib import ( extraIpv4Protocols, extraIpv6Protocols, icmpV4Types,
                                icmpV6Types, icmpV4Codes, icmpV6Codes )
from AclCliLib import ( genericIpProtocols, genericIp6Protocols )

PortRange = Tac.Type( "Classification::PortRange" )
ProtocolRange = Tac.Type( "Classification::ProtocolRange" )
IcmpTypeRange = Tac.Type( "Classification::IcmpTypeRange" )
IcmpCodeRange = Tac.Type( "Classification::IcmpCodeRange" )

ClassMapMatchOptionTypeName = 'PolicyMap::ClassMapMatchOption'

tacMatchOption = Tac.Type( ClassMapMatchOptionTypeName )
def matchOptionToEnum( matchOption ):
   if matchOption == 'ip':
      return tacMatchOption.matchIpAccessGroup
   elif matchOption == 'ipv6':
      return tacMatchOption.matchIpv6AccessGroup
   elif matchOption == 'mac':
      return tacMatchOption.matchMacAccessGroup
   else:
      return None

tacClassMapType = Tac.Type( "PolicyMap::MapType" )
pbrMapType = tacClassMapType.mapPbr
tapaggMapType = tacClassMapType.mapTapAgg

def mapTypeToEnum( mapType ):
   if mapType == 'pbr':
      return pbrMapType
   elif mapType == 'tapagg':
      return tapaggMapType
   else:
      return None   

tacMatchOption = Tac.Type( 'PolicyMap::ClassMapMatchOption' )
matchIpAccessGroup = tacMatchOption.matchIpAccessGroup
matchIpv6AccessGroup = tacMatchOption.matchIpv6AccessGroup
matchMplsAccessGroup = tacMatchOption.matchMplsAccessGroup
matchMacAccessGroup = tacMatchOption.matchMacAccessGroup

def matchOptionToStr( matchOption ):
   if matchOption == tacMatchOption.matchIpAccessGroup:
      return 'ip'
   elif matchOption == tacMatchOption.matchIpv6AccessGroup:
      return 'ipv6'
   elif matchOption == tacMatchOption.matchMacAccessGroup:
      return 'mac'
   else:
      return 'mpls'

def numericalRangeToRangeString( numRangeList, matchType, icmpType=None ):
   '''
   Given a list of Classification::NumericalRange objects, produce
   a range in string form, Either: 3, 7, 90-95 Or: bgp ldp 
   '''
   numRangeList.sort()
   rangeStr = maybeConvert( numRangeList, matchType, icmpType )
   if rangeStr:
      return rangeStr
   return ', '.join( [ x.stringValue() for x in numRangeList ] )

def maybeConvert( numRangeList, matchType, icmpType ):
   '''
   Lists converted to names cannot contain ranges and must have
   a name for each list value.
   '''
   if not numRangeList or rangeFound( numRangeList ):
      return None

   nameList = None
   if isinstance( numRangeList[ 0 ], PortRange ):
      nameList = [ tcpServiceByName, udpServiceByName, commonServiceByName_ ]
   elif isinstance( numRangeList[ 0 ], ProtocolRange ):
      nameList = [ genericIpProtocols, extraIpv4Protocols ] if matchType == 'ipv4' \
                 else [ genericIp6Protocols, extraIpv6Protocols ]
   elif isinstance( numRangeList[ 0 ], IcmpTypeRange ):
      nameList = [ icmpV4Types ] if matchType == 'ipv4' else [ icmpV6Types ]
   elif isinstance( numRangeList[ 0 ], IcmpCodeRange ):
      nameList = [ icmpV4Codes[ icmpType ] ] if matchType == 'ipv4' else \
                 [ icmpV6Codes[ icmpType ] ]
   else:
      return None

   convertedList = []
   for numRange in numRangeList:
      namedValue = convertRange( numRange, nameList )
      if not namedValue:
         return None
      convertedList.append( namedValue )
   return ' '.join( convertedList )

def rangeFound( numRangeList ):
   '''
   Returns a list of the numerical range objects containing a value
   range, ie. 5-8.
   '''
   return hasRanges( numRangeList )

def hasRanges( numRangeList ):
   '''
   Returns True if any ranges are found in the list.
   '''
   return any( numRange for numRange in numRangeList if numRange.hasRange() )

def convertRange( numRange, nameList ):
   '''
   Given a single numerical range object, return the contents as
   either a well known name or None.
   '''
   assert not numRange.hasRange()
   for service in nameList:
      for name, value in service.iteritems():
         if numRange.rangeStart == value[ 0 ]:
            return name
   return None
