# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
'''
Dot1xWebAgentLib.py has the core of the agent that handles dot1x web authentication
'''

import Agent
import Ark
import Tac
import QuickTrace
import SocketServer
import BaseHTTPServer
import threading
import SharedMem
import Smash
from multiprocessing.pool import ThreadPool
from Arnet.NsLib import socketAt, DEFAULT_NS
import ssl
import Toggles.Dot1xToggleLib as Dot1xToggle
import HttpServiceSsl
from GenericReactor import GenericReactor
from Dot1xWebL2ForwarderLib import WebAuthNs, WebAuthWireIntfName
from Dot1xWebL2ForwarderLib import Dot1xL2Forwarder

warn = QuickTrace.trace0
trace = QuickTrace.trace1
bv = QuickTrace.Var

PrivateTcpPorts = Tac.Type( "Arnet::PrivateTcpPorts" )

def protoName( https ):
   return 'https' if https else 'http'

def thId():
   '''Shorthand for getting the current thread's name'''
   # It's worth pointing out that the natural "pythonic" solution here would be to
   # wrap trace* in a function that added the thread name as the first argument - BUT
   # that doesn't work, because trace* does some tricks with the frames.
   # Details: /src/Ark/BothTrace.py
   return threading.current_thread().name

class Dot1xHttpServer( BaseHTTPServer.HTTPServer, Tac.Notifiee ):
   '''
   Dot1xHttpServer: server class that adapts python's BaseHTTPServer.HTTPServer to
   work with a listening socket that is registered in tacc's activity loop and to use
   a thread pool instead of handling the requests inline.

   We also prevent instances from resolving host names.
   '''
   notifierTypeName = 'Tac::FileDescriptor'

   def __init__( self, agent=None, threadPool=None, https=False ):
      '''
      Replace the constructor of BaseHTTPServer.HTTPServer to register the listening
      socket FD in tacc so that we can handle new connections in the main thread.

      As we are replacing the constructor, might as well move some other
      initializations here:
      - HTTPS socket wrapping if applicable
      - ThreadPool
      - Hard-code server_address and HttpReqHandler

      References:
      - Parent BaseHTTPServer.HTTPServer:
        https://github.com/python/cpython/blob/2.7/Lib/BaseHTTPServer.py#L114
      - SocketServer.TCPServer
        https://github.com/python/cpython/blob/2.7/Lib/SocketServer.py#L161
      - multiprocessing doc (ThreadPool):
        https://docs.python.org/2/library/multiprocessing.html
      '''
      self.https = https
      server_address = agent.getIpPort( https=self.https )
      # The following code comes from TCPServer.__init__
      # We are importing it to replace the socket creation, because we need it to
      # be created inside a WebAuthNs
      # Note: TCPServer is our grandparent; our parent, BaseHTTPServer.HTTPServer,
      # doesn't have a constructor.
      # Refs:
      # https://github.com/python/cpython/blob/2.7/Lib/BaseHTTPServer.py#L102
      # https://github.com/python/cpython/blob/2.7/Lib/SocketServer.py#L413
      # {{{
      # pylint: disable=non-parent-init-called
      SocketServer.BaseServer.__init__( self, server_address, HttpReqHandler )
      self.socket = socketAt( self.address_family, self.socket_type, ns=WebAuthNs )
      try:
         self.server_bind()
         self.server_activate()
      except:
         self.server_close()
         raise
      # }}}
      self.allow_reuse_address = True
      if self.https:
         certfile = HttpServiceSsl.getCertFilepath()
         keyfile = HttpServiceSsl.getKeyFilepath()
         self.socket = ssl.wrap_socket( self.socket, certfile=certfile,
                                        keyfile=keyfile, server_side=True )
      self.threadPool = threadPool
      self.agent = agent
      # Initialize tac reactor aspect:
      self.tacFd = Tac.newInstance( 'Tac::FileDescriptor',
                                    'fd%d' % ( self.fileno() ) )
      self.tacFd.nonBlocking = True
      self.tacFd.notificationInterface = 'levelTriggered'
      self.tacFd.notifyOnReadable = True
      self.tacFd.descriptor = self.fileno()
      Tac.Notifiee.__init__( self, self.tacFd )

   def server_bind( self ):
      '''
      Replace HTTPServer.server_bind to prevent instances from trying to resolve host
      names

      Original function:
      https://github.com/python/cpython/blob/2.7/Lib/BaseHTTPServer.py#L106
      '''
      SocketServer.TCPServer.server_bind( self )
      host, port = self.socket.getsockname()[ : 2 ]
      self.server_name = host
      self.server_port = port

   def process_request_in_pool( self, request, client_address ):
      '''
      Method executed in a secondary thread from the pool to handle a request.

      Essentially a copy of python's ThreadingMixIn.process_request_thread:
      https://github.com/python/cpython/blob/2.7/Lib/SocketServer.py#L585
      '''
      trace( thId(), 'process_request_in_pool', bv( client_address ), 'start' )
      try:
         self.finish_request( request, client_address )
         self.shutdown_request( request )
      # pylint: disable=broad-except
      except Exception as e:
         warn( thId(), 'handler got exception', bv( str( e ) ) )
         self.handle_error( request, client_address )
         self.shutdown_request( request )
      finally:
         trace( thId(), 'process_request_in_pool', bv( client_address ), 'done' )

   def process_request( self, request, client_address ):
      '''
      Use the thread pool to handle incoming requests

      Original:
      https://github.com/python/cpython/blob/2.7/Lib/SocketServer.py#L315

      This is equivalent to what SocketServer.ThreadingMixin does:
      https://github.com/python/cpython/blob/2.7/Lib/SocketServer.py#L685

      It's unfortunate that SocketServer doesnt' offer a ThreadPoolMixin.
      '''
      trace( thId(), 'process_request', bv( client_address ), 'scheduling' )
      self.threadPool.apply_async( self.process_request_in_pool,
                                   ( request, client_address ) )
      trace( thId(), 'process_request', bv( client_address ), 'scheduled' )

   @Tac.handler( 'readableCount' )
   def handleReadableCount( self ):
      '''
      Tac reactor to new connections in the listening socket. Called from tac's
      activity loop in the main thread.

      It calls SocketServer.handle_request:
      https://github.com/python/cpython/blob/2.7/Lib/SocketServer.py#L251

      handle_request in turn:
      That gets the following steps executed:
      - handle_request checks timeouts, does a select: harmless, we got here because
        we already have a connection
      - _handle_request_noblock
      - self.get_request, which is a simple "return self.socket.accept()"
        https://github.com/python/cpython/blob/2.7/Lib/SocketServer.py#L461
      - self.verify_request, which is just a "return True"
        https://github.com/python/cpython/blob/2.7/Lib/SocketServer.py#L307
      - self.process_request, defined above.
      '''
      trace( 'Dot1xHttpServer.handleReadableCount got new connection' )
      self.handle_request()
      # ^ Ends up calling process_request, see python's upstream at
      # https://github.com/python/cpython/blob/2.7/Lib/SocketServer.py#L251

   def shutdown( self ):
      trace( 'Dot1xHttpServer.shutdown: calling server_close' )
      self.server_close()
      trace( 'Dot1xHttpServer.shutdown: done' )

