#!/usr/bin/env python
# Copyright (c) 2010-2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import

import Tac, Tracing, Arnet
import os, re
from SysdbHelperUtils import SysdbPathHelper

__defaultTraceHandle__ = Tracing.Handle( "ZeroTouch" )
t0 = Tracing.trace0
t8 = Tracing.trace8

dhcpStatus = None
sysLoggingConfig = None

dhclientEnvVars = [
   'reason' ,
   'interface',
   'new_ip_address',
   'new_subnet_mask',
   'new_domain_name',
   'new_domain_name_servers',
   'new_log_servers',
   'new_domain_search',
   'old_ip_address',
   'alias_ip_address',
   'new_broadcast_address',
   'new_interface_mtu',
   'new_routers',
   'new_static_routes',
   'new_rfc3442_classless_static_routes',
   'new_host_name',
   'new_time_offset',
   'new_tftp_server_name',
   'new_bootfile_name',
   'new_ip6_address',
   'new_ip6_prefixlen',
   'new_dhcp6_name_servers',
   'new_dhcp6_domain_search',
   'new_dhcp6_bootfile_url' ]

dhclient6EnvVars = [
   'new_ip6_address',
   'new_ip6_prefixlen',
   'new_dhcp6_name_servers',
   'new_dhcp6_domain_search',
   'new_dhcp6_bootfile_url' ]

def RouteKey( prefix, preference ):
   return Tac.Value( "Routing::RouteKey", prefix=prefix, preference=preference )

def Via( hop, intf ):
   if intf:
      return Tac.Value( "Routing::Via", hop=Arnet.IpGenAddr( hop ), intfId=intf )
   else:
      return Tac.Value( "Routing::Via", hop=Arnet.IpGenAddr( hop ), intfId='' )

def IpAddr( addr ):
   return Tac.Value( "Arnet::IpAddr", addr )

def classBits( ip ):
   bits = 32
   ipBytes = ip.split( '.' )
   if ( len(ipBytes) != 4 ):
      return 0
   ipInt = (int(ipBytes[0]) << 24) + (int(ipBytes[1]) << 16) +\
       (int(ipBytes[2]) << 8) + int(ipBytes[3])
   mask = 255
   for i in range(0, 4):
      if ((ipInt & mask) == 0):
         bits -= 8
      else:
         return bits
      mask <<= 8
   return bits

def getPrefix( ip, nm ):
   prefix = None
   if ip and nm:
      try:
         output = Tac.run( ['ipcalc', '-s', '-p', ip, nm],
                           stdout=Tac.CAPTURE )
         m = re.match( 'PREFIX=(\d+)', output )
         if m:
            prefix = str( m.group( 1 ) )
      except Tac.SystemCommandError, e:
         pass
   return prefix

def parseClasslessStaticRoutes( bytesStr ):
   try:
      _bytes = [ int( b ) for b in bytesStr.split() ]
   except ValueError:
      return []

   routes = []
   while _bytes:
      mask = _bytes.pop(0)
      if mask > 24:
         ipBytes = 4
      elif mask > 16:
         ipBytes = 3
      elif mask > 8:
         ipBytes = 2
      else:
         ipBytes = 1

      # not enough bytes in input!
      if len( _bytes ) < (ipBytes + 4):
         break

      if mask > 24:
         destNetwork = '%d.%d.%d.%d' % ( _bytes.pop(0), _bytes.pop(0),
                                         _bytes.pop(0), _bytes.pop(0) )
      elif mask > 16:
         destNetwork = '%d.%d.%d.0' % ( _bytes.pop(0), _bytes.pop(0),
                                        _bytes.pop(0) )
      elif mask > 8:
         destNetwork = '%d.%d.0.0' % ( _bytes.pop(0), _bytes.pop(0) )
      else:
         destNetwork = '%d.0.0.0' % ( _bytes.pop(0) )

      nextHop = '%d.%d.%d.%d' % ( _bytes.pop(0), _bytes.pop(0),
                                  _bytes.pop(0), _bytes.pop(0) )

      routes.append( ( Arnet.IpGenPrefix( destNetwork+'/'+str(mask) ), nextHop) )
   return routes

def getEnvVars( dhclientGenEnvVars ):
   for var in dhclientGenEnvVars:
      if var in os.environ:
         t8( '%s=%s' % (var, os.environ.get( var, '' )) )
         globals()[ var ] = os.environ.get( var, '' )
      else:
         globals()[ var ] = None

def getGenEnvVars():
   global dhclientEnvVars
   global dhclient6EnvVars
   getEnvVars( dhclientEnvVars )
   getEnvVars( dhclient6EnvVars )

def preinitHandler():
   return 0

def arpcheckHandler():
   if not new_ip_address or not interface:
      return 0

   cmd = "/usr/sbin/arping -q -f -c 2 -w 3 -D -I %s %s" \
       % ( interface, new_ip_address)

   if os.system( cmd ):
      return 0
   else:
      return 1

