#!/usr/bin/env python
# Copyright (c) 2009, 2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

""" Misc. network-related helper functions. """

import Tac
import CliTestMode
import Tracing
from eunuchs.if_h import IFF_UP
import os
import re
import random
import socket

import Arnet
import Arnet.NsLib
from Arnet.NsLib import DEFAULT_NS, runMaybeInNetNs

traceHandle = Tracing.Handle( "ArnetTestLib" )
t0 = traceHandle.trace0

# We allow exception specifications that match any exception, since
# this is test code.
# pylint: disable-msg=W0702

mplsLabelToPrettyString = {
   Tac.Type( 'Arnet::MplsLabel' ).explicitNullIpv4: 'exp-null-v4(0)',
   Tac.Type( 'Arnet::MplsLabel' ).routerAlert: 'router-alert(1)',
   Tac.Type( 'Arnet::MplsLabel' ).explicitNullIpv6: 'exp-null-v6(2)',
   Tac.Type( 'Arnet::MplsLabel' ).implicitNull: 'imp-null(3)',
   Tac.Type( 'Arnet::MplsLabel' ).entropyLabelIndicator : 'entropy-indicator(7)',
   Tac.Type( 'Arnet::MplsLabel' ).null: 'N/A',
   Tac.Type( 'Arnet::MplsLabel' ).any: '(any)',
   Tac.Type( 'Arnet::MplsLabel' ).luWithdraw: '(lu-withdraw)',
   Tac.Type( 'Arnet::MplsLabel' ).entropyLabel : '(entropy)',
}

# Inverse of mplsLabelToPrettyString
mplsPrettyStringToLabel = { v : k for k, v in mplsLabelToPrettyString.iteritems() }

def devExists( intf, ns=DEFAULT_NS ):
   try:
      out = runMaybeInNetNs( ns, [ "ip", "-o", "link", "show", intf ],
                             stdout=Tac.CAPTURE, stderr=Tac.DISCARD )
      return True if re.search( "[<,]UP[,>]", out ) else False
   except:
      return False

def devRecvExpPkt( dev, expectPktStr, showPkts=False, ignoreFlowLbl=False ):
   '''Grab a packet from the tap device and see if it matches an expected
   string.  It is best to create the device with nonblocking mode, and
   then use a Tac.waitFor construction to wait for the expected pkt
   Tac.waitFor( lambda: devMatchRecvPkt( 'Test0', expPkt ) ) '''
   try:
      pkt = dev.recv()
      if pkt and showPkts:
         # Hide import of scapy module from dependency generator, so
         # that scapy does not ship in vEOS.swi (license is not ok).
         try:
            import importlib
            scapyLayersInet = importlib.import_module( 'scapy.layers.inet' )
            scapyLayersInet.Ether( pkt ).show()
         except ImportError, e:
            print e

      if ignoreFlowLbl:
         # this also ignores the upper four bits of the TC because they share a
         # byte with the flow label
         return pkt[:15] == expectPktStr[:15] and pkt[18:] == expectPktStr[18:]
      else:
         return pkt == expectPktStr
   except OSError:
      return False

def devClearPkts( dev ):
   '''This function should only be used in non-blocking mode.  It will
   repeatedly call recv until an exception gets thrown.  This has an
   effect of puring the input queue of the tap device.'''
   while True:
      try:
         dev.recv()
      except OSError:
         break


def enableDevice( devName ):
   Tac.run( [ "ip", "link", "set", devName, "up" ] )

def disableDevice( devName ):
   Tac.run( [ "ip", "link", "set", devName, "down" ] )

def changeDeviceMacAddr( devName, macAddr ):
   disableDevice( devName )
   Tac.run( [ 'ifconfig', devName, 'hw', 'ether',
              macAddr ] )
   enableDevice( devName )
   
def allDevices( ns=DEFAULT_NS ):

   """Returns a list of all the Linux kernel network devices that currently 
   exist in namespace."""
   stmp = runMaybeInNetNs( ns, [ 'ls', '/sys/class/net' ],
                                     asRoot=True, stdout=Tac.CAPTURE )
   sysfsInterfaces = set( stmp.split() )
   return sysfsInterfaces

class DeviceMissingError( Exception ):
   def __init__( self, devNames ):
      self.devNames = devNames

   def __str__( self ):
      return "Missing devices: %s" % self.devNames