class HttpReqHandler( BaseHTTPServer.BaseHTTPRequestHandler ):
   """Our HTTP request handler

   Socket handling is done by functions in the
   BaseHTTPServer.BaseHTTPRequestHandler base class.
   The read() done in that class, that gets the request proper, blocks - and this is
   the reason we chose to leave Tac's activity loop in the main thread and handle
   the HTTP(S) requests via thread pool"""

   timeout = 60
   # ^ Timeout for reading the client request - copied from nginx:
   # http://nginx.org/en/docs/http/ngx_http_core_module.html#client_header_timeout

   def __init__( self, request, client_address, server ):
      trace( thId(), 'Http Request handler init entry' )
      self.mac = self.getClientMac( server, client_address[ 0 ] )
      self.cp = self.getCaptivePortal( server, self.mac )
      BaseHTTPServer.BaseHTTPRequestHandler.__init__( self, request,
                                                      client_address,
                                                      server )

   def address_string( self ):
      '''
      Replace BaseHTTPRequestHandler.address_string to prevent instances from
      trying to resolve host names

      Original function:
      https://github.com/python/cpython/blob/2.7/Lib/BaseHTTPServer.py#L500
      '''
      host, _ = self.client_address[ : 2 ]
      return host

   @Tac.withActivityLock
   def getClientMac( self, server, clientIp ):
      threadId = threading.current_thread().name
      trace( thId(), 'getClientMac', bv( threadId ) )
      mac = server.agent.dot1xL2Forwarder.getMac( clientIp )
      if not mac:
         return ""
      return mac.src

   @Tac.withActivityLock
   def getCaptivePortal( self, server, mac ):
      '''
      Get the captive portal config from sysdb

      Decorate withActivityLock so that we don't read from sysdb while the main
      thread is updating the value.
      '''
      dot1xStatus = server.agent.dot1xStatus
      if mac == "":
         trace( 'getCaptivePortal : no mac ', bv( mac ) )
         return dot1xStatus.captivePortal
      trace( thId(), 'getCaptivePortal for', bv( mac ) )
      for intf in dot1xStatus.dot1xIntfStatus:
         dis = dot1xStatus.dot1xIntfStatus[ intf ]
         if mac in dis.supplicant:
            supp = dis.supplicant[ mac ]
            trace( 'getCaptivePortal : got ', bv( supp.captivePortal ) )
            return supp.captivePortal
      return dot1xStatus.captivePortal

   def do_GET( self ):
      """Invoked for every HTTP GET request. Defined by the base class."""
      trace( thId(), 'do_GET entry' )
      try:
         # this grabs the activity lock
         trace( thId(), 'do_GET : doing a redirect to', bv( self.cp ) )
         # temporary redirection code 302
         self.send_response( 302 )
         self.send_header( "Location", self.cp )
         self.end_headers()

      # pylint: disable=broad-except
      except Exception as e:
         # internal server error
         warn( thId(), 'do_GET in exception', bv( e ) )
         self.send_response( 500 )
      trace( thId(), 'do_GET exit' )

   # Ignore pylint warning caused by the "format" argument:
   # pylint: disable=redefined-builtin
   def log_message( self, format, *args ):
      '''
      Replace BaseHTTPRequestHandler.log_message to redirect messages to trace

      Ref: https://github.com/python/cpython/blob/2.7/Lib/BaseHTTPServer.py#L449
      '''
      msg = format % args
      trace( thId(), self.client_address[ 0 ], msg )

   def log_error( self, format, *args ):
      self.log_message( format, *args )
      msg = format % args
      # Send out a redirection if we are timing out the client:
      if msg.startswith( 'Request timed out' ):
         # Just send a hard-coded redirection
         self.wfile.write( 'HTTP/1.0 302 Found\r\n'
                           'Location: %s\r\n\r\n' % self.cp )
         trace( thId(), 'log_error timeout, sent redirection' )


