# Copyright (c) 2008-2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import os
import subprocess
import sys

import Tac
import Tracing

from Arnet import NsLib
from Arnet.NsLib import DEFAULT_NS, NamespaceType, setns
import KernelVersion
# pkgdeps: import ArpTestLib (forced dependency, see AID 10)

t0 = Tracing.trace0
t1 = Tracing.trace1

defaultVrfName = Tac.Type( "L3::VrfName" ).defaultVrf

# netns constants
NETNS_RUN_DIR = "/var/run/netns"
NETNS_ETC_DIR = "/etc/netns"
PROC_SELF_NS_NET_FILE = "/proc/self/ns/net"

def vrfSupported():
   """Check if the kernel has the netns patches to support VRFs"""
   return os.path.exists( PROC_SELF_NS_NET_FILE )

def nsNameFromVrfName( vrfName ):
   """Constructs a namespace given a VRF Name"""
   assert vrfName != ''
   if not vrfName or vrfName == defaultVrfName:
      return DEFAULT_NS
   return "ns-" + vrfName

def createKernelNetNsInfoManager():
   root = Tac.newInstance( "KernelNetNsInfo::Root", "root" )
   knim = Tac.newInstance( 'KernelNetNsInfo::Manager', root )
   return ( root, knim )

def mountDefaultNetNs():
   """Bind mount the /proc/self/ns/net file to the default netNs file
      handle"""
   defaultNetNsPath = NETNS_RUN_DIR + "/" + DEFAULT_NS
   preExistNslist = netNsList()
   if DEFAULT_NS not in preExistNslist:
      addNetNs( DEFAULT_NS )
   Tac.run( [ "mount", "--bind", PROC_SELF_NS_NET_FILE, 
              defaultNetNsPath ], asRoot=True )

def addNetNs( netNsName ):
   """Create a new network namespace"""
   cmd = [ "ip", "netns", "add", netNsName ]
   t0( "Running", cmd )
   Tac.run( cmd, asRoot=True )
   Tac.waitFor( lambda: netNsExists( netNsName ), timeout=60,
                description="netns %s to be added" )
   if netNsName != DEFAULT_NS:
      NsLib.runMaybeInNetNs( netNsName, [ 'ifconfig', 'lo', '127.0.0.1', 'up' ] )


      # this sysctl node gets reset to the default value for every namespace.
      # the value set here should match the value given in /etc/sysctl.d/Arora.conf
      # see BUG209248
      if KernelVersion.supports( KernelVersion.SET_IGMP_MAX_MEM ):
         NsLib.runMaybeInNetNs( netNsName, \
                                [ 'sysctl', 'net.ipv4.igmp_max_memberships=65535' ] )

def deleteNetNsNoWait( netNsName, allowBusy=False ):
   """Destroy the given network namespace"""
   cmd = [ "ip", "netns", "delete", netNsName ]
   t0( "Running", cmd )
   try:
      Tac.run( cmd, asRoot=True, stdout=Tac.CAPTURE, stderr=Tac.CAPTURE )
      return True
   except Tac.SystemCommandError, e:
      t1( "SystemCommandError:", e.output )
      if allowBusy and 'Device or resource busy' in e.output:
         return False
      else:
         raise

def deleteNetNs( netNsName, allowBusy=True ):
   if allowBusy:
      Tac.waitFor( lambda: deleteNetNsNoWait( netNsName, allowBusy=allowBusy ),
                   description="netns %s to be deleted" % netNsName )
   else:
      deleteNetNsNoWait( netNsName )
   Tac.waitFor( lambda: not netNsExists( netNsName ),
                description="netns %s to be deleted" % netNsName )

def netNsList():
   """Return a list of existing network namespaces"""
   cmd = [ "ip", "netns", "list" ]
   t0( "Running", cmd )
   nsList = Tac.run( cmd, asRoot=True, stdout=Tac.CAPTURE )
   return nsList.split()

def netNsExists( name ):
   """Verify if given network namespace exists"""
   return name in netNsList()

def netNsReady( ns ):
   """Verify that the network namespace is ready"""
   try:
      NsLib.runMaybeInNetNs( ns, [ 'ifconfig', 'lo' ] )
      return True
   except Tac.SystemCommandError:
      return False

