# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
from PolicyMap import numericalRangeToRangeString
from ClassificationLib import numericalRangeToSet
import Tac
tacNeighborProtocol = Tac.Type( 'Classification::NeighborProtocol' )
protocolBgp = tacNeighborProtocol.protocolBgp
tacMatchL4Protocol = Tac.Type( 'Classification::MatchL4Protocol' )
matchL4Bgp = tacMatchL4Protocol.matchL4Bgp

ActionType = Tac.Type( "PolicyMap::ActionType" )
TcpFlag = Tac.Type( "Classification::TcpFlag" )
fragmentType = Tac.Type( "Classification::FragmentType" )

def _addrListToStr( addrs ):
   return " ".join( map( str, sorted( addrs ) ) )

def _numericalRangeToCmd( numRange, optionCmd, matchType ):
   numRangeList = [ aRange for aRange in numRange ]
   optionStr = numericalRangeToRangeString( numRangeList, matchType )
   if optionStr:
      cmd = '%s %s' % ( optionCmd, optionStr )
      return cmd
   return ''

def tcpFlagAndMaskString( tcpFlagAndMask ):
   """
   For tcp flags we make use of Classification::TcpFlagAndMask.
   TcpFlagAndMask allows us to express states 0,1,X where X is don't care.
   If mask value for 'est' is 1 that means we care and we look at the tcpFlag value
   to see if it is a 0 or 1. If the mask value is 0 we do not care about the flag
   and treat it as X.
   When outputting our match rule command we only care for flags with a mask of 1,
   and if the tcpFlag is 0 we match on 'not flag' and if it is 1 we match on 'flag'
   """
   activeAttrs = []
   flags = [ attr.name for attr in TcpFlag.tacType.attributeQ
             if attr.isIndependentDomainAttr and attr.name != "value" ]
   for attr in flags:
      if getattr( tcpFlagAndMask.tcpFlagMask, attr, False ):
         activeAttrs.append( attr )
   notFlags = []
   flags = []
   for attr in activeAttrs:
      flagToAdd = attr
      if attr == "est":
         flagToAdd = "established"
      elif attr == "init":
         flagToAdd = "initial"
      if getattr( tcpFlagAndMask.tcpFlag, attr ):
         flags.append( flagToAdd )
      else:
         notFlags.append( flagToAdd )
   cmd = " flags"
   if notFlags:
      cmd += " not %s" % " ".join( sorted( notFlags ) )
   cmd += " %s" % " ".join( sorted( flags ) )
   return cmd

def _protoToCmds( proto, optionCmd, matchType ):
   cmds = []
   protos = [] # protocols without additional fields
   protoFieldCmd = ""
   protoRange = [ aRange for aRange in proto ]
   protoRange.sort()
   for p in protoRange:
      tcpFlagAndMask = proto[ p ].tcpFlagAndMask
      port = proto[ p ].port
      icmpTypeWithCodes = []
      icmpTypeWithoutCodes = []
      protoFieldWithIcmpCmd = ""
      icmpTypeColl = proto[ p ].icmpType
      for aRange in icmpTypeColl:
         if icmpTypeColl[ aRange ].icmpCode:
            icmpTypeWithCodes.append( aRange )
         else:
            icmpTypeWithoutCodes.append( aRange )
      if tcpFlagAndMask or port or icmpTypeWithCodes or icmpTypeWithoutCodes:
         protoStr = numericalRangeToRangeString( [ p ], matchType )
         if not port and tcpFlagAndMask:
            protoFieldCmd = "protocol %s%s" % \
                            ( protoStr, tcpFlagAndMaskString( tcpFlagAndMask ) )
            cmds.append( protoFieldCmd )
         for portField in port.itervalues():
            protoFieldCmd = "protocol %s" % protoStr
            sport = [ aRange for aRange in portField.sport ]
            dport = [ aRange for aRange in portField.dport ]
            sportFieldSet = " ".join( portField.sportFieldSet )
            dportFieldSet = " ".join( portField.dportFieldSet )
            # cannot have l4 ports and field-sets set
            assert not ( sport and sportFieldSet )
            assert not ( dport and dportFieldSet )
            if tcpFlagAndMask:
               protoFieldCmd += "%s" % tcpFlagAndMaskString( tcpFlagAndMask )
            if sport:
               sports = numericalRangeToRangeString( sport, matchType )
               protoFieldCmd += " source port %s" % sports
            if sportFieldSet:
               protoFieldCmd += " source port field-set %s" % sportFieldSet
            if dport:
               dports = numericalRangeToRangeString( dport, matchType )
               protoFieldCmd += " destination port %s" % dports
            if dportFieldSet:
               protoFieldCmd += " destination port field-set %s" % dportFieldSet
            cmds.append( protoFieldCmd )
         if icmpTypeWithoutCodes:
            icmpTypeWithoutCodes = numericalRangeToRangeString( icmpTypeWithoutCodes,
                                                                matchType )
            protoFieldWithIcmpCmd = "protocol %s type %s code all\n" % \
                                    ( protoStr, icmpTypeWithoutCodes )
         if icmpTypeWithCodes:
            icmpTypeWithCodes.sort()
            for aType in icmpTypeWithCodes:
               icmpCode = [ aRange for aRange in icmpTypeColl[ aType ].icmpCode ]
               icmpType = numericalRangeToRangeString( [ aType ], matchType )
               icmpTypeVal = numericalRangeToSet( [ aType ] )
               icmpCode = numericalRangeToRangeString( icmpCode, matchType,
                                                       list( icmpTypeVal )[ 0 ] )
               protoFieldWithIcmpCmd += "protocol %s type %s code %s\n" % \
                                       ( protoStr, icmpType, icmpCode )
         if protoFieldWithIcmpCmd:
            # do not add the last empty string by split()
            cmds.extend( protoFieldWithIcmpCmd.split( '\n' )[ : -1 ] )
      else:
         protos.append( p )
   if protos:
      protoCmd = "protocol %s" % \
         numericalRangeToRangeString( protos, matchType )
      cmds.append( protoCmd )
   return cmds

