#!/usr/bin/env python
# Copyright (c) 2009-2010, 2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Agent, Tracing
from GenericReactor import GenericReactor
import Tac, os, socket, errno, Cell, SharedMem, Smash
from MlagShared import MLAG_PORT, MLAG_TTL, MLAG_TOS, IP_RECVTTL
from if_ether_arista import TC_PRIO_CONTROL
from Arnet.NsLib import DEFAULT_NS
import MlagMountHelper

t0 = Tracing.trace0
t1 = Tracing.trace1
t2 = Tracing.trace2

Tac.activityManager.useEpoll = True

class MlagTunnel( Agent.Agent ):
   """ The MlagTunnel agent allows an agent on one MLAG peer to send
   and/or receive protocol packets on the other peer's interfaces. """

   def __init__( self, entityManager ):
      self.mlagTunnelSm_ = None
      self.tapPamManager_ = None
      self.mlagConstants_ = Tac.Value( "Mlag::Constants" )
      self.udpSocketBoundTo_ = ""
      self.mlagConfigReactor_ = None
      self.mlagStatusReactor_ = None
      self.mlagTunnelStatusUpdater_ = None
      self.redundancyModeReactor_ = None
      self.localInterfaceStatusReactor_ = None
      self.kernelIntfStatusReactor_ = None
      self.mlagStatus_ = None
      self.mlagConfig_ = None
      self.mlagTunnelStatus_ = None
      self.mlagTunnelCellStatus_ = None
      self.mlagTunnelConfig_ = None
      self.peerIntfs_ = None
      self.intfStatus_ = None
      self.ethIntfs_ = None
      self.cleanupActivity_ = None
      self.kniStatus_ = None

      Agent.Agent.__init__( self, entityManager )

      self.tunnelAgentRoot_ = self.agentRoot_.newEntity(
            "Mlag::MlagTunnelAgentRoot", "root" )

      if "QUICKTRACEDIR" in os.environ:
         import QuickTrace
         qtfile = "%s%s.qt" % (self.agentName, "-%d" if "QUICKTRACEDIR"
                               not in os.environ else "" )
         QuickTrace.initialize( qtfile, "10,10,10,10,10,10,10,10,10,10" )

   def doInit( self, entityManager ):
      mg = entityManager.mountGroup()

      self.createLocalEntity( "EthIntfStatusDir",
                              "Interface::EthIntfStatusDir",
                              "interface/status/eth/intf" )
      self.createLocalEntity( "AllIntfStatusDir",
                              "Interface::AllIntfStatusDir",
                              "interface/status/all" )
      self.createLocalEntity( "AllIntfStatusLocalDir",
                              "Interface::AllIntfStatusLocalDir",
                              Cell.path( "interface/status/local" ) )
      self.localEntitiesCreatedIs( True )
      # Mount mlag/status, Mlag::Status and its dependent paths
      self.mlagStatus_ = MlagMountHelper.mountMlagStatus( mg )
      # Mount mlag/config, Mlag::Config and its dependent paths
      self.mlagConfig_ = MlagMountHelper.mountMlagConfig( mg )
      self.protoStatus_ = mg.mount( 'mlag/proto', 'Mlag::ProtoStatus', 'rO' )
      self.mlagTunnelStatus_ = mg.mount( 'mlag/tunnel/status',
                                         'Mlag::TunnelStatus', 'w' )
      self.mlagTunnelConfig_ = mg.mount( 'mlag/tunnel/config', 
                                         'Mlag::TunnelConfig' )
      self.mlagTunnelCellStatus_ = mg.mount( Cell.mountpath( 'mlag/tunnel/status' ),
                                         'Mlag::TunnelCellStatus', 'fw' )
      self.peerIntfs_ = mg.mount( 'interface/status/eth/peer',
                                  'Interface::PeerIntfStatusDir', 'r' )
      self.intfStatus_ = mg.mount( 'interface/status/all',
                                   'Interface::AllIntfStatusDir' )
      self.intfStatusLocal_ = mg.mount( Cell.path( 'interface/status/local' ),
                                        'Interface::AllIntfStatusLocalDir' )
      self.ethIntfs_ = mg.mount( 'interface/status/eth/intf',
                                 'Interface::EthIntfStatusDir' )

      shmemEm = SharedMem.entityManager( sysdbEm=entityManager )

      self.kniStatus_ = shmemEm.doMount( "kni/ns/%s/status" % DEFAULT_NS,
                                         "KernelNetInfo::Status",
                                         Smash.mountInfo( 'keyshadow' ) )

      def _finishMounts():
         t2( 'finishMounts' )
         self.tunnelAgentRoot_.intfStatusManager = (
            self.intfStatus_, self.intfStatusLocal_,
            self.tunnelAgentRoot_.deviceIntfStatusDir, self.kniStatus_ )
         self.tunnelAgentRoot_.intfStatusManager \
                              .intfStatusManager.omitNonRunningKernelDevices = True
         self.kernelIntfStatusReactor_ = GenericReactor(
            self.tunnelAgentRoot_.intfStatusManager.kernelIntfStatusDir,
            [ 'intfStatusAndDeviceName' ], self.handleIntfStatusAndDeviceName )
         self.localInterfaceStatusReactor_ = GenericReactor(
            self.mlagStatus_, [ 'localInterface' ], 
            self.handleIntfStatusAndDeviceName )
         self.mlagConfigReactor_ = GenericReactor(
            self.mlagConfig_, [ 'domainId', 'peerAddress' ],
            self.handleMlagConfig )
         self.mlagStatusReactor_ = GenericReactor(
            self.mlagStatus_, [ 'mlagState' ],
            self.handleState, callBackNow=True )
         self.redundancyModeReactor_ = GenericReactor(
            self.redundancyStatus(), [ 'mode' ],
            self.handleRedundancyMode, callBackNow=True )
      mg.close( _finishMounts )

   def primary( self ):
      return self.mlagStatus_.mlagState == 'primary'

   def secondary( self ):
      return self.mlagStatus_.mlagState == 'secondary'

   def mlagEnabled( self ):
      return self.mlagStatus_.mlagState in ( 'primary', 'secondary', 'inactive' )

   def handleMlagConfig( self, notifiee ):
      t2( "handleMlagConfig" )
      if not self.mlagConfig_.domainId or self.mlagConfig_.peerAddress == '0.0.0.0':
         t2( "Mlag is not configured. Let mlag state transition handle cleanup" )
         return
      if not self.mlagTunnelSm_:
         return
      if ( self.mlagTunnelSm_.domainId != self.mlagConfig_.domainId or
           self.mlagTunnelSm_.onePam.txDstIpAddr != self.mlagConfig_.peerAddress ):
         t2( "domainId or peerAddress changed! Deleting mlagTunnelSm." )
         self.cleanUp( complete=True )
         # Recreate necessary devices
         self.handleState( None ) 

   def handleState( self, notifiee ):
      t2( "MLAG state is " + self.mlagStatus_.mlagState )

      if self.mlagEnabled():
         self.initNonIntfTaps()
      if self.primary() or self.secondary():
         self.mlagTunnelCellStatus_.running = True
         self.initTapPamManager()
         if self.mlagTunnelSm_:
            self.initIntfStateMachines()
      elif self.mlagEnabled():
         self.cleanUp()
      else:
         self.cleanUp( complete=True )

   def handleRedundancyMode( self, notifiee=None ):
      redMode = self.redundancyStatus().mode
      t0( "Redundancy mode is " + redMode )
      if redMode == 'active' and self.mlagTunnelSm_ \
             and not self.mlagTunnelStatusUpdater_:
         t0( "Creating MlagTunnelStatusUpdater" )
         self.mlagTunnelStatusUpdater_ = MlagTunnelStatusUpdater(
               self.mlagTunnelConfig_, self.mlagTunnelStatus_,
               self.mlagTunnelSm_ )
      elif redMode != 'active' or not self.mlagTunnelSm_:
         t0( "Deleting (or not creating) MlagTunnelStatusUpdater" )
         self.mlagTunnelStatusUpdater_ = None

   def initTapPamManager( self ):
      if not self.tapPamManager_:
         self.tapPamManager_ = Tac.newInstance(
            "Mlag::Agent::TapPamManager", "" )

   def initNonIntfTaps( self ):
      ''' Do all the initialization that doesn't require access to the
      interface directories in Sysdb. '''
      if self.mlagTunnelSm_:
         return

      # Create the MlagTunnelSm.
      try:
         self.mlagTunnelSm_ = self.newMlagTunnelSm()
      except socket.error, e:
         print "Socket error when creating MlagTunnelSm:", \
             errno.errorcode[ e.errno ]
         raise
      self.mlagTunnelSm_.mode = 'both'

      # Setup TapPam for UDP heartbeat
      pam = Tac.newInstance( "Arnet::TapPam", "MlagHeartbeat",
                             self.mlagConstants_.udpHeartbeatDevName )
      pam.mode = 'enabled'
      udpHeartbeatPortId = self.mlagConstants_.udpHeartbeatPortId
      self.mlagTunnelSm_.manyPam[ udpHeartbeatPortId ] = pam
      self.mlagTunnelSm_.localToPeerPortId[ udpHeartbeatPortId ] = \
          udpHeartbeatPortId

      # Setup TapPam for Stp stable device
      self.initStpStableDevice()

      # Bind to the local-interface if available
      self.handleIntfStatusAndDeviceName()

      # Create status updater if redundancy mode is active.
      self.handleRedundancyMode()

   def initStpStableDevice( self ):
      if not self.mlagTunnelSm_:
         # Wait for mlag state to become inactive
         return
      udpStpStablePortId = self.mlagConstants_.udpStpStablePortId
      if udpStpStablePortId in self.mlagTunnelSm_.manyPam:
         # Already set up, nothing to do
         return
      # Setup TapPam for Stp stability communication
      stpStablePam = Tac.newInstance( "Arnet::TapPam", "MlagStpStable",
                                      self.mlagConstants_.udpStpStableDevName )
      stpStablePam.mode = 'enabled'
      self.mlagTunnelSm_.manyPam[ udpStpStablePortId ] = stpStablePam
      self.mlagTunnelSm_.localToPeerPortId[ udpStpStablePortId ] = \
          udpStpStablePortId

   def initIntfStateMachines( self ):
      ''' Do all the initialization that depends on the interfaces
      directories in Sysdb. '''
      self.tunnelAgentRoot_.fromPeerIntfSm = (
         self.peerIntfs_, self.mlagTunnelSm_,
         self.tapPamManager_.tapPams )
      self.tunnelAgentRoot_.fromPeerIntfSm.mode = 'enabled'

      if self.primary() and not self.tunnelAgentRoot_.toPeerIntfSmPrimary:
         if self.tunnelAgentRoot_.toPeerIntfSmSecondary:
            self.tunnelAgentRoot_.toPeerIntfSmSecondary.mode = 'disabled'
            self.tunnelAgentRoot_.toPeerIntfSmSecondary = None
         self.tunnelAgentRoot_.toPeerIntfSmPrimary = (
            self.mlagTunnelSm_, self.ethIntfs_,
            self.tunnelAgentRoot_.intfStatusManager, self.mlagStatus_ )
         self.tunnelAgentRoot_.toPeerIntfSmPrimary.mode = 'enabled'
      elif self.secondary() and not self.tunnelAgentRoot_.toPeerIntfSmSecondary:
         if self.tunnelAgentRoot_.toPeerIntfSmPrimary:
            self.tunnelAgentRoot_.toPeerIntfSmPrimary.mode = 'disabled'
            self.tunnelAgentRoot_.toPeerIntfSmPrimary = None
         assert self.mlagStatus_.peerLinkIntf
         self.tunnelAgentRoot_.toPeerIntfSmSecondary = (
            self.mlagTunnelSm_, self.ethIntfs_,
            self.tunnelAgentRoot_.intfStatusManager,
            self.mlagStatus_.peerLinkIntf, self.peerIntfs_, self.mlagStatus_ )
         self.tunnelAgentRoot_.toPeerIntfSmSecondary.mode = 'enabled'

   def deferredCleanUp( self ):
      # Kick the TapPamManager once to start flushing its queue of devices.
      if not self.cleanupActivity_ and self.tapPamManager_:
         self.tapPamManager_.processTapPamEventQueue()
         t0( 'deferred cleanup: %d devices remain queued' %
               len( self.tapPamManager_.tapPamEvent ) )

      # once no devices remain, it's safe to stop running.
      if not self.tapPamManager_ or len( self.tapPamManager_.tapPamEvent ) == 0:
         t0( 'deferred cleanup: complete' )
         self.mlagTunnelCellStatus_.running = False
         self.cleanupActivity_ = None
         return

      t0( 'waiting for cleanup: %d devices queued' %
               len( self.tapPamManager_.tapPamEvent ) )

      # poll once a second for the TapPamEvent queue to empty
      self.cleanupActivity_ = Tac.ClockNotifiee( self.deferredCleanUp )
      self.cleanupActivity_.timeMin = Tac.now() + 1.0

   def cleanUp( self, complete=False ):
      # If doing a complete cleanup, delete mlagTunnelSm
      # for initialization later.
      if complete and self.mlagTunnelSm_:
         # Completely remove mlagTunnelSm_
         self.mlagTunnelSm_.mode = 'disabled'
         self.mlagTunnelSm_.onePam.mode = 'disabled'
         self.udpSocketBoundTo_ = ""
         # Remove TapPam for UDP heartbeat
         udpHeartbeatPortId = self.mlagConstants_.udpHeartbeatPortId
         pam = self.mlagTunnelSm_.manyPam.get( udpHeartbeatPortId )
         if pam:
            pam.mode = 'disabled'
            del self.mlagTunnelSm_.manyPam[ udpHeartbeatPortId ]
         # Remove TapPam for Stp stability
         udpStpStablePortId = self.mlagConstants_.udpStpStablePortId
         stpStablePam = self.mlagTunnelSm_.manyPam.get( udpStpStablePortId )
         if stpStablePam:
            stpStablePam.mode = 'disabled'
            del self.mlagTunnelSm_.manyPam[ udpStpStablePortId ]
            del self.mlagTunnelSm_.localToPeerPortId[ udpStpStablePortId ]
         self.mlagTunnelSm_ = None
         # Delete MlagTunnelStatusUpdater
         self.handleRedundancyMode()

      if self.tunnelAgentRoot_.fromPeerIntfSm:
         self.tunnelAgentRoot_.fromPeerIntfSm.mode = 'disabled'
      if self.tunnelAgentRoot_.toPeerIntfSmPrimary:
         self.tunnelAgentRoot_.toPeerIntfSmPrimary.mode = 'disabled'
      if self.tunnelAgentRoot_.toPeerIntfSmSecondary:
         self.tunnelAgentRoot_.toPeerIntfSmSecondary.mode = 'disabled'
      self.tunnelAgentRoot_.toPeerIntfSmPrimary = None
      self.tunnelAgentRoot_.toPeerIntfSmSecondary = None
      self.tunnelAgentRoot_.fromPeerIntfSm = None

      if complete:
         self.deferredCleanUp()

   def handleIntfStatusAndDeviceName( self, notif=None, intfName=None ):
      if ( not self.mlagStatus_.localInterface or 
           ( intfName and intfName != self.mlagStatus_.localInterface.intfId ) ):
         return
      intfStatusAndDeviceName = self.tunnelAgentRoot_.intfStatusManager.\
          kernelIntfStatusDir.intfStatusAndDeviceName.get( \
         self.mlagStatus_.localInterface.intfId )
      t2( 'handleIntfStatusAndDeviceName', intfStatusAndDeviceName, intfName )
      if not intfStatusAndDeviceName or not intfStatusAndDeviceName.deviceName:
         self.udpSocketBoundTo_ = ""
         return
      if self.udpSocketBoundTo_ != intfStatusAndDeviceName.deviceName:
         self.setSockOptsOnLocalInterface()

   def setSockOptsOnLocalInterface( self ):
      if not self.mlagTunnelSm_:
         return
      pam = self.mlagTunnelSm_.onePam
      sock = socket.fromfd( pam.fileDesc.descriptor,
                            socket.AF_INET, socket.SOCK_DGRAM )
      devName = self.mlagStatus_.localInterface.deviceName
      t2( "setSockOptsOnLocalInterface", devName )
      try:
         sock.setsockopt( socket.SOL_SOCKET, socket.SO_BINDTODEVICE, devName )
         self.udpSocketBoundTo_ = devName
      except socket.error, e:
         if e.errno == errno.ENODEV:
            print "Device %s went away" % devName
            # Wait until it comes back.
         else:
            raise

   def newMlagTunnelSm( self ):
      socketOpts = [ ( socket.IPPROTO_IP, socket.IP_TTL, MLAG_TTL ),
                     ( socket.IPPROTO_IP, socket.IP_TOS, MLAG_TOS ),
                     ( socket.IPPROTO_IP, IP_RECVTTL, 1 ),
                     ( socket.SOL_SOCKET, socket.SO_PRIORITY, TC_PRIO_CONTROL) ]
      sm = Tac.newInstance( "Mlag::MlagTunnelSm",
                            self.newUdpPam( socketOptions=socketOpts ),
                            self.mlagConfig_.domainId,
                            self.protoStatus_ )
      return sm
   
   def newUdpPam( self, socketOptions=None ):
      pam = Tac.newInstance( "Arnet::UdpPam", "udpPam" )
      pam.maxRxPktData = 9900
      pam.rxIpAddr = "0.0.0.0" # INADDR_ANY
      pam.rxPort = MLAG_PORT
      pam.txDstIpAddr = self.mlagConfig_.peerAddress
      # This environment variable can be set for testing purposes:
      pam.txPort = int( os.environ.get( 'MLAG_PEER_UDP_PORT',
                                        str( MLAG_PORT ) ) )
      pam.mode = 'server'
      if socketOptions:
         # Here we assume that pam socket is AF_INET/SOCK_DGRAM (i.e. IP/UDP) 
         # but it does not really matter, because setsockopt() ignores these
         # parameters anyway.
         sock = socket.fromfd( pam.fileDesc.descriptor,
                               socket.AF_INET, socket.SOCK_DGRAM )
         for (level, opt, val) in socketOptions:
            sock.setsockopt( level, opt, val )
      return pam