def deviceStateCheck( devNames, ns, enabled ):
   """Returns True if the up/down state of all kernel network devices specified
   in devNames matches the 'enabled' parameter passed in"""
   if not hasattr( devNames, '__iter__' ):
      devNames = [ devNames ]
   filename = os.path.join( '/sys/class/net/*/flags' )
   output = runMaybeInNetNs( ns, [ "grep -H 0x %s" % filename ],
                             asRoot=True, stdout=Tac.CAPTURE, shell=True )
   missingDevNames = set( devNames )
   for line in output.splitlines():
      path, sysfsAttribute = [ f.strip() for f in line.split( ':' ) ]
      devName = path.split( '/' )[ 4 ]
      if devName in devNames:
         flags = int( sysfsAttribute, 16 )
         if bool( flags & IFF_UP ) != enabled:
            return False
         missingDevNames.remove( devName )
         if not missingDevNames:
            return True

   if missingDevNames:
      raise DeviceMissingError( missingDevNames )
   else:
      return True

def deviceEnabled( devNames, ns=DEFAULT_NS ):
   return deviceStateCheck( devNames, ns=ns, enabled=True )

def deviceDisabled( devNames, ns=DEFAULT_NS ):
   return deviceStateCheck( devNames, ns=ns, enabled=False )

def checkDeviceMcastAddr( devName, addr, ns=DEFAULT_NS ):
   cmdStr = "ip -6 maddr show dev %s" % devName
   mcastList = runMaybeInNetNs( ns, cmdStr.split(), \
                                asRoot=True, stdout=Tac.CAPTURE ).splitlines() 
   exp = re.compile( '[ \t]+inet6 %s' % addr )
   for l in mcastList:
      if exp.match( l ):
         return True
   return False


def checkDeviceMtu( name, desiredMtu, ns=DEFAULT_NS ):
   """Returns true if the specified Linux kernel network device mtu matches."""
   filename = os.path.join( '/sys/class/net', name, 'mtu' )
   if ns != DEFAULT_NS:
      try:
         sysfsAttribute = runMaybeInNetNs( ns, [ 'cat', filename ],
                                                  stdout=Tac.CAPTURE )
      except Tac.SystemCommandError:
         return False
      try:
         curMtu = int( sysfsAttribute )
         return curMtu == desiredMtu
      except ValueError:
         return False
   else:
      try:
         curMtu = int( file( filename ).read().strip() )
         return curMtu == desiredMtu
      except IOError:
         return False
   
def setDeviceMtu( devName, mtu ):
   Tac.run( [ 'ip', 'link', 'set', devName, 'mtu', str( mtu ) ] )
   
def deviceOperstate( name, ns=DEFAULT_NS ):
   """Returns the operstate of the specified Linux kernel network device."""
   filename = os.path.join( '/sys/class/net', name, 'operstate' )
   sysfsAttribute = runMaybeInNetNs( ns, [ 'cat', filename ],
                                               asRoot=True, stdout=Tac.CAPTURE )
   return sysfsAttribute.strip()

def deviceMacAddr( name, ns=DEFAULT_NS ):
   """Returns the mac address of the specified Linux kernel network device."""
   filename = os.path.join( '/sys/class/net', name, 'address' )
   sysfsAttribute = runMaybeInNetNs( ns, [ 'cat', filename ],
                                               asRoot=True, stdout=Tac.CAPTURE )
   return sysfsAttribute.strip()

def checkDevMcastAddr( devName, addr ):
   cmdStr = "ip -6 maddr show dev %s" % devName
   mcastList = Tac.run( cmdStr.split(), stdout=Tac.CAPTURE ).splitlines()
   exp = re.compile( '[ \t]+inet6 %s' % addr )
   for l in mcastList:
      if exp.match( l ):
         return True
   return False

def socketAt( ns, proto, sockType ):
   """Test wrapper for NsLib.socketAt."""
   return Arnet.NsLib.socketAt( family=proto, type=sockType, ns=ns )

def _isPortUnused( port, sockType, ns=None ):
   s = socketAt( ns, socket.AF_INET, sockType )
   try:
      s.bind( ('localhost', port) )
   except socket.error, e:
      print 'Port ', port, e
      return False
   else:
      s.close()
   return True

def _unusedPortsGenerator():
   'Unused port generator'
   unusedPorts_ = range( 1024, 65535 )
   random.shuffle( unusedPorts_ )
   while unusedPorts_:
      port = unusedPorts_.pop()
      yield port

unusedPorts = _unusedPortsGenerator()