def _iNodeFromNetNs( netNs ):
   try:
      iNode = Tac.run( [ 'stat', '--format=%i', '%s/%s' % ( NETNS_RUN_DIR, netNs ) ],
                       stdout=Tac.CAPTURE )
   except Tac.SystemCommandError, e:
      t1( "Failed to stat inode for netNs %s" % netNs )
      t1( e )
      raise
   try:
      return int( iNode )
   except ValueError:
      t1( "Unable to convert inode for ns %s to int (%s)" % ( netNs, iNode ) )
      raise

def _iNodeFromPid( pid ):
   try:
      iNode = Tac.run( [ 'stat', '-L', '--format=%i', '/proc/%s/ns/net' % pid ],
                         stdout=Tac.CAPTURE )
   except Tac.SystemCommandError, e:
      t1( "Failed to stat inode for pid %d" % pid )
      t1( e )
      raise
   try:
      return int( iNode )
   except ValueError:
      t1( "Unable to convert inode for pid %d to int (%s)" % ( pid, iNode ) )
      raise

def processExistsInNetNs( processName, netNs ):
   """Check that a process is running in a given namespace"""
   pidFile = '/var/run/%s.pid' % processName
   if not os.path.exists( pidFile ):
      t1( pidFile, "does not exist" )
      return False
   try:
      pid = Tac.run( [ 'cat', pidFile ], asRoot=True, stdout=Tac.CAPTURE ).rstrip()
   except Tac.SystemCommandError, e:
      t1( e )
      return False
   try:
      pid = int( pid )
   except ValueError:
      t1( "Unable to convert %s to int (from %s)" % ( pid, pidFile ) )
      return False
   try:
      iNodeFromPid = _iNodeFromPid( pid )
      iNodeFromNetNs = _iNodeFromNetNs( netNs )
      if iNodeFromPid != iNodeFromNetNs:
         t1( "pid %d is at inode %d" % ( pid, iNodeFromPid ) )
         t1( "ns %s is at inode %d" % ( netNs, iNodeFromPid ) )
         return False
      else:
         return True
   except ( Tac.SystemCommandError, ValueError ):
      return False

def moveIntfNetNs( intfName, netNs ):
   """netNs is the destination network namespace.
      It is assumed that the intfName device is present in netns
      context of the caller"""
   cmd = [ "ifconfig", intfName ]
   try:
      Tac.run( cmd, asRoot=True, stdout=Tac.DISCARD )
   except Tac.SystemCommandError:
      t0( "%s intf device not found" % intfName )
      raise
   cmd = [ "ip", "link", "set", intfName, "netns", netNs ]
   Tac.run( cmd, asRoot=True )

def netNsBashCmd( netNs, execCmd ):
   """This is a helper function to convert
      a command into a 'ip netns exec' bash
      command that can be run inside a network namespace
      If the netNs argument is "" or DEFAULT_NS, it simply
      returns the command as is by prefixing 'bash'"""
   if netNs == '' or netNs == DEFAULT_NS:
      return 'bash ' + execCmd
   else:
      return ( 'bash sudo ip netns exec %s ' % netNs ) + execCmd

# Helper class to run test functions only if netNs capable (i.e the kernel
# can access PROC_SELF_NS_NET_FILE and so is capable of mounting an
# fd to the network namespace. This also cleans up ns left behind by
# previous tests and attempts to make a clean start. It creates required
# namespaces and cleans up after the test is run
class NetNsBreadthTestDriver:
   def __init__( self, createNsList=None, tests=None, noAutoDefault=False,
                 *testArgs, **testKwargs ):

      try:
         assert( isinstance( createNsList, list ) )
         assert( isinstance( tests, list ) )
         self.createNsList = createNsList
         self.noAutoDefault = noAutoDefault
         if os.path.exists( PROC_SELF_NS_NET_FILE ):
            # delete old namespaces
            preExistNslist = netNsList()
            for netNs in self.createNsList:
               if netNs in preExistNslist and netNs != DEFAULT_NS:
                  deleteNetNs( netNs )
   
            if DEFAULT_NS in preExistNslist and not noAutoDefault:
               deleteNetNs( DEFAULT_NS )
   
            # Create required namespaces
            mountDefaultNetNs()
            for netNs in self.createNsList:
               if netNs != DEFAULT_NS:
                  addNetNs( netNs )
   
            # Now invoke the tests
            for _test in tests:
               _test( *testArgs, **testKwargs )
      finally:
         self.cleanup()

   def cleanup( self ):
      if os.path.exists( PROC_SELF_NS_NET_FILE ):
         # Delete the namespaces we created
         preExistNslist = netNsList()
         for netNs in self.createNsList:
            if netNs in preExistNslist and netNs != DEFAULT_NS:
               Tac.waitFor( lambda: deleteNetNsNoWait( netNs, allowBusy=True ),
                            description="netns %s to be deleted" % netNs )
               Tac.waitFor( lambda: not netNsExists( netNs ), timeout=60,
                            description="netns %s to disappear" % netNs )

         if DEFAULT_NS in preExistNslist and not self.noAutoDefault:
            deleteNetNs( DEFAULT_NS )

