# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import Tac

from Arnet import EthAddr
from eunuchs.if_ether_h import ETH_P_IP, ETH_P_IPV6, ETH_P_8021Q, ETH_P_MPLS_UC, \
                               ETH_P_MPLS_MC
from eunuchs.in_h import IPPROTO_TCP, IPPROTO_UDP

Dot1qHeader = Tac.Type( 'PacketTracer::Dot1qHeader' )
IpGenAddr = Tac.Type( 'Arnet::IpGenAddr' )
L2Header = Tac.Type( 'PacketTracer::L2Header' )
L3Header = Tac.Type( 'PacketTracer::L3Header' )
L4Header = Tac.Type( 'PacketTracer::L4Header' )
Request = Tac.Type( 'PacketTracer::Request' )

BaseFields = [ '<ingressIntf>' ]
L2Fields = [ '<srcMac>', '<dstMac>', '<etherType>' ]
OptionalL2Fields = [ '<vlan>', '<innerVlan>' ]
IPv4Fields = [ '<srcIpv4>', '<dstIpv4>', '<ipTtl>', '<ipProto>' ]
IPv6Fields = [ '<srcIpv6>', '<dstIpv6>', '<hopLimit>', '<nextHeader>',
               '<flowLabel>' ]
L4Fields = [ '<srcL4Port>', '<dstL4Port>' ]

ArgToLabel = {
   '<ingressIntf>': 'ingress-interface',
   '<rawPacket>': 'raw',
   '<srcMac>': 'src-mac',
   '<dstMac>': 'dst-mac',
   '<etherType>': 'eth-type',
   '<vlan>': 'vlan',
   '<innerVlan>': 'inner-vlan',
   '<srcIpv4>': 'src-ipv4',
   '<dstIpv4>': 'dst-ipv4',
   '<ipTtl>': 'ip-ttl',
   '<ipProto>': 'ip-protocol',
   '<srcL4Port>': 'src-l4-port',
   '<dstL4Port>': 'dst-l4-port',
   '<srcIpv6>': 'src-ipv6',
   '<dstIpv6>': 'dst-ipv6',
   '<hopLimit>': 'hop-limit',
   '<nextHeader>': 'next-header',
   '<flowLabel>': 'flow-label',
}

def printRequest( request ):
   """Print out the currently configured packet
   Example:
   MAC Source aaaa.bbbb.cccc Destination 1234.4321.abcd 802.1Q 10 802.1Q 20
   IPv4 Source 10.1.0.1 Destination 20.1.0.1 TTL 23 Protocol 6
   TCP Source Port 10000 Destination Port 20000
   """
   if request.packetOverWriteBytes or request.overwriteOffset > 0:
      # Raw packet requests are not printed out
      return

   l2 = request.l2Header

   dot1qString = ''
   etherType = l2.etherType
   if l2.dot1qHeader != Dot1qHeader():
      dot1qString += ' 802.1Q {}'.format( l2.dot1qHeader.vlanId )
      etherType = l2.dot1qHeader.nextEtherType
      if l2.innerDot1qHeader != Dot1qHeader():
         dot1qString += ' 802.1Q {}'.format( l2.innerDot1qHeader.vlanId )
         etherType = l2.innerDot1qHeader.nextEtherType
   print( 'MAC Source {} Destination {} Ethertype {}{}'.format(
             EthAddr( l2.srcMac ).displayString, EthAddr( l2.dstMac ).displayString,
             hex( etherType ), dot1qString ) )

   l3Protocol = 0
   if request.hasL3:
      l3 = request.l3Header
      if l3.ipVersion == 4:
         l3Protocol = l3.protocol
         print( 'IPv4 Source {} Destination {} TTL {} Protocol {}'.format(
                   l3.srcIp, l3.dstIp, l3.ttl, l3.protocol ) )
      elif l3.ipVersion == 6:
         l3Protocol = l3.nextHeader
         print( 'IPv6 Source {} Destination {} Hop Limit {} Flow Label {} '
                'Next Header {}'.format( l3.srcIp, l3.dstIp, l3.hopLimit,
                                         l3.flowLabel, l3.nextHeader ) )

   if request.hasL4:
      l4 = request.l4Header
      l4Type = 'TCP' if l3Protocol == IPPROTO_TCP else 'UDP'
      print( '{} Source Port {} Destination Port {}'.format(
             l4Type, l4.srcPort, l4.dstPort ) )

