#!/usr/bin/env python
# Copyright (c) 2013 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from PolicyMapModel import ClassMapModel, ClassMatchModel, ClassMatchAclModel, \
                           PolicyMapModel, PolicyClassMapModel, ClassMapCounter
from PolicyMapModel import ClassMapMatchOption
from PolicyMapModel import MplsRule
import AclCliModelImpl
import Tac
import collections
import AclLib

tacClassMapStatusType = Tac.Type( "PolicyMap::ClassMapStatus" )

counterUpdateTimeout = 10
mapTypeToStatusFilter_ = {}

class ActionBase( object ):
   actionClasses_ = {}
   @staticmethod
   def addActionType( actType, clazz ):
      ActionBase.actionClasses_[ actType ] = clazz

   @classmethod
   def createModel( cls, action ):
      actType = action.__class__
      clazz = ActionBase.actionClasses_[ actType ]
      return clazz.createModel( action )

class PolicyClassMapBase( object ):
   statusClasses_ = {}
   @staticmethod
   def addStatusType( statusType, clazz ):
      PolicyClassMapBase.statusClasses_[ statusType ] = clazz

   @classmethod
   def createModel( cls, name, statuses ):
      clsType = None
      if statuses is not None:
         if hasattr( statuses, '__iter__' ):
            clsType = statuses[0].__class__ if len( statuses ) else None
         else:
            clsType = statuses.__class__

      clazz = PolicyClassMapBase.statusClasses_.get( clsType )
      if not clazz:
         # There is no class registered for this type. Just use the default.
         model = PolicyClassMapModel()
      else:
         model = clazz.createModel( statuses )
      model.name = name
      return model

def registerCmapStatusFilter( mapType, filterFunc ):
   mapTypeToStatusFilter_[ mapType ] = filterFunc

def getCmapStatusFilter( mapType ):
   statusFilter = mapTypeToStatusFilter_.get( mapType, None )
   if not statusFilter:
      statusFilter = ( lambda x: type( x ) != tacClassMapStatusType )
      registerCmapStatusFilter( mapType, statusFilter )

   return statusFilter

def aclTypetoOption( aclType ):
   if aclType == AclCliModelImpl.AclType.ip:
      return ClassMapMatchOption.matchIpAccessGroup
   elif aclType == AclCliModelImpl.AclType.ipv6:
      return ClassMapMatchOption.matchIpv6AccessGroup
   elif aclType == AclCliModelImpl.AclType.mac:
      return ClassMapMatchOption.matchMacAccessGroup
   else:
      assert True

# clear checkpoint for a policymap in every agentCheckPoint
def clearPolicyMapCheckpoint( policyName, checkpoint ):
   if checkpoint is None or not checkpoint.checkpointAgentStatus:
      return 
   for agentCheckpoint in checkpoint.checkpointAgentStatus.itervalues():
      if policyName in agentCheckpoint.pmapCkptStatus:
         del agentCheckpoint.pmapCkptStatus[ policyName ]