# The following functions are used to access namespaces created with netnsd,
# like in Namespace Duts.
def pidOfNamespace( namespaceName ):
   """Returns the PID of the given namespace.

   For example: pidOfNamespace( "host42" ) -> 5169

   Raises:
      - ValueError: if the given namespace name doesn't exist or we failed
        to communicate with its netnsd.
   """
   netns = subprocess.Popen( [ "sudo", "netns", "-q", namespaceName ],
                             stdout=subprocess.PIPE, stderr=subprocess.PIPE )
   # Pylint is confused about life here, it thinks communicate() returned 2 lists.
   # pylint: disable-msg=E1103
   pid, stderr = netns.communicate()
   if netns.returncode:
      msg = ( "Failed to get the PID of netnsd for %r: %s"
              % ( namespaceName, stderr.strip() ) )
      if "Connection refused" in stderr:
         msg += " -- does the namespace exist?"
      raise ValueError( msg )
   pid = int( pid.strip() )
   assert pid > 0, "Invalid PID returned by netnsd: %d" % pid
   return pid

class RunInNamespace( object ):
   """Context manager to run Python code in a given namespace.

   Args:
      - namespaceName: The name of the namespace according to netnsd.
      - nsTypes: The type of namespace to enter (see NamespaceType).
        If `None', then this context manager will be a no-op.

   Here are some examples:

      with RunInNamespace( "host42", NamespaceType.NETWORK ):
         sock = socket.socket()
         # now 'sock' was created inside of host42's network namespace.
         # (see also `socketAt' for an easier way to do this)

      with RunInNamespace( "host42", NamespaceType.PID ):
         child = subprocess.Popen( ... )
         # now 'child' runs inside of the PID namespace of host42.

      with RunInNamespace( "host42", NamespaceType.NETWORK | NamespaceType.PID ):
         # Combines both previous capabilities into one.

   Raises:
      - ValueError: if the given namespace name doesn't exist or we failed
        to communicate with its netnsd.
      - OSError: we failed to enter the given namespace, possibly due to a
        permission problem or the namespace going away right at the point we
        tried to enter it.
   """

   def __init__( self, namespaceName, nsTypes ):
      self.namespaceName = namespaceName
      if nsTypes is not None:
         assert nsTypes & NamespaceType.ALL != 0, ( "Invalid namespace type: 0x%x"
                                                    % nsTypes )
      self.nsTypes = nsTypes
      # Maps a NamespaceType to the fd of the namespace we came from, so we
      # can go back into that namespace upon exiting the context manager.
      self.origNsFd = {}

   def __enter__( self ):
      if self.nsTypes is None:
         return
      netnsdPid = pidOfNamespace( self.namespaceName )
      # We must first open all our FDs before calling setns() in case
      # we're about to enter a mount namespace, because after it'll be
      # too late, as /proc will be quite different then.
      newNsFd = {}  # Maps a namespace type to the fd to switch to.
      try:
         # First open.
         for nsType in NamespaceType.VALID_TYPES:
            if nsType & self.nsTypes:
               typeName = NamespaceType.procNsFilename( nsType )
               newNsFd[ nsType ] = open( "/proc/%d/ns/%s" % ( netnsdPid, typeName ) )
               self.origNsFd[ nsType ] = open( "/proc/self/ns/%s" % typeName )
         # Then switch.
         for nsType, fd in newNsFd.iteritems():
            typeName = NamespaceType.procNsFilename( nsType )
            t0( "Entering", typeName, "namespace", self.namespaceName )
            setns( fd.fileno(), nsType=nsType )
      except:
         # If we failed to enter one of the namespaces, we back out here:
         self.__exit__( *sys.exc_info() )
         raise
      finally:
         for fd in newNsFd.itervalues():
            fd.close()

   def __exit__( self, exception, execType, stackTrace ):
      try:
         for nsType, fd in self.origNsFd.iteritems():
            setns( fd.fileno(), nsType=nsType )
            fd.close()
      finally:
         self.origNsFd.clear()