def updateRequest( request, treeDict ):
   """Update a request from all the user provided fields"""
   request.ingressIntf = treeDict.get( '<ingressIntf>' )

   rawPacket = treeDict.get( '<rawPacket>' )
   if rawPacket:
      request.packetOverWriteBytes = rawPacket.decode( 'hex' )
      request.overwriteOffset = 0
      return request

   l2Header = L2Header()
   l2Header.srcMac = treeDict.get( '<srcMac>' )
   l2Header.dstMac = treeDict.get( '<dstMac>' )
   l2Header.etherType = treeDict.get( '<etherType>' )

   vlan = treeDict.get( '<vlan>' )
   if vlan:
      dot1qHeader = Dot1qHeader()
      dot1qHeader.vlanId = vlan
      dot1qHeader.nextEtherType = l2Header.etherType

      innerVlan = treeDict.get( '<innerVlan>' )
      if innerVlan:
         innerDot1qHeader = Dot1qHeader()
         innerDot1qHeader.vlanId = innerVlan
         innerDot1qHeader.nextEtherType = dot1qHeader.nextEtherType
         l2Header.innerDot1qHeader = innerDot1qHeader
         dot1qHeader.nextEtherType = ETH_P_8021Q

      l2Header.dot1qHeader = dot1qHeader
      l2Header.etherType = ETH_P_8021Q
   request.l2Header = l2Header

   packetType = treeDict.get( '<packetType>' )
   if ( packetType == 'ipv4' or packetType == 'ipv6' ):
      protocol = 0
      request.hasL3 = True
      ipVersion = 4 if packetType == 'ipv4' else 6
      l3Header = L3Header( ipVersion )
      if ipVersion == 4:
         l3Header.srcIp = IpGenAddr( treeDict.get( '<srcIpv4>' ) )
         l3Header.dstIp = IpGenAddr( treeDict.get( '<dstIpv4>' ) )
         l3Header.ttl = treeDict.get( '<ipTtl>' )
         l3Header.protocol = treeDict.get( '<ipProto>' )
         protocol = treeDict.get( '<ipProto>' )
      elif ipVersion == 6:
         l3Header.srcIp = IpGenAddr( treeDict.get( '<srcIpv6>' ).stringValue )
         l3Header.dstIp = IpGenAddr( treeDict.get( '<dstIpv6>' ).stringValue )
         l3Header.hopLimit = treeDict.get( '<hopLimit>' )
         l3Header.nextHeader = treeDict.get( '<nextHeader>' )
         l3Header.flowLabel = treeDict.get( '<flowLabel>' )
         protocol = treeDict.get( '<nextHeader>' )
      request.l3Header = l3Header

      if treeDict.get( '<l4Type>' ) in [ 'tcp', 'udp' ]:
         request.hasL4 = True
         l4Header = L4Header()
         l4Header.protocol = protocol
         l4Header.srcPort = treeDict.get( '<srcL4Port>' )
         l4Header.dstPort = treeDict.get( '<dstL4Port>' )
         request.l4Header = l4Header

   return request

def generatePacketTypes( treeDict ):
   """Attempt to generate the packet type (Ethernet, IPv4 IPv6, raw) as well as theL4
   L4 type (TCP or UDP)."""
   packetType = treeDict.get( '<packetType>' )
   l4Type = treeDict.get( '<l4Type>' )

   if packetType == 'raw' or '<rawPacket>' in treeDict:
      return ( 'raw', 'none' )

   # If the packetType and l4Type have been previously configured re-use them
   if packetType and l4Type:
      return ( packetType, l4Type )

   # If any of the L2 fields have been configured set type to ethernet. This is
   # performed prior to IPv4/v6 checks as those overrule the configuration here.
   if any( k in treeDict for k in L2Fields + OptionalL2Fields ):
      packetType = 'ethernet'

   # Set the packet type to IPv4/v6 if any of their corresponding fields are
   # configured.
   protocol = None
   if any( k in treeDict for k in IPv6Fields ):
      packetType = 'ipv6'
      protocol = treeDict.get( '<nextHeader>' )
   elif any( k in treeDict for k in IPv4Fields ):
      packetType = 'ipv4'
      protocol = treeDict.get( '<ipProto>' )

   # Determine the L4 type based on the protocol in the IPv4 or v6 header
   if protocol == IPPROTO_TCP:
      l4Type = 'tcp'
   elif protocol == IPPROTO_UDP:
      l4Type = 'udp'
   elif isinstance( protocol, int ):
      # If a protocol that's not TCP nor UDP has been set we'll set l4Type to 'none'
      # so we won't prompt the user for the L4 header.
      l4Type = 'none'

   return ( packetType, l4Type )