def clearPolicyMapCounters( mode, policyName, cliConfig, mergedConfig,
                            status, checkpoint ):

   pmapCfg = mergedConfig.pmapType.pmap.get( policyName )
   if not pmapCfg:
      return False
   pmap = pmapCfg.currCfg
   if not pmap:
      return False

   clearPolicyMapCheckpoint( policyName, checkpoint )
   pmapStatuses = [ pmapTypeStatus.status.get( policyName ) for pmapTypeStatus
                    in status.itervalues() ]
   pmapStatuses = filter( None, pmapStatuses )


   countersAreStale = not waitForPmapCounters( policyName, cliConfig.pmapType,
                                               pmapStatuses )
   if countersAreStale:
      mode.addWarning( 'Using stale counters' )

   # create a new checkpoint for the policy in each agentCheckpoint
   # agentName same as pmapTypeStatus.name
   for s in status:
      checkpoint.checkpointAgentStatus.newMember( s )

   def updatePktHitCounter( subclassStatus, ckptClassMapStatus ):
      aclTypeCkptStatus = ckptClassMapStatus.aclTypeCkptStatus
      ruleStatus = subclassStatus.ruleStatus
      ckptClassMapStatus.pktMatch += ruleStatus.pkts

      # copy the id also to checkpoint so that it can be used
      # to identify stale class-map counters in the checkpoint
      ckptClassMapStatus.version = subclassStatus.id

      # Iterate through each acl and copy the packet hit counter to 
      # checkpoint
      for aclType in AclLib.aclTypes:
         option = aclTypetoOption( aclType )
         cmapMatchStatus = subclassStatus.match.get( option )
         ckpt = aclTypeCkptStatus[ aclType ].acl
         if not cmapMatchStatus:
            continue
         for acl in cmapMatchStatus.acl.itervalues( ):
            dest = ckpt[ acl.name ]
            dest.version = acl.version
            # copy counters into status
            for r, v in acl.ruleStatus.iteritems( ):
               if r not in dest.ruleStatus:
                  value = Tac.Value( "Acl::RuleCheckpointStatus", pkts=v.pkts )
                  dest.ruleStatus[ r ] = value
               else:
                  # Individual attributes of value types can not be modified
                  pkts = dest.ruleStatus[ r ].pkts + v.pkts
                  value = Tac.Value( "Acl::RuleCheckpointStatus", pkts=pkts )
                  dest.ruleStatus[ r ] = value


   # Update every agents checkpoint independently.
   for agentCheckpoint in checkpoint.checkpointAgentStatus.itervalues():
      agentName = agentCheckpoint.name
 
      #it's possible that a certain line card is pulled out. In that case we will
      # have checkpoint for it, but we won't have a status entity 
      if agentName not in status:
         continue
      
      # A policy may not be attached  to an interafce in all the slices. Checkpoint
      # is created for a slice only if it has corresponding policy attached
      if policyName not in status[ agentName ].status:
         continue
      ckptStatus = agentCheckpoint.pmapCkptStatus.newMember( policyName )
      for _, className in pmap.classPrio.iteritems():
         ckptClassMapStatus = ckptStatus.cmapCkptStatus.newMember( className )
         ckptClassMapStatus.pktMatch = 0
         for aclType in AclLib.aclTypes:
            ckptClassMapStatus.newAclTypeCkptStatus( aclType )

         pmapStatus =  status[ agentName ].status[ policyName ]
         classStatus = pmapStatus.cmap.get( className )
         if not classStatus:
            continue

         # create a newMember for each aclRule in class
         aclTypeCkptStatus = ckptClassMapStatus.aclTypeCkptStatus
         for aclType in AclLib.aclTypes:
            # do we have aclStatus update in any one agent?
            option = aclTypetoOption( aclType )
            cmapMatchStatus = classStatus.match.get( option )
            if not cmapMatchStatus:
               continue
            ckpt = aclTypeCkptStatus[ aclType ].acl
            for acl in cmapMatchStatus.acl.itervalues( ):
               if acl.name not in ckpt.keys():
                  ckpt.newMember( acl.name )
 
         # update subclass status
         updatePktHitCounter( classStatus, ckptClassMapStatus )

      ckptStatus.ckptTimestamp = Tac.now()

def _getCheckpointAclStats( aclTypeCkptStatus, aclName, aclType, ruleid, version ):
   pktHits = 0
   if aclTypeCkptStatus is not None:
      aclCkptStatus = aclTypeCkptStatus.acl.get( aclName )
      if aclCkptStatus is not None:
         if aclCkptStatus.version == version:
            ruleStatus = aclCkptStatus.ruleStatus.get( ruleid )
            if ruleStatus is not None:
               pktHits = ruleStatus.pkts

   return pktHits

