# Copyright (c) 2015 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
import Tac
import random

MSGID = Tac.Type( "Controller::ControllerMessageId" )

class MountInterfaceMgr( object ):
   def __init__( self, serviceMountConfigDir, serviceMountStatusDir ):
      self.serviceMountConfigDir = serviceMountConfigDir
      self.serviceMountStatusDir = serviceMountStatusDir
   
   def newServiceMountConfig( self, serviceName ):
      smc = self.serviceMountConfigDir.newEntity( 
                  "Controller::ServiceMountConfig", 
                  serviceName )
      smc.serviceName = serviceName
      return smc

   def delServiceMountConfig( self, serviceName ):
      self.serviceMountConfigDir.deleteEntity( serviceName )
   
   def getServiceMountConfig( self, serviceName ):
      smc = None
      try:
         smc = self.serviceMountConfigDir[ serviceName ]
      except KeyError:
         return None
      return smc

   def getServiceMountStatus( self, serviceName ):
      sms = None
      try:
         sms = self.serviceMountStatusDir[ serviceName ]
      except KeyError:
         return None
      return sms
      
   def newVersionFilter( self, serviceName, filterName ):
      smc = self.serviceMountConfigDir[ serviceName ]
      return smc.newVersionFilter( filterName )

   def delVersionFilter( self, serviceName, filterName ):
      smc = self.serviceMountConfigDir[ serviceName ]
      try:
         del smc.versionFilter[ filterName ]
      except KeyError:
         pass

   def newSystemFilter( self, serviceName, filterName ):
      smc = self.serviceMountConfigDir[ serviceName ]
      return smc.newSystemFilter( filterName )

   def delSystemFilter( self, serviceName, filterName ):
      smc = self.serviceMountConfigDir[ serviceName ]
      try:
         del smc.systemFilter[ filterName ]
      except KeyError:
         pass

   def newMountGroupConfig( self, serviceName, mgName, vfName=None, sfName=None ):
      smc = self.serviceMountConfigDir[ serviceName ]
      mgc = smc.newMountGroupConfig( mgName )
      vf = None
      sf = None
      if vfName is not None:
         vf = smc.versionFilter[ vfName ]
      if sfName is not None:
         sf = smc.systemFilter[ sfName ]
      mgc.versionFilter = vf
      mgc.systemFilter = sf
      return mgc

   def delMountGroupConfig( self, serviceName, mgName ):
      smc = self.serviceMountConfigDir[ serviceName ]
      try:
         del smc.mountGroupConfig[ mgName ]
      except KeyError:
         pass

   def newIncludedVersion( self, serviceName, filterName, version ):
      smc = self.serviceMountConfigDir[ serviceName ]
      vf = smc.versionFilter[ filterName ]
      vf.includedServiceVersion[ version ] = True

   def delIncludedVersion( self, serviceName, filterName, version ):
      smc = self.serviceMountConfigDir[ serviceName ]
      vf = smc.versionFilter[ filterName ]
      try:
         del vf.includedServiceVersion[ version ]
      except KeyError:
         pass

   def newIncludedSystem( self, serviceName, filterName, system ):
      smc = self.serviceMountConfigDir[ serviceName ]
      sf = smc.systemFilter[ filterName ]
      sf.includedSystem[ system ] = True

   def delIncludedSystem( self, serviceName, filterName, system ):
      smc = self.serviceMountConfigDir[ serviceName ]
      sf = smc.systemFilter[ filterName ]
      try:
         del sf.includedSystem[ system ]
      except KeyError:
         pass
      
   def newMountConfig( self, serviceName, mgName, remotePath,
                       taccType, mountFlags, targetMountPath ):
      smc = self.serviceMountConfigDir[ serviceName ]
      mgc = smc.mountGroupConfig[ mgName ]
      return mgc.newMountConfig( remotePath, taccType, mountFlags, targetMountPath )

   def delMountConfig( self, serviceName, mgName, remotePath ):
      smc = self.serviceMountConfigDir[ serviceName ]
      mgc = smc.mountGroupConfig[ mgName ]
      try:
         del mgc.mountConfig[ remotePath ]
      except KeyError:
         pass

   def randomServiceMountConfig( self, serviceName ):
      maxItems = 4
      smc = self.newServiceMountConfig( serviceName )
      nVf = random.randint( 1, maxItems )
      nSf = random.randint( 1, maxItems )
      nMg = random.randint( 1, maxItems )
      
      for mg in smc.mountGroupConfig.itervalues():
         mg.versionFilter = None
         mg.systemFilter = None
         
      for i in xrange( 1, nVf ):
         filterName = "vf%d" % i
         if random.choice( [ True, False ] ):
            vf = self.newVersionFilter( serviceName, filterName )
            nIncVer = random.randint( 1, maxItems )
            for v in xrange( 1, nIncVer ):
               if random.choice( [ True, False ] ):
                  self.newIncludedVersion( serviceName, filterName, v )
               else:
                  self.delIncludedVersion( serviceName, filterName, v )

            # Create atleast one version
            if len( vf.includedServiceVersion ) is 0:
               self.newIncludedVersion( serviceName, filterName, 1 )
         else:
            self.delVersionFilter( serviceName, filterName )
      
      for i in xrange( 1, nSf ):
         filterName = "sf%d" % i
         if random.choice( [ True, False ] ):
            sf = self.newSystemFilter( serviceName, filterName )
            nIncSys = random.randint( 1, maxItems )
            for s in xrange( 1, nIncSys ):
               if random.choice( [ True, False ] ):
                  self.newIncludedSystem( serviceName, filterName, "sys%d" % s )
               else:
                  self.delIncludedSystem( serviceName, filterName, "sys%d" % s )

            # Create atleast one system
            if len( sf.includedSystem ) is 0:
               self.newIncludedSystem( serviceName, filterName, "sys1" )
         else:
            self.delSystemFilter( serviceName, filterName )
      
      for i in xrange( 1, nMg ):
         mgName = "mg%d" % i
         if random.choice( [ True, False ] ):
            vfName = random.choice( smc.versionFilter.keys() + [ None ] )
            sfName = random.choice( smc.systemFilter.keys() + [ None ] )
            mg = self.newMountGroupConfig( serviceName, mgName, vfName, sfName )
            nMc = random.randint( 1, maxItems )
            for m in xrange( 1, nMc ):
               mcData = random.choice( [ 'x', 'y', 'z' ] )
               if random.choice( [ True, False ] ):
                  self.delMountConfig( serviceName, mgName, "mc%d" % m )
                  self.newMountConfig( serviceName, mgName, "mc%d" % m, 
                                       mcData, mcData, mcData )
               else:
                  self.delMountConfig( serviceName, mgName, "mc%d" % m )

            # Create atleast one mountConfig
            if len( mg.mountConfig ) is 0:
               self.newMountConfig( serviceName, mgName, "mc1", "x", "y", "z" )
         else:
            self.delMountGroupConfig( serviceName, mgName )
      
      return smc
   
   def verifyServiceMountConfig( self, serviceName, deleted=False, verifyStatus=True,
                                 wait=True ):
      def waitFn():
         retSmc = False
         retSms = False
         smc = self.getServiceMountConfig( serviceName )
         if smc:
            retSmc = not deleted
         else:
            retSmc = deleted
         
         if verifyStatus:
            sms = self.getServiceMountStatus( serviceName )
            if sms:
               retSms = not deleted
            else:
               retSms = deleted 
         
         if verifyStatus:
            return retSmc and retSms
         else:
            return retSmc
      
      if wait:
         Tac.waitFor( waitFn, description="verifyServiceMountConfig" )
      else:
         return waitFn()

   def verifyVersionFilter( self, serviceName, filterName, deleted=False, 
                            wait=True ):
      def waitFn():
         smc = self.getServiceMountConfig( serviceName )
         try:
            _ = smc.versionFilter[ filterName ]
            return not deleted 
         except KeyError:
            return deleted
      
      if wait:
         Tac.waitFor( waitFn, description="verifyVersionFilter" )
      else:
         return waitFn()

   def verifySystemFilter( self, serviceName, filterName, deleted=False, wait=True ):
      def waitFn():
         smc = self.getServiceMountConfig( serviceName )         
         try:
            _ = smc.systemFilter[ filterName ]
            return not deleted 
         except KeyError:
            return deleted
      
      if wait:
         Tac.waitFor( waitFn, description="verifySystemFilter" )
      else:
         return waitFn()
   
   def verifyMountGroupConfig( self, serviceName, mgName, vfName=None, sfName=None, 
                               deleted=False, wait=True ):
      def waitFn():
         smc = self.getServiceMountConfig( serviceName )         
         try:
            mgc = smc.mountGroupConfig[ mgName ]
         except KeyError:
            return deleted

         if deleted:
            return False
         expVf = smc.versionFilter[ vfName ] if vfName else None
         expSf = smc.systemFilter[ sfName ] if sfName else None
         return mgc.versionFilter == expVf and mgc.systemFilter == expSf
      
      if wait:
         Tac.waitFor( waitFn, description="verifyMountGroupConfig" )
      else:
         return waitFn()

   def verifyIncludedVersion( self, serviceName, filterName, version, 
                              deleted=False, wait=True ):
      def waitFn():
         smc = self.getServiceMountConfig( serviceName )
         vf = smc.versionFilter[ filterName ]
         try:
            _ = vf.includedServiceVersion[ version ]
            return not deleted
         except KeyError:
            return deleted
      
      if wait:   
         Tac.waitFor( waitFn, description="verifyIncludedVersion" )
      else:
         return waitFn()
   
   def verifyIncludedSystem( self, serviceName, filterName, system, 
                             deleted=False, wait=True ):
      def waitFn():
         smc = self.serviceMountConfigDir[ serviceName ]
         sf = smc.systemFilter[ filterName ]
         try:
            _ = sf.includedSystem[ system ]
            return not deleted
         except KeyError:
            return deleted
      
      if wait:
         Tac.waitFor( waitFn, description="verifyIncludedSystem" )
      else:
         return waitFn()
   
   def verifyMountConfig( self, serviceName, mgName, remotePath, 
                          taccType=None, mountFlags=None, 
                          targetMountPath=None, deleted=False, wait=True ):
      def waitFn():
         smc = self.serviceMountConfigDir[ serviceName ]
         mgc = smc.mountGroupConfig[ mgName ]
         try:
            mc = mgc.mountConfig[ remotePath ]
         except KeyError:
            return deleted
         if deleted:
            return False
         return ( mc.taccType == taccType and
                  mc.mountFlags == mountFlags and
                  mc.targetMountPath == targetMountPath )
      
      if wait:
         Tac.waitFor( waitFn, description="verifyMountConfig" )
      else:
         return waitFn()
      
   def _verifyFullVersionFilter( self, smc, expSmc ):
      serviceName = smc.serviceName
      for expVfName, expVf in expSmc.versionFilter.iteritems():
         ret = self.verifyVersionFilter( serviceName, expVfName, wait=False )
         if not ret: 
            return False
         vf = smc.versionFilter[ expVfName ]

         for expVer in expVf.includedServiceVersion:
            ret = self.verifyIncludedVersion( serviceName, expVfName, expVer, 
                                              wait=False )
            if not ret: 
               return False

         for ver in vf.includedServiceVersion:
            if ver not in expVf.includedServiceVersion:
               ret = self.verifyIncludedVersion( serviceName, expVfName, ver, 
                                                 deleted=True, wait=False )
               if not ret: 
                  return False
               
      for vfName, vf in smc.versionFilter.iteritems():
         if vfName not in expSmc.versionFilter:
            ret = self.verifyVersionFilter( serviceName, vfName, deleted=True, 
                                            wait=False )
            if not ret: 
               return False
      
      return True

   def _verifyFullSystemFilter( self, smc, expSmc ):
      serviceName = smc.serviceName
      for expSfName, expSf in expSmc.systemFilter.iteritems():
         ret = self.verifySystemFilter( serviceName, expSfName, wait=False )
         if not ret: 
            return False
         
         sf = smc.systemFilter[ expSfName ]
         
         for expSys in expSf.includedSystem:
            ret = self.verifyIncludedSystem( serviceName, expSfName, expSys, 
                                             wait=False )
            if not ret: 
               return False

         for sys in sf.includedSystem:
            if sys not in expSf.includedSystem:
               ret = self.verifyIncludedSystem( serviceName, expSfName, sys, 
                                                deleted=True, wait=False )
               if not ret: 
                  return False

      for sfName, sf in smc.systemFilter.iteritems():
         if sfName not in expSmc.systemFilter:
            ret = self.verifySystemFilter( serviceName, sfName, deleted=True, 
                                           wait=False )
            if not ret: 
               return False
   
      return True
   
   def _verifyFullMountGroupConfig( self, smc, expSmc ):
      serviceName = smc.serviceName
      for expMgName, expMg in expSmc.mountGroupConfig.iteritems():
         expVfName = expMg.versionFilter.name if expMg.versionFilter else None 
         expSfName = expMg.systemFilter.name if expMg.systemFilter else None
         ret = self.verifyMountGroupConfig( serviceName, expMgName, expVfName, 
                                            expSfName, wait=False )
         if not ret: 
            return False
         
         mg = smc.mountGroupConfig[ expMgName ]
         
         for expRpath, expMc in expMg.mountConfig.iteritems():
            ret = self.verifyMountConfig( serviceName, expMgName, expRpath, 
                                          expMc.taccType, expMc.mountFlags, 
                                          expMc.targetMountPath, wait=False )
            if not ret: 
               return False
         
         for rpath in mg.mountConfig:
            if rpath not in expMg.mountConfig:
               ret = self.verifyMountConfig( serviceName, expMgName, rpath, 
                                             deleted=True, wait=False )
               if not ret: 
                  return False
 
      for mgName, mg in smc.mountGroupConfig.iteritems():
         if mgName not in expSmc.mountGroupConfig:
            ret = self.verifyMountGroupConfig( serviceName, mgName, 
                                               deleted=True, wait=False )
            if not ret: 
               return False
      
      return True
   
   def verifyFullServiceMountConfig( self, serviceName, expSmc ):
      def waitFn():
         ret = self.verifyServiceMountConfig( serviceName, wait=False )
         if not ret: 
            return False
         
         smc = self.getServiceMountConfig( serviceName )
         
         ret = self._verifyFullVersionFilter( smc, expSmc )
         if not ret: 
            return False
         
         ret = self._verifyFullSystemFilter( smc, expSmc )
         if not ret: 
            return False
         
         ret = self._verifyFullMountGroupConfig( smc, expSmc )
         if not ret: 
            return False
         
         return True
      Tac.waitFor( waitFn, description="verifyFullServiceMountConfig" )
      
