#!/usr/bin/env python
# Copyright (c) 2009, 2010, 2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac, os
import Ethernet
import struct
import socket
import Tracing
import QuickTrace
from if_ether_arista import TC_PRIO_CONTROL
from zlib import crc32
import subprocess
import signal
from ctypes import cdll

hbt0 = Tracing.trace0 # Heartbeat header check
t2 = Tracing.trace2 # Events in method
t3 = Tracing.trace3 # Ununsual events/error scenario
t4 = Tracing.trace4 # Events in FileTransfer manager

qv = QuickTrace.Var
qt3 = QuickTrace.trace3
qt4 = QuickTrace.trace4

MLAG_PORT = 4432
# Used during reload-delay to transfer ARP table from Primary->Secondary
MLAG_ARP_SYNC_PORT = 50002
# socket option not defined in socket.py
IP_RECVTTL = 12
# Set Mlag and MlagTunnel socket TTL to 255 
# to pass through Acl filter for Mlag
MLAG_TTL = 255
# Set number of SYN retries to be low so tacc retries the
# connection more often than the kernel backoff algorithm
MLAG_SYN_RETRIES = 1
# Set Mlag and MlagTunnel socket TOS to 0xe0
MLAG_TOS = 0xe0
# Kill the process if it runs more than 120s.
PROCESS_MANAGER_POLLER_TIMEOUT = int( os.environ.get(
   "PROCESS_MANAGER_POLLER_TIMEOUT", '120' ) )

CALLBACK_HIGH_PRIORITY = 0
CALLBACK_MEDIUM_PRIORITY = 1
CALLBACK_LOW_PRIORITY = 2

HEARTBEAT_TIMEOUT_MULTIPLIER = 15
HEARTBEAT_INITIAL_WAIT = 300.0
MAX_HEARTBEAT_TIMEOUT = 75000

UDPHEARTBEATPORTID = 4096
UDPHEARTBEAT_HDR_MAGICNUM_NO_TLV = 547 # <=V7
UDPHEARTBEAT_HDR_MAGICNUM = 548        # >= V8
UDPHEARTBEAT_HDR_FORMAT = "!HIH"
UDPHEARTBEAT_HDR_SIZE = struct.calcsize( UDPHEARTBEAT_HDR_FORMAT )

VERSION_PKT_FORMAT = "!II"
VERSION_PKT_SIZE = struct.calcsize( VERSION_PKT_FORMAT )

PEER_TIME_OFFSET_FORMAT = "!ddd"
PEER_TIME_OFFSET_SIZE = struct.calcsize( PEER_TIME_OFFSET_FORMAT )

# NEG_STATUS_ values now originate in Tac
NEG_STATUS_CONNECTED = Tac.Type( "Mlag::NegotiationStatus" ).connected
NEG_STATUS_CONNECTING = Tac.Type( "Mlag::NegotiationStatus" ).connecting
NEG_STATUS_DOMAIN_MISMATCH = Tac.Type( "Mlag::NegotiationStatus" ).domainMismatch
NEG_STATUS_INVALID_PEER = Tac.Type( "Mlag::NegotiationStatus" ).invalidPeer
NEG_STATUS_NEGOTIATION = Tac.Type( "Mlag::NegotiationStatus" ).negotiation
NEG_STATUS_VERS_INCOMPATIBLE = \
   Tac.Type( "Mlag::NegotiationStatus" ).versIncompatible
NEG_STATUS_HW_NOT_READY = Tac.Type( "Mlag::NegotiationStatus" ).hwNotReady
NEG_STATUS_ALL = \
                 ( NEG_STATUS_CONNECTED,
                   NEG_STATUS_CONNECTING,
                   NEG_STATUS_DOMAIN_MISMATCH,
                   NEG_STATUS_INVALID_PEER,
                   NEG_STATUS_NEGOTIATION,
                   NEG_STATUS_VERS_INCOMPATIBLE,
                   NEG_STATUS_HW_NOT_READY )

FORCEFUL_FAILOVER_TIMEOUT = 30
# Note: Keep this in sync with 'MlagShutdown' stage timeout within 
# Asu EosStageEvents SysdbPlugin and keep this more than FORCEFUL_FAILOVER_TIMEOUT 
# with enough time for Mlag.mlagAsuTimer_ to run.
MLAG_ASU_SHUTDOWN_STAGE_TIMEOUT = 60

def heartbeatTimeoutFromInterval( interval ):
   return min( HEARTBEAT_TIMEOUT_MULTIPLIER * interval,
               MAX_HEARTBEAT_TIMEOUT )