def _getAclRuleStatsCheckpoint( checkpoint, policyName, className, aclName,
                                ruleid, aclType, aclStatus ):
   pktHits = 0
   ckptTime = 0

   if checkpoint is None or not checkpoint.checkpointAgentStatus:
      return ckptTime, pktHits

   # Below check is only to get btests passing. We need to fix btests. See bug
   # 124637 for details.
   if not aclStatus.parent.parent.parent:
      return ckptTime, pktHits

   # Only the checkpoint of the slice whose policymap/classmap/acl we are cosnidering
   # now is relevant
   agentName = aclStatus.parent.parent.parent.name
   agentCheckpoint = checkpoint.checkpointAgentStatus.get( agentName, None )
   if not agentCheckpoint: 
      return ckptTime, pktHits
   
   if agentCheckpoint.pmapCkptStatus is not None: 
      pmapCkptStatus = agentCheckpoint.pmapCkptStatus.get( policyName )
      if pmapCkptStatus is not None:
         cmapCkptStatus = pmapCkptStatus.cmapCkptStatus.get( className )
         if cmapCkptStatus is not None:
            aclTypeCkptStatus = cmapCkptStatus.aclTypeCkptStatus[aclType]
            version = aclStatus.version
            pktHits = _getCheckpointAclStats( aclTypeCkptStatus, aclName,
                                              aclType, ruleid, version )
            ckptTime = pmapCkptStatus.ckptTimestamp

   return ckptTime, pktHits

def _getAclRuleStatsCkptLatest( policyName, className, aclName, ruleid, aclType,
                             checkpoints, aclStatus ):
   checkpoint, sessionCheckpoint = checkpoints

   globalTime, globalPkts = _getAclRuleStatsCheckpoint( sessionCheckpoint,
                                             policyName, className, aclName,
                                             ruleid, aclType, aclStatus )
   sessTime, sessPkts = _getAclRuleStatsCheckpoint( checkpoint,
                                             policyName, className,
                                             aclName, ruleid, aclType,
                                             aclStatus )
   # Use the latest checkpoint
   return sessPkts if sessTime >= globalTime else globalPkts

def _getRuleStatsCheckpoint( checkpoint, policyName, className, status ):
   ckptTime = 0
   pktHits = 0
 
   if checkpoint is None or not checkpoint.checkpointAgentStatus:
      return ckptTime, pktHits

   # Below check is only to get btests passing. We need to fix btests. See bug
   # 124637 for details.
   if not status.parent:
      return ckptTime, pktHits

   # Only the checkpoint of the slice whose policymap/classmap we are cosnidering
   # now is relevant
   agentName = status.parent.name
   agentCheckpoint = checkpoint.checkpointAgentStatus.get( agentName, None )
   if not agentCheckpoint: 
      return ckptTime, pktHits
 
   if agentCheckpoint.pmapCkptStatus is not None: 
      pmapCkptStatus = agentCheckpoint.pmapCkptStatus.get( policyName )
      if pmapCkptStatus is not None:
         cmapCkptStatus = pmapCkptStatus.cmapCkptStatus.get( className )
         if cmapCkptStatus is not None:
            # consider the check point pkt count only if version number match. This
            # is to ignore  stale check points that can exist after restart or
            # config replace
            if cmapCkptStatus.version == status.id:
               pktHits = cmapCkptStatus.pktMatch
               ckptTime = pmapCkptStatus.ckptTimestamp

   return ckptTime, pktHits

def _getRuleStatsCkptLatest( policyName, className, checkpoints, status ):
   checkpoint, sessionCheckpoint = checkpoints
   globalTime, globalPkts = _getRuleStatsCheckpoint( checkpoint,
                                                     policyName,
                                                     className,
                                                     status )
   sessTime, sessPkts = _getRuleStatsCheckpoint( sessionCheckpoint,
                                                 policyName,
                                                 className,
                                                 status )
   # Use the latest checkpoint
   return sessPkts if sessTime >= globalTime else globalPkts

def _getAclRuleStatsWrapper( policyName, className, checkpoints ):

   def _getAclRuleStats( aclStatus, ruleConfig, ruleid, aclType, chipName=None ):
      if not aclStatus or not len( aclStatus ) \
         or ruleConfig.action == 'remark':
         return ( None, None, None, None, None )
      
      if not hasattr( aclStatus, '__iter__' ):
         aclStatus = [ aclStatus ]

      pkts = 0
      ckptPkts = 0
      lastChangedTime = 0.0
      for a in aclStatus:
         ruleStatus = a.ruleStatus.get( ruleid )
         if ruleStatus is None:
            continue
         pkts += ruleStatus.pkts
         # lastChangedTime lets get the lowest value
         if ruleStatus.lastChangedTime:
            if not lastChangedTime or \
               lastChangedTime > ruleStatus.lastChangedTime:
               lastChangedTime = ruleStatus.lastChangedTime

         ckptPkts += _getAclRuleStatsCkptLatest( policyName, className,
                                                 a.name, ruleid, aclType, 
                                                 checkpoints, a )
      return ( pkts, 0, ckptPkts, 0, float( lastChangedTime ) )

   return _getAclRuleStats