class MountInterfaceUpdateMgr( object ):
   def __init__( self, miUpdate=None ):
      self.miUpdate = miUpdate

   def newMarkerUpdate( self, start=True ):
      if self.miUpdate is None:
         marker = Tac.newInstance( 
                      "Controller::CvxMountInterfaceMarkerMessage", None )
      else:
         marker = self.miUpdate.mountInterfaceMarker
      
      marker.empty = True
      if start:
         marker.start = True

      if self.miUpdate is None:
         miUpdate = Tac.newInstance( 
                        "ControllerCluster::MountInterfaceUpdate",
                        marker, None, None, None )
      else:
         miUpdate = self.miUpdate
      
      miUpdate.type = MSGID.cvxMountInterfaceMarkerMessage
      return miUpdate

   def newIncludedVersionUpdate( self, serviceName, filterName, version=None, 
                                 deleted=False ):
      if self.miUpdate is None:
         incVerUpdate = Tac.newInstance( 
                            "Controller::CvxIncludedVersionMessage", None )
      else:
         incVerUpdate = self.miUpdate.includedVersion
         
      incVerUpdate.empty = True
      incVerUpdate.serviceName = serviceName
      incVerUpdate.filterName = filterName
      if version is not None:
         incVerUpdate.version = version
      incVerUpdate.deleted = deleted
      
      if self.miUpdate is None:
         miUpdate = Tac.newInstance( 
                        "ControllerCluster::MountInterfaceUpdate",
                         None, incVerUpdate, None, None )
      else:
         miUpdate = self.miUpdate

      miUpdate.type = MSGID.cvxIncludedVersionMessage
      return miUpdate

   def newIncludedSystemUpdate( self, serviceName, filterName, system=None, 
                                deleted=False ):
      if self.miUpdate is None:
         incSysUpdate = Tac.newInstance( 
                            "Controller::CvxIncludedSystemMessage", None )
      else:
         incSysUpdate = self.miUpdate.includedSystem
         
      incSysUpdate.empty = True
      incSysUpdate.serviceName = serviceName
      incSysUpdate.filterName = filterName
      if system is not None:
         incSysUpdate.system = system
      incSysUpdate.deleted = deleted
      
      if self.miUpdate is None:
         miUpdate = Tac.newInstance( 
                        "ControllerCluster::MountInterfaceUpdate",
                         None, None, incSysUpdate, None )
      else:
         miUpdate = self.miUpdate

      miUpdate.type = MSGID.cvxIncludedSystemMessage
      return miUpdate

   def newMountConfigUpdate( self, serviceName, mgName, vfName=None, sfName=None,
                             remotePath=None, taccType=None, mountFlags=None, 
                             targetMountPath=None, deleted=False ):
      if self.miUpdate is None:
         mcUpdate = Tac.newInstance( "Controller::CvxMountConfigMessage", None )
      else:
         mcUpdate = self.miUpdate.mountConfig
         
      mcUpdate.empty = True
      mcUpdate.serviceName = serviceName
      mcUpdate.mountGroupName = mgName
      
      if vfName is not None:
         mcUpdate.versionFilterName = vfName
      if sfName is not None:
         mcUpdate.systemFilterName = sfName
      if remotePath is not None:
         mcUpdate.remotePath = remotePath
      if taccType is not None:
         mcUpdate.taccType = taccType
      if mountFlags is not None:
         mcUpdate.mountFlags = mountFlags
      if targetMountPath is not None:
         mcUpdate.targetMountPath = targetMountPath
      
      mcUpdate.deleted = deleted
      
      if self.miUpdate is None:
         miUpdate = Tac.newInstance( 
                        "ControllerCluster::MountInterfaceUpdate",
                         None, None, None, mcUpdate )
      else:
         miUpdate = self.miUpdate

      miUpdate.type = MSGID.cvxMountConfigMessage
      return miUpdate
   
   def getActualUpdate( self, actUpdate ):
      if actUpdate is None:
         actUpdate = self.miUpdate
      assert actUpdate
      return actUpdate
      
   def verifyMarkerUpdate( self, expUpdate, actUpdate=None ):
      actUpdate = self.getActualUpdate( actUpdate )
      actMarkerUpdate = actUpdate.mountInterfaceMarker
      expMarkerUpdate = expUpdate.mountInterfaceMarker
      assert actMarkerUpdate.start == expMarkerUpdate.start

   def verifyIncludedVersionUpdate( self, expUpdate, actUpdate=None ):
      actUpdate = self.getActualUpdate( actUpdate )      
      actIvUpdate = actUpdate.includedVersion
      expIvUpdate = expUpdate.includedVersion
      assert actIvUpdate.serviceName == expIvUpdate.serviceName
      assert actIvUpdate.filterName == expIvUpdate.filterName
      assert actIvUpdate.version == expIvUpdate.version
      assert actIvUpdate.deleted == expIvUpdate.deleted

   def verifyIncludedSystemUpdate( self, expUpdate, actUpdate=None ):
      actUpdate = self.getActualUpdate( actUpdate )
      actIsUpdate = actUpdate.includedSystem
      expIsUpdate = expUpdate.includedSystem
      assert actIsUpdate.serviceName == expIsUpdate.serviceName
      assert actIsUpdate.filterName == expIsUpdate.filterName
      assert actIsUpdate.system == expIsUpdate.system
      assert actIsUpdate.deleted == expIsUpdate.deleted

   def verifyMountConfigUpdate( self, expUpdate, actUpdate=None ):
      actUpdate = self.getActualUpdate( actUpdate )
      actMcUpdate = actUpdate.mountConfig
      expMcUpdate = expUpdate.mountConfig
      assert actMcUpdate.serviceName == expMcUpdate.serviceName
      assert actMcUpdate.mountGroupName == expMcUpdate.mountGroupName
      assert actMcUpdate.versionFilterName == expMcUpdate.versionFilterName
      assert actMcUpdate.systemFilterName == expMcUpdate.systemFilterName
      assert actMcUpdate.remotePath == expMcUpdate.remotePath
      assert actMcUpdate.taccType == expMcUpdate.taccType
      assert actMcUpdate.mountFlags == expMcUpdate.mountFlags
      assert actMcUpdate.targetMountPath == expMcUpdate.targetMountPath
      assert actMcUpdate.deleted == expMcUpdate.deleted