def heartbeatTimeout( cfg, interval ):
   return cfg.heartbeatTimeout \
       if cfg.heartbeatTimeout != 0 \
       else heartbeatTimeoutFromInterval( interval )

def msi( macAddr ):
   # Set locally administered bit
   mc = Ethernet.convertMacAddrToPackedString( macAddr )
   # We don't expect the bit to be set already.
   assert ord( mc[0] ) & 0x2 == 0
   mc = chr( ord( mc[0] ) | 0x2 ) + mc[ 1: ]
   colonAddr = '%02x:%02x:%02x:%02x:%02x:%02x' % tuple( [ ord( b ) for b in mc ] )
   return colonAddr

def minSystemId( macAddrs ):
   return min( [ msi( macAddr ) for macAddr in macAddrs ] )

def newVersionPkt( minVersion, maxVersion ):
   versionPkt = struct.pack( VERSION_PKT_FORMAT, minVersion, maxVersion )
   return versionPkt

def versionFieldsFromPkt( pkt ):
   assert len( pkt ) >= VERSION_PKT_SIZE
   versionPkt = pkt[ : VERSION_PKT_SIZE ]
   minVersion, maxVersion = struct.unpack( VERSION_PKT_FORMAT,
                                           versionPkt )
   return ( minVersion, maxVersion )

def newPeerClockHdr( pctSm, now ):
   # pctSm is PeerClockTrackerSm
   # encodes the 24 byte peer clock tracker header in network byte order
   peerClockHdr = struct.pack( '!ddd', pctSm.heartbeatLastReceivedTimestamp,
                               pctSm.heartbeatLastReceivedAt, now )
   return peerClockHdr

def newHeartbeatHdr( domainId, useTlv=True ):
   domainIdCrc = crc32( domainId ) & 0xffffffff
   heartbeatHdr = struct.pack( UDPHEARTBEAT_HDR_FORMAT, 
                               UDPHEARTBEAT_HDR_MAGICNUM if useTlv else
                               UDPHEARTBEAT_HDR_MAGICNUM_NO_TLV,
                               domainIdCrc,
                               # The lsb for this field denotes whether the
                               # packet came via peer's local interface (which
                               # is always false for UDP heartbeat packet)
                               # hence the left shift.
                               UDPHEARTBEATPORTID << 1 )
   return heartbeatHdr

def validateHeartbeatPkt( pktStr, domainId ):
   if len( pktStr ) < UDPHEARTBEAT_HDR_SIZE:
      hbt0( "Received packet with invalid length:", len( pktStr ) )
      return
   hdr = pktStr[ : UDPHEARTBEAT_HDR_SIZE ]
   ( magicNum, domainIdCrc, fromPeerIntfAndPortId ) = \
         struct.unpack( UDPHEARTBEAT_HDR_FORMAT, hdr )
   if ( magicNum != UDPHEARTBEAT_HDR_MAGICNUM_NO_TLV and
        magicNum != UDPHEARTBEAT_HDR_MAGICNUM ):
      hbt0( "Received packet with unexpected magic number: ", magicNum )
      return
   localDomainIdCrc = crc32( domainId ) & 0xffffffff
   if domainIdCrc != localDomainIdCrc:
      hbt0( "Received packet with unexpected domainId CRC: ", domainIdCrc, 
            " local: ", localDomainIdCrc )
      return
   portId = fromPeerIntfAndPortId >> 1
   if portId != UDPHEARTBEATPORTID:
      hbt0( "Received packet with unexpected portId: ", portId )
      return
   return pktStr[ UDPHEARTBEAT_HDR_SIZE : ]

def socketOpts( synCntOption=True, deviceName=None ):
   # security and QoS options for Mlag sockets
   opts = [ ( socket.IPPROTO_IP, socket.IP_TTL, MLAG_TTL ),
            ( socket.IPPROTO_IP, socket.IP_TOS, MLAG_TOS ),
            ( socket.SOL_SOCKET, socket.SO_PRIORITY, TC_PRIO_CONTROL ) ]
   if synCntOption:
      opts += [ ( socket.IPPROTO_TCP, socket.TCP_SYNCNT, MLAG_SYN_RETRIES ) ]
   if deviceName:
      opts += localInterfaceSocketOpts( deviceName )
   return opts

def localInterfaceSocketOpts( deviceName ):
   # Socket options specific to the local-interface
   return [ ( socket.SOL_SOCKET, socket.SO_BINDTODEVICE, deviceName ) ]

