#!/usr/bin/env python
# Copyright (c) 2009, 2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function
import netns, socket, os, sys
import Tac
import ctypes
import ctypes.util

class NamespaceType( object ):
   """Holder class for named constants."""
   # From linux/sched.h
   MOUNT = 0x00020000    # CLONE_NEWNS
   UTS = 0x04000000      # CLONE_NEWUTS (allows changing the hostname etc)
   IPC = 0x08000000      # CLONE_NEWIPC
   USER = 0x10000000     # CLONE_NEWUSER
   PID = 0x20000000      # CLONE_NEWPID
   NETWORK = 0x40000000  # CLONE_NEWNET

   ANY = 0  # Used by setns to allow any type of namespace to be joined.
   VALID_TYPES = frozenset( [ MOUNT, UTS, IPC, USER, PID, NETWORK ] )
   ALL = sum( VALID_TYPES )

   @staticmethod
   def procNsFilename( nsType ):
      """Returns the filename under /proc/<pid>/ns for the given namespace type."""
      mapping = {
         NamespaceType.IPC: "ipc",
         NamespaceType.MOUNT: "mnt",
         NamespaceType.NETWORK: "net",
         NamespaceType.PID: "pid",
         NamespaceType.USER: "user",
         NamespaceType.UTS: "uts",
         }
      return mapping[ nsType ]

_setns = ctypes.CDLL( "libnsutil.so.0.0.0" ).setns
_setns.argtypes = ( ctypes.c_int, ctypes.c_int )
_setns.restype = ctypes.c_int

# pylint: disable-msg=W0212
_errnoLocation = ctypes.CDLL( ctypes.util.find_library( "c" ) ).__errno_location
# pylint: enable-msg=W0212
_errnoLocation.restype = ctypes.POINTER( ctypes.c_int )


def getErrno():
   """Return the current value of `errno'."""
   return _errnoLocation()[ 0 ]


def setns( fd, nsType=NamespaceType.ANY ):
   """Changes this process's namespace.

   Args:
     - fd: A file descriptor (integer) referencing the namespace to change to.
     - nsType: The type of namespace to change (see NamespaceType).

   Returns:
     None, on success.

   Raises:
     OSError: if the system call failed.
   """
   rv = _setns( fd, nsType )
   if rv != 0:
      e = getErrno()
      raise OSError( e, os.strerror( e ) )

class RunInOuterNetworkNamespace( object ):
   """Context manager to run Python code in the outer namespace of the
      Artest namespace that gets created when a call is made to 
      Artest.runMeInNetworkNamespace. This can be useful when we need to make
      network connections to the outside world (For ex benchmark db)"""
   def __init__( self ):
      hiddenNetnsPath = '/var/run/hiddenNetns/isolateOuterNetNamespace'
      self.inner = open( '/proc/self/ns/net' )
      self.innerFd = self.inner.fileno()
      if os.path.exists( hiddenNetnsPath ):
         self.outer = open( hiddenNetnsPath )
      else:
         # The pid-1 is the init process for this pid namespace. Typically for
         # programs running in Artest.runMeInNetworknamespace, this would also be
         # the PID of the first process inside the chroot or container. Usually
         # its network namespace would have network reachability to outside world.
         # Exceptions are when we run the programs inside "isolate" utility. That is
         # when we take the "if" code path above.
         self.outer = open( '/proc/1/ns/net' )
      self.outerFd = self.outer.fileno()

   def __enter__( self ):
      if self.outerFd is None:
         return
      try:
         setns( self.outerFd, nsType=NamespaceType.NETWORK )
      except:
         self.__exit__( *sys.exc_info() )
         raise

   def __exit__( self, exception, execType, stackTrace ):
      if self.outerFd is None:
         return
      try:
         setns( self.innerFd, nsType=NamespaceType.NETWORK )
      finally:
         self.inner.close()

def enterNamespace( name ):
   """Take the current process and switch to a new network namespace, with a server
   running"""

   s = socket.socket( socket.AF_UNIX, socket.SOCK_STREAM )
   # pylint: disable=anomalous-backslash-in-string
   s.bind( '\0/netns/%s' % name )
   try:
      netns.unshare( netns.CLONE_NEWNET|netns.CLONE_NEWNS|netns.CLONE_NEWUTS)
   except OSError as e:
      import errno
      if e.errno == errno.EPERM:
         print( "***\n***" )
         print( "Insufficient privileges to create a network namespace." )
         print( "You probably should be using runMeAsRoot() or sudo." )
         print( "***\n***" )
         raise
   import subprocess
   for (t, d) in [("sysfs", "/sys"),("proc", "/proc")]:
      subprocess.call(["umount", d])
      subprocess.call(["mount", "-t", t, "none", d])
   subprocess.call(["ifconfig", "lo", "up"])
   os.environ['NSNAME'] = name
   nspath = os.environ.get('NSPATH', '') + '/' + name
   os.environ['NSPATH'] = nspath
   cpid = os.fork()
   if not cpid:
      SIGTERM = 15
      Tac.setpdeathsig( SIGTERM )
      os.dup2( s.fileno(), 3 )
      # Start
      os.execvp( 'netnsd', ['netnsd', '-f', '', '-s', name] )
      # pylint: disable=protected-access
      os._exit(1)

if __name__ == "__main__":
   if len(sys.argv) < 2:
      print( "usage:", sys.argv[ 0 ], "name" )
   else:
      enterNamespace( sys.argv[1] )
      import signal
      signal.pause()