def populateNameServers( new_domain_gen_name_servers, intfDhcpGenStatus ):
   sepRe = re.compile( '[, ]\s*' )
   nameServers = sepRe.split( new_domain_gen_name_servers )
   id = 0
   for nameServer in nameServers:
      intfDhcpGenStatus[ interface ].nameServer[ id ] = Arnet.IpGenAddr( nameServer )
      id += 1
      if id >= 3:
         break

def dhconfigBoundHandler():
   global dhcpStatus
   global sysLoggingConfig

   prefix = getPrefix( new_ip_address, new_subnet_mask )
   if not prefix or not new_ip_address:
      return 0

   addrWithMask = '%s/%s' % (new_ip_address, prefix)

   # ip address and mask
   dhcpStatus.intfDhcpStatus[ interface ].addrWithMask = Arnet.IpGenAddrWithMask(
                                                  new_ip_address + '/' + prefix )

   # interface mtu
   if new_interface_mtu and new_interface_mtu > 576:
      dhcpStatus.intfDhcpStatus[ interface ].mtu = int( new_interface_mtu )

   # static routes
   if new_static_routes:
      sepRe = re.compile( '[, ]\s*' )
      staticRoutes = sepRe.split( new_static_routes )
      id = 0
      for target, gateway in zip( staticRoutes[::2], staticRoutes[1::2] ):
         prefix = Arnet.IpGenPrefix( '%s/%d' % (target, classBits( target ) ) )
         dhcpStatus.intfDhcpStatus[ interface ].staticRoute[ id ] = \
             Tac.Value( "ZeroTouch::Route", \
                           key=RouteKey( prefix, 1 ), \
                           via=Via( gateway, None ) )
         id += 1

   # classless static routes
   if new_rfc3442_classless_static_routes:
      routes = parseClasslessStaticRoutes( new_rfc3442_classless_static_routes )
      id = 0
      for route in routes:
         (prefix, gateway) = route
         dhcpStatus.intfDhcpStatus[ interface ].staticRoute[ id ] = \
             Tac.Value( "ZeroTouch::Route", \
                           key=RouteKey( prefix, 1 ), \
                           via=Via( gateway, None ) )
         id += 1

   # gateways
   if new_routers:
      sepRe = re.compile( '[, ]\s*' )
      gateways = sepRe.split( new_routers )
      id = 0
      for gateway in gateways:
         dhcpStatus.intfDhcpStatus[ interface ].gateway[ id ] = \
               Arnet.IpGenAddr( gateway )
         id += 1
         if id >= 3:
            break

   # host name
   if new_host_name:
      dhcpStatus.intfDhcpStatus[ interface ].hostname = new_host_name

   # domain name
   if new_domain_name:
      dhcpStatus.intfDhcpStatus[ interface ].domainName = new_domain_name

   # name servers
   if new_domain_name_servers:
      populateNameServers( new_domain_name_servers, dhcpStatus.intfDhcpStatus )

   # log servers
   if new_log_servers:
      sepRe = re.compile( '[, ]\s*' )
      logServers = sepRe.split( new_log_servers )

      for ipAddrOrHostname in logServers:
         loggingHostType = Tac.Type( "LogMgr::LoggingHost" )
         port = loggingHostType().portDefault
         ports = {}
         ports[ port ] = port
         hostInfo = Tac.Value( "LogMgr::LoggingHost",
                               ipAddrOrHostname=ipAddrOrHostname,
                               protocol="udp",
                               ports=ports )
         dhcpStatus.intfDhcpStatus[ interface ].loggingHost.addMember( hostInfo )

   # dhcp options 66 and 67
   if new_tftp_server_name:
      dhcpStatus.intfDhcpStatus[ interface ].serverName  = new_tftp_server_name

   if new_bootfile_name:
      dhcpStatus.intfDhcpStatus[ interface ].bootFileName = new_bootfile_name

   # Note: genId has to be last assigned attr as this triggers the
   # state machine
   dhcpStatus.intfDhcpStatus[ interface ].genId += 1

   return 0

def doDAD():
   pass

def dhconfig6BoundHandler():
   global dhcpStatus

   if not new_ip6_address or not new_ip6_prefixlen:
      return 0

   # XXX-sarangs
   # Still to be implemented
   # Do DAD Bug 210715 tracks this
   doDAD()

   prefix = new_ip6_prefixlen
   # ip6 address and mask
   dhcpStatus.intfDhcp6Status[ interface ].addrWithMask = Arnet.IpGenAddrWithMask(
                                                  new_ip6_address + '/' + prefix )

   # domain name
   if new_dhcp6_domain_search:
      dhcpStatus.intfDhcp6Status[ interface ].domainName = new_dhcp6_domain_search

   # name servers
   if new_dhcp6_name_servers:
      populateNameServers( new_dhcp6_name_servers, dhcpStatus.intfDhcp6Status )

   # dhcp6 option 59
   if new_dhcp6_bootfile_url:
      dhcpStatus.intfDhcp6Status[ interface ].bootFileName = new_dhcp6_bootfile_url

   # Note: genId has to be last assigned attr as this triggers the
   # state machine
   dhcpStatus.intfDhcp6Status[ interface ].genId += 1

   return 0