class Dot1xWeb( Agent.Agent ):
   def __init__( self, entityManager ):
      trace( 'Dot1x Web Agent init entry' )
      self.agentName = name()
      Agent.Agent.__init__( self, entityManager, agentName=self.agentName )
      self.dot1xConfig = None
      self.dot1xStatus = None
      self.webAgentStatus = None
      self.httpServers = {}
      self.warm_ = None
      self.arpStatus = None
      self.cpReactor = None
      self.intfReactor = None
      self.kniStatus = None
      self.threadPool = None
      self.sysname = entityManager.sysname()
      self.sEm = SharedMem.entityManager( self.sysname, entityManager.local() )
      self.dot1xL2Forwarder = None
      trace( 'Dot1x Web Agent init exit' )

   def doInit( self, entityManager ):
      trace( thId(), 'doInit entry' )
      mg = entityManager.mountGroup()
      Ark.configureLogManager( self.agentName )
      self.dot1xConfig = mg.mount( 'dot1x/config', 'Dot1x::Config', 'r' )
      self.dot1xStatus = mg.mount( 'dot1x/status', 'Dot1x::Status', 'r' )
      self.webAgentStatus = mg.mount( 'dot1x/webAgentStatus',
                                      'Dot1x::WebAgentStatus', 'w' )
      shMemEm = SharedMem.entityManager( sysdbEm=entityManager )
      self.kniStatus = shMemEm.doMount( "kni/ns/%s/status" % DEFAULT_NS,
                                        "KernelNetInfo::Status",
                                        Smash.mountInfo( 'keyshadow' ) )

      def mountDone():
         trace( thId(), 'mountDone entry' )
         trace( "Captive portal data is ", bv( self.dot1xConfig.captivePortal ) )
         self.cpReactor = GenericReactor( self.dot1xConfig, [ 'captivePortal' ],
                                          self.handleCaptivePortal )
         self.intfReactor = GenericReactor( self.kniStatus, [ 'interface' ],
                                            self.handleIntf )
         self.webAgentStatus.running = False
         # Call maybeCreateL2Forwarder - it calls maybeCreateServers if
         # appropriate
         self.maybeCreateL2Forwarder()
         self.warm_ = True
         trace( 'mountDone exit' )

      trace( 'closing mount group' )
      mg.close( mountDone )
      trace( 'doInit exit' )

   # no sysdb read or write
   def warm( self ):
      return self.warm_

   def getIpPort( self, https ):
      return ( '0.0.0.0',
               PrivateTcpPorts.dot1xWebHttpsPort if https
               else PrivateTcpPorts.dot1xWebHttpPort )

   def maybeCreateServers( self ):
      '''
      Start web server if it's enabled

      We don't need the activity lock here because only the main thread is running.
      '''
      if not self.dot1xConfig.captivePortal.enabled or not self.dot1xL2Forwarder:
         return
      protos = [ 'http' ]
      if Dot1xToggle.toggleDot1xWebAuthHttpsEnabled():
         protos.append( 'https' )
      if not self.threadPool:
         self.threadPool = ThreadPool( processes=2 )
      for proto in protos:
         https = proto == 'https'
         if proto not in self.httpServers:
            trace( "Creating", name(),
                   "for protocol", bv( proto ) )
            self.httpServers[ proto ] = Dot1xHttpServer(
               agent=self, threadPool=self.threadPool, https=https )
         self.webAgentStatus.running = True

   def stopHttpServer( self ):
      '''
      Stop web servers if they are running

      We can't get the activity lock here because the handler threads might want to
      get the captive portal config after a timeout, and we wait for them to finish
      when we self.threadPool.join - that would cause a deadlock.
      '''
      if self.threadPool:
         trace( 'stopHttpServer close threads' )
         self.threadPool.close()
         trace( 'stopHttpServer join threads' )
         self.threadPool.join()
         trace( 'stopHttpServer joined threads' )
      for proto, server in self.httpServers.items():
         trace( 'stopHttpServer shutting down', proto )
         server.shutdown()
      self.httpServers = {}
      self.threadPool = None
      trace( 'stopHttpServer done, ready for agent shutdown' )
      self.webAgentStatus.running = False

   def handleCaptivePortal( self, notifiee=None ):
      trace( thId(), 'handle captive portal' )
      if self.dot1xConfig.captivePortal.enabled:
         self.maybeCreateServers()
      else:
         self.stopHttpServer()
      trace( thId(), 'handle captive portal done' )

   def maybeCreateL2Forwarder( self ):
      '''
      Create Dot1xL2Forwarder if it doesn't exist

      Notes:
      - We need this because Dot1xL2Forwarder can't be created before the platform
        agent creates the webauth interface; this function checks that by looking
        at kniStatus.
      - maybeCreateL2Forwarder is called when the Dot1xWeb agent starts and when a
        new interface pops up, via handleIntf.
      - Dot1xL2Forwarder is never destroyed (webauth intf is never destroyed).
      - We only instantiate Dot1xHttpServer after Dot1xL2Forwarder => we are
        guaranteed to be running a single thread when this runs.
      '''
      if self.dot1xL2Forwarder:
         return
      for intf in self.kniStatus.interface.values():
         if intf.deviceName == WebAuthWireIntfName:
            trace( thId(), 'maybeCreateL2Forwarder found intf',
                   intf.deviceName )
            self.dot1xL2Forwarder = Dot1xL2Forwarder()
            self.maybeCreateServers()
            return
      trace( thId(), 'maybeCreateL2Forwarder could not find intf',
             WebAuthWireIntfName )

   def handleIntf( self, notifiee=None, key=None ):
      '''
      Reacts to interfaces appearing in kniStatus (i.e. in Linux)

      Calls maybeCreateL2Forwarder if Dot1xL2Forwarder has not been instantiated yet
      and "key" points to the webauth interface.
      '''
      if self.dot1xL2Forwarder:
         return
      trace( thId(), 'handleIntf', bv( self.kniStatus.interface[ key ].deviceName ) )
      if ( key in self.kniStatus.interface and
           self.kniStatus.interface[ key ].deviceName == WebAuthWireIntfName ):
         self.maybeCreateL2Forwarder()

def name():
   ''' Call this to establish an explicit dependency on the Dot1xWeb
   agent executable, to be discovered by static analysis. '''
   return 'Dot1xWeb'