def _policeActionToCmd( policeAction ):
   cmd = "police rate %s" % ( policeAction.rateLimit.stringValue )
   return cmd

def structuredFilterToCmds( structuredFilter, policyActions, matchType ):
   '''
   Takes a structured filter and returns a dictionary containing all commands
   used to create the filter.
   '''
   structuredFilterCmds = {}
   sourceAddrs = _addrListToStr( structuredFilter.source )
   if sourceAddrs:
      cmd = 'source prefix %s' % sourceAddrs
      structuredFilterCmds[ 'source' ] = cmd

   sourcePrefixSet = _addrListToStr( structuredFilter.srcPrefixSet )
   if sourcePrefixSet:
      cmd = 'source prefix field-set %s' % sourcePrefixSet
      structuredFilterCmds[ 'srcPrefixSet' ] = cmd

   dstAddrs = _addrListToStr( structuredFilter.destination )
   if dstAddrs:
      cmd = 'destination prefix %s' % dstAddrs
      structuredFilterCmds[ 'destination' ] = cmd

   dstPrefixSet = _addrListToStr( structuredFilter.dstPrefixSet )
   if dstPrefixSet:
      cmd = 'destination prefix field-set %s' % dstPrefixSet
      structuredFilterCmds[ 'dstPrefixSet' ] = cmd

   if policyActions:
      actionsCmds = []
      for actionType in policyActions:
         if actionType == ActionType.deny:
            actionsCmds.append( 'drop' )
         elif actionType == ActionType.police:
            cmd = _policeActionToCmd( policyActions[ actionType ] )
            actionsCmds.append( cmd )
         elif actionType == 'count':
            cmd = "count"
            if policyActions[ actionType ].counterName:
               cmd += " %s" % policyActions[ actionType ].counterName
            actionsCmds.append( cmd )
         elif actionType == 'log':
            actionsCmds.append( 'log' )
         elif actionType == 'sample':
            actionsCmds.append( 'sample' )
         elif actionType == 'sampleAll':
            actionsCmds.append( 'sample all' )
         elif actionType == 'setDscp':
            cmd = 'set dscp %d' % policyActions[ actionType ].dscp
            actionsCmds.append( cmd )
         elif actionType == 'setTc':
            cmd = 'set traffic class %d' % policyActions[ actionType ].tc
            actionsCmds.append( cmd )
         elif actionType == 'permit':
            pass
         elif actionType == ActionType.loadBalance:
            nhgAction = policyActions[ actionType ]
            cmd = 'load-balance nexthop-group %s' % nhgAction.nexthopGroupName
            actionsCmds.append( cmd )
         else:
            raise NotImplementedError, "Unsupported action"
      structuredFilterCmds[ 'actions' ] = sorted( actionsCmds )

   protoCmd = _protoToCmds( structuredFilter.proto, 'protocol', matchType )
   if protoCmd:
      structuredFilterCmds[ 'protocol' ] = protoCmd

   ttlCmd = _numericalRangeToCmd( structuredFilter.ttl, 'ttl', matchType )
   if ttlCmd:
      structuredFilterCmds[ 'ttl' ] = ttlCmd

   dscpCmd = _numericalRangeToCmd( structuredFilter.dscp, 'dscp', matchType )
   if dscpCmd:
      structuredFilterCmds[ 'dscp' ] = dscpCmd

   lengthCmd = _numericalRangeToCmd( structuredFilter.length, 'ip length',
                                     matchType )
   if lengthCmd:
      structuredFilterCmds[ 'length' ] = lengthCmd

   if structuredFilter.fragment:
      if structuredFilter.fragment.fragmentType == fragmentType.matchAll:
         structuredFilterCmds[ 'fragment' ] = 'fragment'
      if structuredFilter.fragment.fragmentType == fragmentType.matchOffset:
         fragOffsetCmd = _numericalRangeToCmd( structuredFilter.fragment.offset,
                                               'fragment offset', matchType )
         if fragOffsetCmd:
            structuredFilterCmds[ 'fragment' ] = fragOffsetCmd

   if structuredFilter.matchIpOptions:
      structuredFilterCmds[ 'matchIpOptions' ] = 'ip options'

   if structuredFilter.neighborProtocol == protocolBgp:
      structuredFilterCmds[ 'protoNeighborsBgp' ] = 'protocol neighbors bgp'

   if structuredFilter.matchL4Protocol == matchL4Bgp:
      structuredFilterCmds[ 'protoBgp' ] = 'protocol bgp'
   return structuredFilterCmds
