#!/usr/bin/env python
# Copyright (c) 2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac

_smashMount = None
_pimStatusColl = {}
_pimAllStatusColl = None
_pimGlobalStatus = None
_pimGlobalStatusReaderSm = None
_pimGlobalStatusModeFilterSm = {}
_pim6StatusColl = {}
_pim6AllStatusColl = None
_pim6GlobalStatus = None
_pim6GlobalStatusReaderSm = None
_pim6GlobalStatusModeFilterSm = {}

AddressFamily = Tac.Type( "Arnet::AddressFamily" )

def _initPimGlobalStatusReader( entityManager, af ):
   global _pimGlobalStatusReaderSm, _pim6GlobalStatusReaderSm
   pgsr = Tac.newInstance(
         "Routing::Pim::Smash::GlobalStatusReaderSm",
         af, getPimGlobalStatus( entityManager, af ),
         getPimAllStatusColl( af ), False, Tac.activityManager.clock,
         getSmashMount( entityManager ) )

   if af == AddressFamily.ipv4:
      _pimGlobalStatusReaderSm = pgsr
   elif af == AddressFamily.ipv6:
      _pim6GlobalStatusReaderSm = pgsr
   else:
      assert False, "Unsupported address family"

def _initPimGlobalStatusModeFilter( af, mode ):
   if af == AddressFamily.ipv4:
      coll = _pimGlobalStatusModeFilterSm
   elif af == AddressFamily.ipv6:
      coll = _pim6GlobalStatusModeFilterSm
   else:
      assert False, "Unsupported address family"
   coll[ mode ] = Tac.newInstance(
         "Routing::Pim::GlobalStatusModeFilterSm",
         getPimAllStatusColl( af ), getPimStatusColl( af, mode ),
         mode )

def _initPimGlobalStatus( entityManager, af, blocking ):
   global _pimGlobalStatus, _pim6GlobalStatus
   mg = entityManager.mountGroup()
   path = getPath( "Routing::Pim::GlobalStatus", af )
   pgs = mg.mount( path, "Routing::Pim::GlobalStatus", "r" )

   if af == AddressFamily.ipv4:
      _pimGlobalStatus = pgs
   elif af == AddressFamily.ipv6:
      _pim6GlobalStatus = pgs
   else:
      assert False, "Unsupported address family"

   mg.close( lambda: _initPimGlobalStatusReader( entityManager, af ), blocking )

def getPimStatusModeFilterSm( af, mode ):
   if af == AddressFamily.ipv4:
      coll = _pimGlobalStatusModeFilterSm
   elif af == AddressFamily.ipv6:
      coll = _pim6GlobalStatusModeFilterSm
   else:
      assert False, "Unsupported address family"

   if mode not in coll:
      _initPimGlobalStatusModeFilter( af, mode )

   return coll[ mode ]

def getPimGlobalStatusReaderSm( entityManager, af ):
   if af == AddressFamily.ipv4:
      reader = _pimGlobalStatusReaderSm
   elif af == AddressFamily.ipv6:
      reader = _pim6GlobalStatusReaderSm
   else:
      assert False, "Unsupported address family"

   if reader is None:
      _initPimGlobalStatusReader( entityManager, af )

   return _pimGlobalStatusReaderSm if af == AddressFamily.ipv4 \
                                 else _pim6GlobalStatusReaderSm

def getPimGlobalStatus( entityManager, af, blocking=False ):

   if af == AddressFamily.ipv4:
      pgs = _pimGlobalStatus
   elif af == AddressFamily.ipv6:
      pgs = _pim6GlobalStatus
   else:
      assert False, "Unsupported address family"

   if pgs is None:
      _initPimGlobalStatus( entityManager, af, blocking )

   return _pimGlobalStatus if af == AddressFamily.ipv4 \
                           else _pim6GlobalStatus

def getPimAllStatusColl( af ):
   global _pimAllStatusColl, _pim6AllStatusColl

   pimAfAllStatusColl = None
   if af == AddressFamily.ipv4:
      if _pimAllStatusColl is None:
         _pimAllStatusColl = Tac.newInstance( "Routing::Pim::AllStatusColl" )
      pimAfAllStatusColl = _pimAllStatusColl
   elif af == AddressFamily.ipv6:
      if _pim6AllStatusColl is None:
         _pim6AllStatusColl = Tac.newInstance( "Routing::Pim::AllStatusColl" )
      pimAfAllStatusColl = _pim6AllStatusColl
   else:
      assert False, "Unsupported address family"

   return pimAfAllStatusColl

def getPimStatusColl( af, mode ):
   if af == AddressFamily.ipv4:
      coll = _pimStatusColl
   elif af == AddressFamily.ipv6:
      coll = _pim6StatusColl
   else:
      assert False, "Unsupported address family"

   if mode not in coll:
      coll[ mode ] = Tac.newInstance( 'Routing::Pim::StatusColl' )
   return coll[ mode ]

def getSmashMount( entityManager ):
   global _smashMount
   # TODO - We should convert Routing::Pim::Smash::GlobalStatusReaderSm to store a
   # TacSharedMem::EntityManager::Ptr instead of the deprecated TacSmash::Mount
   if _smashMount is None:
      _smashMount = Tac.newInstance( 'TacSmash::Mount',
                                     entityManager.cEntityManager() )
   return _smashMount

def getPath( typeName, *args ):
   return Tac.Type( typeName ).mountPath( *args )

def cleanup():
   global _smashMount, _pimStatusColl, _pimAllStatusColl, _pimGlobalStatus, \
   _pimGlobalStatusReaderSm, _pimGlobalStatusModeFilterSm, _pim6StatusColl, \
   _pim6AllStatusColl, _pim6GlobalStatus, _pim6GlobalStatusReaderSm, \
   _pim6GlobalStatusModeFilterSm

   _smashMount = None
   _pimStatusColl = {}
   _pimAllStatusColl = None
   _pimGlobalStatus = None
   _pimGlobalStatusReaderSm = None
   _pimGlobalStatusModeFilterSm = {}
   _pim6StatusColl = {}
   _pim6AllStatusColl = None
   _pim6GlobalStatus = None
   _pim6GlobalStatusReaderSm = None
   _pim6GlobalStatusModeFilterSm = {}