def _pickUnusedPort( ns, socketType=None ):
   while True:
      try:
         port = next( unusedPorts )
      except StopIteration: 
         raise Exception( 'No more unused ports available' )

      if socketType and _isPortUnused( port, socketType, ns ):
         return port
      elif not socketType and \
             _isPortUnused( port, socket.SOCK_STREAM, ns ) and \
             _isPortUnused( port, socket.SOCK_DGRAM, ns ):
         return port

def hopefullyUnusedTcpPort( ns=None ):
   return _pickUnusedPort( ns, socket.SOCK_STREAM )
def hopefullyUnusedUdpPort( ns=None ):
   return _pickUnusedPort( ns, socket.SOCK_DGRAM )
def hopefullyUnusedTcpAndUdpPort( ns=None ):
   return _pickUnusedPort( ns )

class ArnetTestCaseMixin( object ):
   '''Mixin class providing extra assertion methods for
   unittest.TestCase classes that make assertions about network
   concepts such as IP addresses

   This class contains some stub implementations of TestCase methods
   to keep pylint happy, so you should put TestCase ahead of this
   class in the inheritance order::

      class FooBarTest( unittest.TestCase, ArnetTestCaseMixin):
         pass
   '''

   def fail( self, msg ):
      raise NotImplementedError

   def assertIpInSubnet( self, prefix, addr ):
      '''Assert that addr is in the subnet defined by prefix
      
      :param prefix: A string
      :param addr: A string
      '''

      if not Arnet.AddrWithMask( prefix ).contains( Arnet.IpAddress( addr ) ):
         self.fail( '%s is not in %s' % ( addr, prefix ) )

   def assertIpNotInSubnet( self, prefix, addr ):
      '''Assert that addr is not in the subnet defined by prefix.

      :param prefix: A string
      :param addr: A string
      '''

      if Arnet.AddrWithMask( prefix ).contains( Arnet.IpAddress( addr ) ):
         self.fail( '%s is in %s' % ( addr, prefix ) )

# ip6KernelAddrCommon
#     Returns addrDict
#        addrDict:
#           Key: intfname
#           value: intfDict
#        intfDict:
#           Keys: name, flags, mtu, addr
#           intfDict[ 'addr' ] is a list of addrEntryDict
#        addrEntryDict
#           Dictionary of each IPv6 address configured on the interface.
#           Keys: addr, scope, dynamic, validLft, prefLft
def ip6KernelAddrCommon( ip6AddrCmdOutput ):
   addrDict = {}
   intfDict = {}
   addrEntryDict = {}
   addrLftLineExpected = False
   for line in ip6AddrCmdOutput:
      if not line:
         continue
      if addrLftLineExpected:
         # We should have a NON-Empty addrEntryDict
         assert addrEntryDict
         addrLftLinePattern = \
            ' +valid_lft (?P<validLft>((?P<validLftVal>[0-9]+)sec)|(forever))' \
            ' preferred_lft (?P<prefLft>((?P<prefLftVal>[0-9]+)sec)|(forever))'
         match = re.match( addrLftLinePattern, line )
         assert match
         validLftVal = match.group( 'validLftVal' )
         if not validLftVal:
            validLftVal = match.group( 'validLft' )
         assert validLftVal
         addrEntryDict[ 'validLft' ] = validLftVal
         prefLftVal = match.group( 'prefLftVal' )
         if not prefLftVal:
            prefLftVal = match.group( 'prefLft' )
         assert prefLftVal
         addrEntryDict[ 'prefLft' ] = prefLftVal
         # Add the addrEntryDict to list of addresses on the interface.
         intfDict[ 'addr' ].append( addrEntryDict )
         addrEntryDict = {}
         addrLftLineExpected = False
      else:
         # Check if this is Address Line Pattern:
         #  inet6 2000:1:1:1:5478:bfff:fe16:3c99/64 scope global dynamic
         # Here, the keyword 'dynamic' is optional
         #  Scope can be global or local
         addrLinePattern = \
            ' +inet6 (?P<addr>[0-9,a-f,:]+/[0-9]+) (scope (?P<scope>\S+))' \
            ' (?P<dynamic>dynamic)?'
         match = re.match( addrLinePattern, line )
         if match:
            assert intfDict and not addrEntryDict
            addrEntryDict[ 'addr' ] = match.group( 'addr' )
            addrEntryDict[ 'scope' ] = match.group( 'scope' )
            if match.group( 'dynamic' ):
               addrEntryDict[ 'dynamic' ] = True
            addrLftLineExpected = True
         else:
            # Must be an interface line
            # Interface line Pattern:
            #  5: macvlan-bond0@bond0: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 9000
            intfLinePattern = \
               '\d+: (?P<name>[\S^:]+): <(?P<flags>[\S^>]+)> (mtu (?P<mtu>\d+))?'
            match = re.match( intfLinePattern, line )
            assert match
            if intfDict:
               addrDict[ intfDict[ 'name' ] ] = intfDict
               intfDict = {}
            intfDict[ 'name' ] = match.group( 'name' )
            intfDict[ 'flags' ] = match.group( 'flags' )
            intfDict[ 'mtu' ] = match.group( 'mtu' )
            intfDict[ 'addr' ] = []
   if intfDict:
      addrDict[ intfDict[ 'name' ] ] = intfDict
   return addrDict
 