def _getRuleStatsWrapper( policyName, className, checkpoints ):

   def _getRuleStats( status, config, ruleid, aclType, chipName=None ):
      if status is None or not len( status ):
         return ( None, None, None, None, None )
      if not hasattr( status, '__iter__' ):
         status = [ status ]
      
      pkts = 0
      ckptPkts = 0
      lastChangedTime = 0.0
      for s in status:
         if not s:
            continue
         ruleStatus = s.ruleStatus
         pkts += ruleStatus.pkts
         if ruleStatus.lastChangedTime:
            if not lastChangedTime or \
               lastChangedTime > ruleStatus.lastChangedTime:
               lastChangedTime = ruleStatus.lastChangedTime
         ckptPkts += _getRuleStatsCkptLatest( policyName, className,
                                              checkpoints, s )
      return ( pkts, 0, ckptPkts, 0, float( lastChangedTime ) )

   return _getRuleStats

def matchOptionToAclType( option ):
   if option == ClassMapMatchOption.matchIpAccessGroup:
      return AclCliModelImpl.AclType.ip
   elif option == ClassMapMatchOption.matchIpv6AccessGroup:
      return AclCliModelImpl.AclType.ipv6
   elif option == ClassMapMatchOption.matchMacAccessGroup:
      return AclCliModelImpl.AclType.mac
   else:
      assert True

def populateAclListModel( policyName, className, aclConfig, prio, status, option,
                          aclRule, aclIndex, checkpoints ):
   aclType = matchOptionToAclType( option )
   aclTypeConfig = aclConfig.config[ aclType ]
   aclListConfig = aclTypeConfig.acl.get( aclRule )
   aclListStatus = []

   if aclListConfig is None:
      return
   if status is not None:
      for s in status:
         aclTypeStatus = s.match.get( option )
         if aclTypeStatus is not None:
            aclListStatus.append( aclTypeStatus.acl.get( prio ) )

      getAclRuleStatsHandler = _getAclRuleStatsWrapper( policyName, className,
                                                        checkpoints )
      return AclCliModelImpl.getAclListModel( aclListConfig, aclListStatus,
                                              aclType, aclRule,
                                              getAclRuleStatsHandler )

def populateClassMapModel( cmap, policyName=None, status=None, aclConfig=None,
                           checkpoints=None, classMapModelType=ClassMapModel ):
   className = cmap.className
   cmapModel = classMapModelType()
   cmapModel.matchCondition = cmap.matchCondition
   cmapModel.name = cmap.className
   cmapCounter = ClassMapCounter()
   getRuleStatsHandler = _getRuleStatsWrapper( policyName, className,
                                               checkpoints )
   ( cmapCounter.packetCount, _,
     cmapCounter.checkpointPacketCount, _,
     cmapCounter.lastChangedTime ) = getRuleStatsHandler( status, None,
                                                          None, None )
   cmapModel.counterData = cmapCounter
   
   for option, clMatch in cmap.match.iteritems():
      cmatchModel = ClassMatchModel()
      cmatchModel.option = option

      if option == 'matchMplsAccessGroup':
         mplsRule = MplsRule()
         mplsRule.counterData = cmapModel.counterData
         cmatchModel.mplsRule = mplsRule

      for aclIndex, item in enumerate( clMatch.acl.iteritems() ):
         prio, aclRule = item
         prio = long( prio )
         cMatchAclModel = ClassMatchAclModel()
         cMatchAclModel.name = aclRule
         cmatchModel.acl[ prio ] = cMatchAclModel
         if aclConfig is not None:
            aclListModel = populateAclListModel( policyName,
                                                 cmap.className,
                                                 aclConfig, prio,
                                                 status, option,
                                                 aclRule, aclIndex,
                                                 checkpoints )
            if option == 'matchIpAccessGroup':
               cMatchAclModel.ipAclList = aclListModel
            elif option == 'matchIpv6AccessGroup':
               cMatchAclModel.ipv6AclList = aclListModel
            elif option == 'matchMacAccessGroup':
               cMatchAclModel.macAclList = aclListModel
      for prio, ipRuleCfg in clMatch.ipRule.iteritems():
         prio = long( prio )
         getRuleStatsHandler = _getRuleStatsWrapper( policyName,
                                                     cmap.className,
                                                     checkpoints )
         rule = AclCliModelImpl.getIp4Rule( prio, ipRuleCfg,
                                            status, prio, False,
                                            'ip', getRuleStatsHandler )
         cmatchModel.ipRule[ prio ] = rule

      for prio, ip6RuleCfg in clMatch.ip6Rule.iteritems():
         prio = long( prio )
         getRuleStatsHandler = _getRuleStatsWrapper( policyName,
                                                     cmap.className,
                                                     checkpoints )
         rule = AclCliModelImpl.getIp6Rule( prio, ip6RuleCfg,
                                            status, prio, False,
                                            'ipv6', getRuleStatsHandler )
         cmatchModel.ip6Rule[ prio ] = rule

      for prio, macRuleCfg in clMatch.macRule.iteritems():
         prio = long( prio )
         getRuleStatsHandler = _getRuleStatsWrapper( policyName,
                                                     cmap.className,
                                                     checkpoints )
         rule = AclCliModelImpl.getMacRule( prio, macRuleCfg,
                                            status, prio, False,
                                            'mac', getRuleStatsHandler )
         cmatchModel.macRule[ prio ] = rule
      cmapModel.match[ option ] = cmatchModel

   return cmapModel