class ProcessManager():
   """
   Instantiates a process based on args passed and ensures child process exits on
   parent death.
   1. It embeds a Tac poller to monitor the forked process.
   2. Implements retry mechanism with clock notifiee if the return value
      is not successful ( returnCode == 0 )
   3. callback is added for caller to take relevant action on process exit code.
   """
   def __init__( self, args, maxRetry=0, retryTimeout=60.0, callback=None,
                 stdout=subprocess.PIPE ):
      self.args = args
      self.maxRetry = maxRetry
      self.retryTimeout = retryTimeout
      self.returnCode = None
      self.proc = None
      self.callback = callback
      self.stdout = stdout
      # State to track number of retries before we bail out
      self.numRetry = 0
      self.poller = None
      self.procRetryActivity = Tac.ClockNotifiee( self.handleRetryTimeout )
      self.procRetryActivity.timeMin = Tac.endOfTime
      t4( "Initialized procMgr with args", args )

   def cleanup( self ):
      # cleanup all local state except numRetry as we need to re-run the process
      t4( "ProcMgr cleanup. returnCode", self.returnCode )
      self.procRetryActivity.timeMin = Tac.endOfTime
      if self.poller:
         # Remove from clockNotifiee
         self.poller.cancel()
         self.poller = None
      if self.returnCode is None and self.proc:
         t3( "Killing process before it has a chance to exit." )
         qt3( "Killing process before it has a chance to exit." )
         self.proc.kill()
      self.proc = None
      self.returnCode = None

   def handleRetryTimeout( self ):
      # Called when we want to re-run the process
      t3( "Running the process for", self.numRetry, "times" )
      qt3( "Running the process for", qv( self.numRetry ), "times" )
      self.run()

   def run( self ):
      def createProcess():
         # Returns a Popen object for child process which will die on parent exit.
         PR_SET_PDEATHSIG = 1
         def onParentExit( signame='SIGKILL' ):
            """
            Returns a function to be called after fork and before exec to modify
            prctl for PDEATHSIG
            """
            signum = getattr( signal, signame )
            def setParentDeathSig():
               result = cdll[ 'libc.so.6' ].prctl( PR_SET_PDEATHSIG, signum )
               if result != 0:
                  t3( "prctl failed with code", result )
                  qt3( "prctl failed with code", qv( result ) )
            return setParentDeathSig

         proc = subprocess.Popen( self.args,
                                  stdout=self.stdout,
                                  stderr=subprocess.PIPE,
                                  close_fds=True,
                                  preexec_fn=onParentExit() )
         return proc

      def procReturnCode():
         # Poller condition handler to check process exit status
         self.returnCode = self.proc.poll()
         if self.returnCode is not None:
            return True
         else:
            return False

      def handleProcReturnCode( dummy=None ):
         # Poller callback handler when condition is True
         t4( "handleProcReturnCode called:", self.returnCode )
         self.poller = None
         self.numRetry += 1
         if self.returnCode is 0 or self.numRetry > self.maxRetry:
            t4( "Invoking process exit callback" )
            self.callback( self.returnCode )
         else:
            t4( "Need a clock notifiee to retry after some timeout" )
            self.procRetryActivity.timeMin = Tac.now() + self.retryTimeout

      def handleProcTimeout():
         # Subprocess has run too long without exiting. Timeout handler.
         t4( "Timed out in file transfer" )
         self.proc.kill()
         # Note the failure
         self.returnCode = -1
         handleProcReturnCode()

      # Ensure we don't call run while we are already managing a process
      self.cleanup()
      returnCode = int( os.environ.get( "MLAG_PROC_DIRECT_CALLBACK_CODE", '-1' ) )
      if returnCode != -1:
         self.callback( returnCode )
      else:
         self.proc = createProcess()
         self.poller = Tac.Poller( procReturnCode, handleProcReturnCode,
                                   timeoutHandler=handleProcTimeout,
                                   timeout=PROCESS_MANAGER_POLLER_TIMEOUT,
                                   description="process to exit", maxDelay=5.0 )

def configCheckDir( peerRoot ):
   ccd = peerRoot.get( 'configCheck' )
   if not ccd:
      # subdirDeprecated collection is being deprecated from EOS.
      # to be compatible with past/future releases
      # first look up in default collection ( entityPtr )
      # see message on top of MlagConfigCheckPlugin/Mlag.py regarding
      # subdirDeprecated lookup/creation.
      ccd = peerRoot.subdirDeprecated[ 'configCheck' ]
   return ccd