def randomIp( ):
   '''
   :rtype: Arnet::IpAddrWithFullMask
   '''
   ip = random.getrandbits( 32 )
   mask = random.getrandbits( 32 )
   ipaddr = "%d.%d.%d.%d" % ( ip >> 24, (ip >> 16) & 0xff,
                              ( ip >> 8 ) & 0xff, ip & 0xff)
   return Arnet.AddrWithFullMask( ipaddr, mask )

def randomIp6( prefixLen=0, maxPrefixLen=128 ):
   '''
   :rtype: Arnet::Ip6AddrWithMask
   '''
   prefix = random.randint( 8, maxPrefixLen ) if ( prefixLen == 0 ) else prefixLen
   prefixBytes = prefix / 8
   prefixBits = prefix % 8
   ipAddr = ''
   for i in range( 0, 16 ):
      if i < prefixBytes:
         ipAddr += "%.2x" % random.getrandbits( 8 )
      elif prefixBits:
         word = random.getrandbits( prefixBits )
         ipAddr += "%.2x" % ( ~word & 0xff )
         prefixBits = 0
      else:
         ipAddr += "00" 
      if i < 15 and ( i % 2 ):
         ipAddr += ':'
   ipAddr += "/%d" % prefix 
   return Arnet.Ip6AddrWithMask( ipAddr )

def randomPrefixIp( ):
   '''
   :rtype: string
   '''
   ip = random.getrandbits( 32 )
   masklen = random.randint( 0, 32 )
   ip = ip & ~( ( 1 << ( 32 - masklen ) ) - 1 )
   return "%d.%d.%d.%d/%d" % ( ip >> 24, ( ip >> 16 ) & 0xFF,
                               ( ip >> 8 ) & 0xFF, ip & 0xFF,
                               masklen )

def ip6FullMaskFromPrefix( prefixLen ):
   binaryString = "1" * prefixLen + "0" * ( 128 - prefixLen )
   hexString = format( int( binaryString, 2 ), '032x' )
   return ':'.join( re.findall( '.{4}', hexString ) )

def ip6PrefixLength( mask ):
   hexString = socket.inet_pton( socket.AF_INET6, mask ).encode( 'hex' )
   binaryString = format( int( hexString, 16 ), '0128b' )
   assert len( binaryString ) == 128
   m = re.match( '(1*)0*$', binaryString )
   if m:
      return len( m.group( 1 ) )
   else:
      return Tac.Type( 'Arnet::Ip6PrefixLen' ).invalid

def randomIp6Addr( ):
   return ":".join( ( "%x" % random.randint( 0, 16 ** 4 - 1 ) \
         for _ in range( 8 ) ) )

def ipv6CanonicalAddr( addr ):
   tmp = socket.inet_pton( socket.AF_INET6, addr )
   return socket.inet_ntop( socket.AF_INET6, tmp )

