#!/usr/bin/env python
# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import os
import fcntl
import struct
from collections import namedtuple
from contextlib import contextmanager
import threading

import Tac
import QuickTrace
import Arnet
import Arnet.NsLib
from Arnet.NsLib import DEFAULT_NS

ETH_P_IP = 0x0800
ETH_HLEN = 14

BUFFER_SIZE = 1500
MACTABLE_TIMEOUT = 300 # < inactivity time to expire eth header

WebAuthNs = 'webauthNs'
WebAuthWireIntfName = 'webauth'
WebAuthTunIntfName = 'webauthTun'

warn = QuickTrace.trace0
trace = QuickTrace.trace1
bv = QuickTrace.Var

EthType = Tac.Type( 'Arnet::EthType' )
ArnetEthHdr = Tac.Type( 'Arnet::EthHdrWrapper' )
ArnetEth8021QHdr = Tac.Type( 'Arnet::Eth8021QHdrWrapper' )
ArnetIpHdr = Tac.Type( 'Arnet::IpHdrWrapper' )
ArnetPkt = Tac.Type( 'Arnet::Pkt' )
PrivateTcpPorts = Tac.Type( "Arnet::PrivateTcpPorts" )
MacEntry = namedtuple( 'MacEntry', [ 'vlan', 'src', 'dst' ] )

class MacTableCore( object ):
   '''
   This class owns the mac and expiration tables, and the lock

   All functions here must run with mactableLock acquired, and can't use most
   Tac functions due to lock restrictions.
   '''

   def __init__( self ):
      self.mactable = {}
      self.expiration = {}
      self.mactableLock_ = threading.Lock()

   @contextmanager
   def mactableLock( self ):
      '''
      Protect self.mactable, as we use it in Dot1xWebAgentLib.HttpReqHandler
      instances that run in threads other than main
      '''
      self.mactableLock_.acquire()
      try:
         yield
      finally:
         self.mactableLock_.release()

   def _renewEntry( self, ip, now ):
      '''Must be called within a mactableLock'''
      expiration = now + MACTABLE_TIMEOUT
      self.expiration[ ip ] = expiration
      trace( 'renew expiration of ip', bv( ip ), 'to', bv( expiration ) )
      return expiration

   def getMac( self, now, ip ):
      with self.mactableLock():
         entry = self.mactable.get( ip )
         if not entry:
            return None
         self._renewEntry( ip, now )
         return entry

   def setMac( self, now, timeMin, ip, value ):
      with self.mactableLock():
         self.mactable[ ip ] = value
         expiration = self._renewEntry( ip, now )
         if expiration < timeMin:
            return expiration
         return None

   def handleExpirationTimer( self, now ):
      with self.mactableLock():
         nextExpiration = None
         for ip, expiration in self.expiration.items():
            if expiration <= now:
               trace( 'expiring ip', bv( ip ), 'expiration', bv( expiration ),
                      'now', now )
               del self.mactable[ ip ]
               del self.expiration[ ip ]
            if nextExpiration is None or expiration < nextExpiration:
               nextExpiration = expiration
         return nextExpiration

   def finish( self ):
      '''Used in tests to prevent old tables from expiring IPs and polluting logs'''
      with self.mactableLock():
         self.mactable = {}
         self.expiration = {}

class MacTable( object ):
   '''
   This class wraps over a MacTableCore object and can do Tac.now calls, set tacc
   objects, etc.

   Having the two classes clearly separated will hopefuly prevent lock races,
   which are non-deterministic, hard to catch and reason about, and hard to solve.
   '''

   def __init__( self ):
      self.mactableCore = MacTableCore()
      self.timer = Tac.ClockNotifiee( handler=self.handleExpirationTimer )

   def getMac( self, ip ):
      return self.mactableCore.getMac( Tac.now(), ip )

   def setMac( self, ip, value ):
      nextExpiration = self.mactableCore.setMac( Tac.now(), self.timer.timeMin,
                                                 ip, value )
      if nextExpiration is not None:
         self.timer.timeMin = nextExpiration

   def handleExpirationTimer( self ):
      nextExpiration = self.mactableCore.handleExpirationTimer( Tac.now() )
      if nextExpiration is not None:
         self.timer.timeMin = nextExpiration

   def finish( self ):
      self.mactableCore.finish()