def waitForPmapCounters( policyName, pmapType, pmapStatuses ):
   if not pmapStatuses:
      return True
   pmapType.counterUpdateRequestTime[ policyName ] = Tac.now()
   for pmapStatus in pmapStatuses:
      try:
         Tac.waitFor(
            lambda: pmapType.counterUpdateRequestTime.get( policyName, 0 ) <=
            pmapStatus.counterUpdateTime,
            description='counters to be available',
            warnAfter=None, sleep=True, maxDelay=0.1,
            timeout=counterUpdateTimeout )
      except ( Tac.Timeout, KeyboardInterrupt ):
         del pmapType.counterUpdateRequestTime[ policyName ]
         return False

   del pmapType.counterUpdateRequestTime[ policyName ]
   return True

class PolicyMapModelContainer( object ):
   def __init__( self, cliConfig, cliIntfConfig, intfConfig, config, status,
                 aclConfig, hwStatus, sliceHwStatus, pmapType, direction,
                 policyMapAllModel=None ):
      self.cliIntfConfig = cliIntfConfig
      self.cliConfig = cliConfig
      self.intfConfig = intfConfig
      self.config = config
      self.status = status
      self.aclConfig = aclConfig
      self.hwStatus = status
      self.sliceHwStatus = sliceHwStatus
      self.mapType = pmapType
      self.direction = direction
      self.policyMapAllModel = policyMapAllModel
      self.configuredPolicyIntf = None
      self.configuredFallbackPolicyIntf = None
      self.appliedPolicyIntf = None

   def readPolicyIntf( self ):
      if self.configuredPolicyIntf is not None:
         return
      if self.configuredFallbackPolicyIntf is not None:
         return
      self.configuredPolicyIntf = collections.defaultdict( set )
      self.configuredFallbackPolicyIntf = collections.defaultdict( set )
      self.appliedPolicyIntf = collections.defaultdict( set )

      for intf, policy in self.intfConfig.intf.iteritems():
         self.configuredPolicyIntf[ policy ].add( intf )

      for intf, policy in self.intfConfig.intfFallback.iteritems():
         self.configuredFallbackPolicyIntf[ policy ].add( intf )

      failedIntfs = collections.defaultdict( set )
      for sliceStatus in self.status.itervalues():
         for status in sliceStatus.status.itervalues():
            installedIntfs = []
            for intf, intfStatus in status.intfStatus.iteritems():
               installed = all( x.installed == 'success'
                                for x in intfStatus.installed.itervalues() )
               failed = any( x.installed != 'success'
                             for x in intfStatus.installed.itervalues() )
               if failed:
                  failedIntfs[ status.name ].add( intf )
                  self.appliedPolicyIntf[ status.name ].discard( intf )
               if installed:
                  if intf in failedIntfs[ status.name ]:
                     continue
                  installedIntfs.append( intf )
            self.appliedPolicyIntf[ status.name ].update( installedIntfs )

   def populateAll( self, mode, summary, checkpoints ):
      for name in sorted( self.config.pmapType.pmap ):
         self.populatePolicyMap( mode, name, summary, checkpoints )

   def populatePolicyMap( self, mode, name, summary, checkpoints ):
      self.readPolicyIntf()
      pmapCfg = self.config.pmapType.pmap.get( name )
      if not pmapCfg:
         return False
      pmap = pmapCfg.currCfg
      if not pmap:
         return False

      # Get all the status objects for this policy.
      pmapStatuses = [ status.status.get( name ) for status in
                       self.status.itervalues() ]
      pmapStatuses = filter( None, pmapStatuses )

      pmapModel = PolicyMapModel()
      pmapModel.name = pmap.policyName
      pmapModel.mapType = pmapCfg.type
      pmapModel.configuredIngressIntfs = list( self.configuredPolicyIntf[ name ] )
      pmapModel.configuredIngressIntfsAsFallback = list( 
                   self.configuredFallbackPolicyIntf[ name ] )
      pmapModel.appliedIngressIntfs = list( self.appliedPolicyIntf[ name ] )
      
      if summary:
         # For a summary model, this is all we return.
         self.policyMapAllModel.append( name, pmapModel )
         return True

      countersAreStale = not waitForPmapCounters( name, self.cliConfig.pmapType,
                                                  pmapStatuses )
      if countersAreStale:
         # somehow we timed out, but we'll show whatever we have
         mode.addWarning( 'Displaying stale counters' )

      # XXX This code does not currently look at the AclStatus objects.
      # Once it does, it will need to account for the keys in the collection
      # of AclStatus objects perhaps not matching up with the configured sequence
      # numbers, if the user has run a resequence operation. Since the values should
      # still be in the same order though, we can use the keys from the configuration
      # for entering into the model, and just pair up those values with the list of
      # AclStatus objects.
      for prio, className in pmap.classPrio.iteritems():
         if className in pmap.rawClassMap:
            cmapModel = ClassMapModel()
            cmap = pmap.rawClassMap[ className ]
            cmapModel.name = className
            cmapModel.matchCondition = cmap.matchCondition
            pmapModel.rawClassMap[ className ] = cmapModel
         else:
            classConfig = self.config.cmapType.cmap.get( className )
            if not classConfig:
               continue
            cmap = classConfig.currCfg
            if not cmap:
               continue

         classStatuses = [ st.cmap.get( className ) for st in pmapStatuses ]
         classStatuses = filter( None, classStatuses )

         # Find the status objects that are a subclass of the base status, if any.
         subclassStatusList = filter( getCmapStatusFilter( self.mapType ),
                                      classStatuses )

         classMapModel = PolicyClassMapBase.createModel(
                                             className, subclassStatusList )
         classMapModel.classMap = populateClassMapModel( cmap, name,
                           subclassStatusList, self.aclConfig, checkpoints )
         classAction = pmap.classAction[ className ]
         for actType, action in classAction.policyAction.iteritems():
            actionModel = ActionBase.createModel( action )
            classMapModel.configuredAction[ actType ] = actionModel
         pmapModel.classMap[ long( prio ) ] = classMapModel
      self.policyMapAllModel.append( name, pmapModel )
      return True

class ClassMapModelContainer( object ):
   def __init__( self, config, hwStatus, cmapType, classMapAllModel ):
      self.config = config
      self.hwStatus = hwStatus
      self.mapType = cmapType
      self.classMapAllModel = classMapAllModel
      self.classMapAllModel.mapType = cmapType

   def populateAll( self ):
      for name in sorted( self.config.cmapType.cmap ):
         self.populateClassMap( name )

   def populateClassMap( self, name ):
      cmapCfg = self.config.cmapType.cmap.get( name, None )
      if not cmapCfg:
         return False
      cmap = cmapCfg.currCfg
      if not cmap:
         return False

      cmapModel = populateClassMapModel( cmap )
      self.classMapAllModel.append( name, cmapModel )
      return True