def generateRequiredFields( packetType, l4Type ):
   """Generate a list of required fields based on both the packet type (Ethernet,
   IPv4, IPv6, raw) and L4 type (TCP, UDP)."""
   if packetType == 'raw':
      return [ '<rawPacket>', '<ingressIntf>' ]

   fields = BaseFields + L2Fields
   if packetType != 'ethernet' and '<etherType>' in fields:
      fields.remove( '<etherType>' )

   if packetType == 'ipv4':
      fields += IPv4Fields
   elif packetType == 'ipv6':
      fields += IPv6Fields
   if l4Type in [ 'tcp', 'udp' ]:
      fields += L4Fields

   return fields

def updateEthertypeAndProtocol( treeDict ):
   packetType = treeDict[ '<packetType>' ]
   protocolKey = None

   if packetType == 'raw':
      return
   elif packetType == 'ipv4':
      treeDict[ '<etherType>' ] = ETH_P_IP
      protocolKey = '<ipProto>'
   elif packetType == 'ipv6':
      treeDict[ '<etherType>' ] = ETH_P_IPV6
      protocolKey = '<nextHeader>'

   if protocolKey:
      if treeDict[ '<l4Type>' ] == 'tcp':
         treeDict[ protocolKey ] = IPPROTO_TCP
      elif treeDict[ '<l4Type>' ] == 'udp':
         treeDict[ protocolKey ] = IPPROTO_UDP

def fetchConfiguredPacket( mode ):
   """Fetch the previously configured packet from the session data"""
   treeDict = mode.session.sessionData( 'PacketTracer.Packet' )
   newConfiguration = False
   if not treeDict:
      treeDict = {}
      newConfiguration = True
      mode.session.sessionDataIs( 'PacketTracer.Packet', treeDict )

   return ( treeDict, newConfiguration )

def clearConfiguredPacket( mode ):
   """Clear a previously configured packet from the session data"""
   mode.session.sessionDataIs( 'PacketTracer.Packet', None )

def validateFields( mode, treeDict, packetTracerHwStatus, packetTracerSwStatus ):
   """Validate the user's specified fields against what is supported by the
   hardware."""
   etherType = treeDict.get( '<etherType>' )
   if ( not packetTracerSwStatus.mplsPacketSupported and
        etherType in [ ETH_P_MPLS_UC, ETH_P_MPLS_MC ] ):
      mode.addError( 'Invalid request: MPLS packets are unsupported. Clearing '
                     'packet configuration.' )
      clearConfiguredPacket( mode )
      return False

   rawPacket = treeDict.get( '<rawPacket>' )
   if rawPacket:
      if not packetTracerHwStatus.rawPacketSupported:
         mode.addError( 'Invalid request: Raw packets are unsupported.' )
         clearConfiguredPacket( mode )
         return False

      rawPacketLen = len( rawPacket )
      if ( rawPacketLen % 2 ) == 1:
         # All bytes need to be specified in full, 2 characters each
         mode.addError( 'Hexadecimal string is an odd length' )
         return False

      packetLen = rawPacketLen / 2
      if packetLen > packetTracerHwStatus.maximumPacketSize:
         mode.addError( 'Raw packet of {} bytes is larger than the maximum '
                        'supported {} bytes'.format(
                           packetLen,
                           packetTracerHwStatus.maximumPacketSize ) )
         return False
      elif packetLen < packetTracerHwStatus.minimumPacketSize:
         mode.addError( 'Raw packet of {} bytes is smaller than the minimum '
                        'supported {} bytes'.format(
                           packetLen,
                           packetTracerHwStatus.minimumPacketSize ) )
         return False

   return True

def checkForMissingFields( mode, requiredFields, treeDict ):
   # Check if any fields are missing, and if so generate a list of them and
   # output them with addError.
   missingFields = ( set( requiredFields ) - set( treeDict.keys() ) )
   if missingFields:
      # Update the fields with their human readable values
      missingFields = [ ArgToLabel[ value ] for value in missingFields ]
      mode.addError( 'Missing field(s): ' + ', '.join( missingFields ) )
      return True
   return False