def createTun( intfName, ns=None ):
   TUNSETIFF = 0x400454ca
   IFF_TUN = 0x0001
   IFF_NO_PI = 0x1000
   tunFd = os.open( '/dev/net/tun', os.O_RDWR )
   ifr = struct.pack( '16sH', intfName, IFF_TUN | IFF_NO_PI )
   fcntl.ioctl( tunFd, TUNSETIFF, ifr )
   if ns and ns != DEFAULT_NS:
      Tac.run( [ 'ip', 'link', 'set', intfName, 'netns', ns ] )
   return tunFd

class Tun( Tac.Notifiee ):
   notifierTypeName = 'Tac::FileDescriptor'

   def __init__( self ):
      self.tunFd = createTun( intfName=WebAuthTunIntfName, ns=WebAuthNs )
      iptablesPref = [ 'iptables', '-t', 'nat' ]
      dot1xWebHttpPort = PrivateTcpPorts.dot1xWebHttpPort
      # dot1xWebHttpsPort = PrivateTcpPorts.dot1xWebHttpsPort
      webauthIp = '127.1.1.3'
      cmds = [
         [ 'ip', 'link', 'set', WebAuthTunIntfName, 'up' ],
         [ 'ip', 'addr', 'add', webauthIp + '/24', 'dev', WebAuthTunIntfName ],
         [ 'ip', 'route', 'add', 'default', 'dev', WebAuthTunIntfName ],
         [ 'sysctl', '-w',
           'net.ipv4.conf.%s.route_localnet=1' % WebAuthTunIntfName ],
         [ 'sysctl', '-w',
           'net.ipv4.conf.%s.forwarding=0' % WebAuthTunIntfName ],
         iptablesPref + [ '-F', 'PREROUTING' ],
         iptablesPref + [ '-A', 'PREROUTING',
                          '-i', WebAuthTunIntfName, '-p', 'tcp',
                          '-j', 'DNAT',
                          '--to-destination',
                          '%s:%d' % ( webauthIp, dot1xWebHttpPort ) ],
         # TODO: HTTPS
      ]
      for cmd in cmds:
         trace( 'running', bv( ' '.join( cmd ) ) )
         Arnet.NsLib.runMaybeInNetNs( WebAuthNs, cmd )
      self.tunFileDesc = Tac.newInstance( 'Tac::FileDescriptor', 'tun' )
      self.tunFileDesc.descriptor = self.tunFd
      self.ethPam = None
      Tac.Notifiee.__init__( self, self.tunFileDesc )

   def finish( self ):
      '''Close tun device explicitly; required in cohab tests'''
      os.close( self.tunFd )

   def setEthPam( self, ethPam ):
      self.ethPam = ethPam

   def write( self, data ):
      os.write( self.tunFd, data )

   @Tac.handler( 'readableCount' )
   def handleReadableCount( self ):
      data = os.read( self.tunFd, BUFFER_SIZE )
      if not data or not self.ethPam:
         return
      self.ethPam.sendIpPkt( data )