class MlagTunnelStatusUpdater( object ):
   def __init__( self, cfg, status, sm ):
      self.cfg_ = cfg
      self.status_ = status
      self.sm_ = sm
      self.statusUpdateActivity_ = Tac.ClockNotifiee( self.statusUpdate )
      self.clearCountersReactor_ = GenericReactor(
         self.cfg_, [ 'clearCounters' ], self.clearCounters )

   def __del__( self ):
      self.clearCountersReactor_ = None
      self.statusUpdateActivity_ = None
      self.sm_ = None

   def statusUpdate( self ):
      status = self.status_
      # for now we only have packet counters to update
      # Setting Pam counters to zero only changes the counters, 
      # but updating them in Sysdb may trigger more actions,  
      # so, just to be safe, we first reset Pam counters, and only 
      # then update the ones in Sysdb, keeping the counters consistent.
      rx = self.sm_.onePam.rxPkts
      self.sm_.onePam.rxPkts = 0
      tx = self.sm_.onePam.txPkts
      self.sm_.onePam.txPkts = 0
      enc = self.sm_.encPkts
      self.sm_.encPkts = 0
      dec = self.sm_.decPkts
      self.sm_.decPkts = 0
      ePkts = { }
      dPkts = { }
      allCounters = self.sm_.encDecByMacPkts.values() + \
          self.sm_.encDecByProtoPkts.values() + \
          self.sm_.encDecByStrPkts.values()
        
      for counter in allCounters:
         name = counter.counterName
         ePkts[ name ] = counter.encPkts 
         counter.encPkts = 0
         dPkts[ name ] = counter.decPkts
         counter.decPkts = 0
      status.rxPkts += rx
      status.txPkts += tx
      status.encPkts += enc
      status.decPkts += dec
      for name in [ c.counterName for c in allCounters ]:
         if not name in status.encDecPkts:
            status.encDecPkts.newMember( name )
         counter = status.encDecPkts[ name ]
         counter.encPkts += ePkts[ name ]
         counter.decPkts += dPkts[ name ]

      # Reschedule one second in the future.
      self.statusUpdateActivity_.timeMin = Tac.now() + 1

   def clearCounters( self, notifiee ):
      status = self.status_
      status.rxPkts = 0
      status.txPkts = 0
      status.encPkts = 0 
      status.decPkts = 0 
      for key in status.encDecPkts:
         status.encDecPkts[ key ].encPkts = 0
         status.encDecPkts[ key ].decPkts = 0