def dhconfig6RenewHandler():
   global dhcpStatus

   # domain name
   if new_dhcp6_domain_search:
      dhcpStatus.intfDhcp6Status[ interface ].domainName = new_dhcp6_domain_search

   # name servers
   if new_dhcp6_name_servers:
      populateNameServers( new_dhcp6_name_servers, dhcpStatus.intfDhcp6Status )

   # dhcp6 option 59
   if new_dhcp6_bootfile_url:
      dhcpStatus.intfDhcp6Status[ interface ].bootFileName = new_dhcp6_bootfile_url

   # Note: genId has to be last assigned attr as this triggers the
   # state machine
   dhcpStatus.intfDhcp6Status[ interface ].genId += 1

   return 0

def dhconfigHandler():
   if reason == 'BOUND' or reason == 'REBOOT':
      dhconfigBoundHandler()
   return 0

def dhconfig6Handler():
   if reason == 'BOUND6':
      dhconfig6BoundHandler()
   # Dhclient if started with -6 -S option always invokes the zerotouch-dhclient
   # script with RENEW6 reason
   elif reason == 'RENEW6':
      dhconfig6RenewHandler()
   return 0

def downHandler():
   return 0

def timeoutHandler():
   return 0

def cleanupHandler():
   for intf in dhcpStatus.intfDhcpStatus:
      t0( "del intfDhcpStatus for", intf )
      del dhcpStatus.intfDhcpStatus [ intf ]

   return 0

def cleanup6Handler():
   for intf in dhcpStatus.intfDhcp6Status:
      t0( "del intfDhcpStatus for", intf )
      del dhcpStatus.intfDhcp6Status [ intf ]

   return 0

dhcpoptions = { 'PREINIT':  preinitHandler,
                'ARPCHECK': arpcheckHandler,
                'ARPSEND':  arpcheckHandler,
                'BOUND':    dhconfigHandler,
                'RENEW':    dhconfigHandler,
                'REBIND':   dhconfigHandler,
                'REBOOT':   dhconfigHandler,
                'EXPIRE':   downHandler,
                'FAIL':     downHandler,
                'RELEASE':  downHandler,
                'STOP':     downHandler,
                'TIMEOUT':  timeoutHandler,
                'CLEANUP':  cleanupHandler }

dhcp6options = { 'PREINIT6': preinitHandler,
                 'BOUND6':   dhconfig6Handler,
                 'RENEW6':   dhconfig6Handler,
                 'REBIND6':  dhconfig6Handler,
                 'DEPREF6':  dhconfig6Handler,
                 'EXPIRE6':  downHandler,
                 'RELEASE6': downHandler,
                 'STOP6':    downHandler,
                 'CLEANUP6': cleanup6Handler }

def mountSysdb():
   global dhcpStatus
   global sysLoggingConfig

   # mount sysdb
   sysname = os.environ.get( "SYSNAME", "ar" )
   pathHelper = SysdbPathHelper( sysname )

   dhcpStatus = pathHelper.getEntity( "zerotouch/dhcp/status" )
   if not dhcpStatus:
      raise Exception( "Failed to mount dhcpStatus" )

   sysLoggingConfig = pathHelper.getEntity( "sys/logging/config" )

def main():
   t0( "Starting..." )

   # Extract known environment variables
   getGenEnvVars()

   #XXX: HACK: Don't do anything for 'PREINIT'
   # PREINIT is called for every interface and
   # mounting sysdb 64 (or 384!) times is just
   # a bad idea.
   if not reason or reason == 'PREINIT' or reason == 'PREINIT6':
      return

   # Mount sysdb
   mountSysdb()

   dhclientVersion = 'v6' if "6" in reason else 'v4'

    # Create dhcp and dhcp6 interface status, if one doesn't exist
   if interface:
      if dhclientVersion == 'v4':
         t0( "Creating intf dhcp status", interface )
         dhcpStatus.newIntfDhcpStatus( interface )
      elif dhclientVersion == 'v6':
         t0( "Creating intf dhcp6 status", interface )
         dhcpStatus.newIntfDhcp6Status( interface )

   # Dhclient invokes dhclient-script with reason set to BOUND, RENEW etc. while
   # Dhclient v6 sets the reason to BOUND6, RENEW6 etc.
   # Handle 'reason'
   retCode = 0

   if dhclientVersion == 'v4':
      if reason in dhcpoptions:
         retCode = dhcpoptions[ reason ]()
      if interface:
         dhcpStatus.intfDhcpStatus[ interface ].reason = reason
   elif dhclientVersion == 'v6':
      if reason in dhcp6options:
         retCode = dhcp6options[ reason ]()
      if interface:
         dhcpStatus.intfDhcp6Status[ interface ].reason = reason

   Tac.flushEntityLog()
   os._exit( retCode )

if __name__ == "__main__":
   main()