class TcpSession( object ):
   """
   This class defines a simple interface for starting and stopping a TCP session. The
   "server" side is a netcat-based echo server (/bin/cat). The "client" simply
   connects to the server and sends 10 KB of 0's. The session remains open until
   stop() is called.

   This class is used to validate MSS clamping configuration (see aid/7371), and so
   it exposes socket stats (via 'ss') for the client and server MSS values (i.e.
   tp->ack.rcv_mss value).
   """
   def __init__( self, client, clientIp, server, serverIp, listenPort ):
      self.client = client
      self.server = server
      self.clientIp = clientIp
      self.serverIp = serverIp
      self.listenPort = listenPort
      self.serverPid = 0

      # Use a number of CLIs for different devices and modes
      self.serverCli = self.server.newCli()
      self.serverCli.gotoMode( CliTestMode.bashSuMode )
      self.clientCli = self.client.newCli()
      self.clientCli.gotoMode( CliTestMode.bashSuMode )
      self.pyshell = self.client.newCli()
      self.pyshell.gotoMode( CliTestMode.pythonMode )

      # NOTE: The scripts are deliberately chosen to accomplish 2 goals:
      #
      # 1) Send data: This is only so that the Linux kernel updates the
      #    tp->ack.rcv_mss value (see
      #    https://elixir.bootlin.com/linux/latest/ident/tcp_measure_rcv_mss) for
      #    both the server and client. This is shown by 'ss' and is used to verify
      #    the feature is working as intended.
      # 2) The TCP session persists until killed: This is only to eliminate any race
      #    conditions querying the session state. This was impossible to achieve with
      #    netcat alone, and so a custom Python TCP client is used.
      self.serverScript = 'nc -k -l {port} -e ' \
                          '/bin/cat'.format( port=self.listenPort )
      self.clientScript = [
            "import socket",
            "sock = socket.socket( socket.AF_INET, socket.SOCK_STREAM )",
            "sock.connect( ( '{addr}', {port} ) )".format( addr=self.serverIp,
                                                           port=self.listenPort ),
            "data = bytearray( [ 0 ] * {dataSize} )".format( dataSize=10000 ),
            "sock.sendall( data )",
            ]

   def _runPythonScript( self, script ):
      for cmd in script:
         t0( "Running '%s' on DUT" % cmd )
         # If pyshell has changed modes, the reference to 'sock' is lost.
         assert self.pyshell.mode.name == 'python'
         output = self.pyshell.runCmd( cmd )
         assert output == cmd

   def _sessionUp( self, cli ):
      cmd = 'netstat --tcp -a -n -4 | grep ":{port} "'.format( port=self.listenPort )
      return any( 'LISTEN' in line or 'ESTABLISHED' in line for line in
                  cli.runCmd( cmd, ignoreOutputErrors=True ).splitlines() )

   @property
   def serverUp( self ):
      t0( 'Checking if server is listening' )
      return self._sessionUp( self.serverCli )

   def _startServer( self ):
      """
      Set up netcat to listen, -l, and stay alive, -k, on 8080, simply echo'ing
      back anything the TCP client sends.
      """
      t0( 'Run netcat to listen for incoming TCP sessions' )
      cmd = "%s &" % self.serverScript
      self.serverCli.runCmd( cmd )
      self.serverPid = self.serverCli.runCmd( "echo $!", omitEcho=True )
      Tac.waitFor( lambda: self.serverUp )

   def _stopServer( self ):
      t0( 'Kill server running on %s' % self.server )
      self.serverCli.runCmd( "kill -9 %s" % self.serverPid )
      Tac.waitFor( lambda: not self.serverUp )

   @property
   def _clientUp( self ):
      t0( 'Checking if client is running' )
      return self._sessionUp( self.clientCli )

   def _startClient( self ):
      assert self.serverUp
      assert not self._clientUp
      t0( 'connect() TCP socket to %s' % self.server )
      self._runPythonScript( self.clientScript )
      Tac.waitFor( lambda: self._clientUp )

   def _stopClient( self ):
      t0( 'close() TCP socket to %s' % self.server )
      self._runPythonScript( [ "sock.close()" ] )
      Tac.waitFor( lambda: not self._clientUp )

   def _sockStat( self, dev, stat ):
      assert self.serverUp
      assert self._clientUp

      if dev == self.client:
         output = self.client.tcpSocketStatistics( self.clientIp, self.serverIp,
                                                   dstPort=str( self.listenPort ) )
      else:
         output = self.server.tcpSocketStatistics( self.serverIp, self.clientIp,
                                                   srcPort=str( self.listenPort ) )
      ss = [ sockStat for sockStat in output if sockStat[ 'state' ] == 'ESTAB' ]
      assert len( ss ) == 1
      ss = ss[ 0 ]
      assert 'tcpinfo' in ss
      return int( ss[ 'tcpinfo' ].get( stat ) )

   def clientMss( self ):
      return self._sockStat( self.client, 'mss' )

   def serverMss( self ):
      return self._sockStat( self.server, 'mss' )

   def start( self ):
      self._startServer()
      self._startClient()

   def stop( self ):
      self._stopClient()
      self._stopServer()