class EthPam( Tac.Notifiee ):
   notifierTypeName = "Arnet::EthDevPam"

   def __init__( self, mactable ):
      self.mactable = mactable
      self.ethPam = Tac.newInstance( 'Arnet::EthDevPam', WebAuthWireIntfName )
      self.ethPam.ethProtocol = ETH_P_IP
      self.ethPam.enabled = True
      self.tun = None
      Tac.Notifiee.__init__( self, self.ethPam )

   def setTun( self, tun ):
      self.tun = tun

   def sendIpPkt( self, data ):
      pkt = ArnetPkt()
      pkt.stringValue = data
      iphdr = ArnetIpHdr( pkt, 0 )
      if iphdr.version == 6:
         # Can happen in btest
         return
      macentry = self.mactable.getMac( iphdr.dst )
      if not macentry:
         warn( 'EthPam, no macentry for ip dst', bv( iphdr.dst ),
               'src', bv( iphdr.src ) )
         return
      pkt.newSharedHeadData = 18
      ethhdr = ArnetEthHdr( pkt, 0 )
      # Invert src and dst as we are answering:
      ethhdr.src = macentry.dst
      ethhdr.dst = macentry.src
      ethhdr.ethType = 'ethTypeDot1Q'
      dot1q = ArnetEth8021QHdr( pkt, ETH_HLEN )
      dot1q.tagControlPriority = 0
      dot1q.tagControlCfi = False
      dot1q.tagControlVlanId = macentry.vlan
      dot1q.ethType = 'ethTypeIp'
      trace( 'EthPam, to wire:',
             'eth src', bv( ethhdr.src ), 'dst', bv( ethhdr.dst ),
             'vlan', bv( dot1q.tagControlVlanId ),
             'ip src', bv( iphdr.src ), 'dst', bv( iphdr.dst ),
             'len', bv( len( pkt.stringValue ) ) )
      self.ethPam.txPkt = pkt

   @Tac.handler( 'readableCount' )
   def handleReadableCount( self ):
      pkt = self.ethPam.rxPkt()
      if not pkt or not self.tun:
         return
      ethhdr = ArnetEthHdr( pkt, 0 )
      dot1q = ArnetEth8021QHdr( pkt, ETH_HLEN )
      iphdr = ArnetIpHdr( pkt, ethhdr.ipHdrOffset )
      trace( 'EthPam, from wire:',
             'eth src', ethhdr.src, 'dst', ethhdr.dst,
             'vlan', dot1q.tagControlVlanId,
             'ip src', iphdr.src, 'dst', iphdr.dst )
      self.mactable.setMac( iphdr.src, MacEntry( vlan=dot1q.tagControlVlanId,
                                                 src=ethhdr.src, dst=ethhdr.dst ) )
      l3pkt = pkt.stringValue[ ethhdr.ipHdrOffset : ]
      self.tun.write( l3pkt )

class Dot1xL2Forwarder( object ):
   def __init__( self ):
      trace( 'Dot1xL2Forwarder initializing' )
      self.mactable = MacTable()
      self.ethPam = EthPam( self.mactable )
      self.tun = Tun()
      self.ethPam.setTun( self.tun )
      self.tun.setEthPam( self.ethPam )
      trace( 'Dot1xL2Forwarder initialized' )

   def finish( self ):
      '''Finish some entities explicitly; helps in cohab tests'''
      self.mactable.finish()
      self.tun.finish()

   def getMac( self, ip ):
      return self.mactable.getMac( ip )

def name():
   ''' Call this to establish an explicit dependency on the Dot1xWeb
   agent executable, to be discovered by static analysis. '''
   return 'Dot1xL2Forwarder'

# This is used for netns setup in both EOS and stests:

def runCmd( cmd, ns=None ):
   try:
      Arnet.NsLib.runMaybeInNetNs( ns, cmd.split( ' ' ) )
      return True
   except Tac.SystemCommandError:
      return False

def setupWebauthNs():
   '''
   Sets up webauthNs and moves the webauth interface there
   To support agent restart and stests, we don't do what is already done.

   The "runCmd" function gets ( cmd, ns=None ) as arguments, and returns True
   for success or False for error. cmd is a string.
   '''
   trace( 'Setting up WebAuthNs' )
   if not runCmd( 'true', ns=WebAuthNs ):
      trace( 'Creating namespace for webauth', WebAuthNs )
      runCmd( 'ip netns add %s' % WebAuthNs )
   else:
      trace( 'Namespace', WebAuthNs, 'exists, skipping creation' )
   trace( 'Webauth setup done' )
