#!/usr/bin/env python
# Copyright (c) 2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import re

import BasicCli
import CliCommand
import CliMatcher
import ConfigMount
import CliParser
import CliPlugin.AgentCli as AgentCli
import LazyMount
import ShowCommand
from TableOutput import Format, TableFormatter, Headings
import Tac
import Tracing

t0 = Tracing.trace0

faultPointConfig = None
faultPointStatus = None

# Match a valid agent name.
#
# The agent may not be running, so we can't rely on AgentDirectory.
# We take it from sys/config/role. However, some agents append other
# info to the name (such as SandSlice-*), so we can't simply use that
# list.
#
# The below matcher matches a name if one of the agent names is a prefix
# to the token. Always return the the exact token with case perserved.
# Note this is used together with a DynamicKeywordsMatcher.
class KeywordsPrefixMatcher( CliMatcher.Matcher ):
   def __init__( self, keywordsFn, **kargs ):
      super( KeywordsPrefixMatcher, self ).__init__( helpdesc='', **kargs )
      self.keywordsFn_ = keywordsFn

   def match( self, mode, context, token ):
      keywords = self.keywordsFn_( mode )
      tokenLow = token.lower()
      for k in keywords:
         if len( tokenLow ) > len( k ) and tokenLow.startswith( k.lower() ):
            return CliMatcher.MatchResult( token, token )
      return CliMatcher.noMatch

   def completions( self, mode, context, token ):
      # We provide one completion if there is a match
      if token:
         m = self.match( mode, token )
         if m:
            return [ CliParser.Completion( m, 'Agent name', literal=False ) ]
         else:
            return []
      else:
         return [ CliParser.Completion( 'WORD', 'Agent name', literal=False ) ]

class AgentNameExpr( CliCommand.CliExpression ):
   expression = 'AGENT | AGENT_PREFIX'
   data = {
      'AGENT' : AgentCli.agentNameNewMatcher,
      'AGENT_PREFIX' : CliCommand.Node(
         matcher=KeywordsPrefixMatcher(
            AgentCli.allAgentNames,
            priority=CliParser.PRIO_LOW ),
         alias='AGENT' )
   }

matcherFaultPoint = CliMatcher.PatternMatcher( '[A-Za-z0-9_-]+', helpname='WORD',
      helpdesc='Fault Point Identifier' )

#--------------------------------------------------------------------------------
# show fault-point AGENT [ FAULT_POINT ]
#--------------------------------------------------------------------------------
class ShowFaultPointCmd( ShowCommand.ShowCliCommandClass ):
   syntax = 'show fault-point AGENT [ FAULT_POINT ]'
   data = {
      'fault-point' : 'Show fault point',
      'AGENT' : AgentNameExpr,
      'FAULT_POINT' : matcherFaultPoint,
   }
   hidden = True

   @staticmethod
   def handler( mode, args ):
      def printRow( tf, faultId, faultInjPt ):
         tf.startRow()
         fmt = Format( justify='left' )
         fmt.noPadLeftIs( True )
         tf.newCell( faultId )
         tf.newCell( 'Yes' if faultInjPt.enabled else 'No' )
         tf.newCell( faultInjPt.persistent )
         tf.newCell( faultInjPt.numConsecutiveRestarts )
         tf.newCell( faultInjPt.maxConsecutiveRestarts )
         faultInjAction = Tac.Value( 'FaultPt::FaultAction' )
         if faultInjPt.action == faultInjAction.supSwitchover:
            tf.newCell( 'Switchover' )
         elif faultInjPt.action == faultInjAction.agentRestart:
            tf.newCell( 'Restart' )
         else:
            tf.newCell( 'Embedded' )
         condOutput = ''
         numCols = 0
         for condName in faultInjPt.condition:
            if condOutput:
               condOutput += ', '
            numCols = numCols + 1
            condOutput += '%s=%s' % ( condName, faultInjPt.condition[ condName ] )
         if condOutput:
            tf.newCell( condOutput )
      agent = args[ 'AGENT' ]
      faultPointId = args.get( 'FAULT_POINT' )
      tf = TableFormatter()
      heading = ( ( 'Fault Point', 'l' ), ( 'Enabled', 'l' ),
            ('Persistent', 'l'), ('No Restarts', 'l'),
            ('Max Restarts', 'l'), ('Action'),
            ('Conditions', 'l') )
      hd = Headings( heading )
      hd.doApplyHeaders( tf )
      faultInjStatus = faultPointStatus.get( agent ) 
      if faultInjStatus is None:
         return
      if faultPointId:
         faultInjPt = faultInjStatus.faultPoint[ faultPointId ]
         printRow( tf, faultPointId, faultInjPt )
      else:
         for faultId in faultInjStatus.faultPoint:
            faultInjPt = faultInjStatus.faultPoint[ faultId ]
            printRow( tf, faultId, faultInjPt )
      print tf.output()

BasicCli.addShowCommandClass( ShowFaultPointCmd )

#--------------------------------------------------------------------------------
# [ no | default ] fault-point AGENT FAULT_POINT action ACTION
#                             persistent PERSISTENT restart RESTART [ { KEY_VALUE } ]
#--------------------------------------------------------------------------------
def doFaultPointUpdate( mode, agent, faultPointId, faultInjAction,
                        persistent, maxRestart, keyValuePairs, no=False ):
   faultInjConfig = faultPointConfig.get( agent )
   if faultInjConfig is None:
      return
   if no:
      enabled = False
   else:
      enabled = True
      faultInjConfig.newFaultPoint( faultPointId )
   faultInjPt = faultInjConfig.faultPoint[ faultPointId ]
   faultInjPt.action = faultInjAction
   if no:
      faultInjPt.persistent = False
      for key in faultInjPt.condition:
         del faultInjPt.condition[ key ]
   else:
      faultInjPt.maxConsecutiveRestarts = int( maxRestart )
      faultInjPt.persistent = persistent == 'True'
      if keyValuePairs:
         for kv in keyValuePairs:
            res = re.match( '(.+)=(.+)', kv )
            key = res.group( 1 )
            value = res.group( 2 )
            faultInjPt.condition[ key ] = value
   faultInjPt.enabled = enabled
   faultInjPt.updated = not faultInjPt.updated

def invalidFaultPointId( mode, agent, faultPointId ):
   faultInjStatus = faultPointStatus.get( agent )
   if faultInjStatus is None:
      return True
   try:
      faultInjStatus.faultPoint[ faultPointId ]
   except KeyError:
      mode.addError( 'FaultPoint %s does not exist' % faultPointId )
      return True
   return False

class FaultPointCmd( CliCommand.CliCommandClass ):
   syntax = ( 'fault-point AGENT FAULT_POINT action ACTION persistent PERSISTENT '
                                                'restart RESTART [ { KEY_VALUE } ]' )
   noOrDefaultSyntax = 'fault-point AGENT FAULT_POINT ...'
   data = {
      'fault-point' : 'Configure fault point',
      'AGENT' : AgentNameExpr,
      'FAULT_POINT' : matcherFaultPoint,
      'action' : 'Configure action to be taken',
      'ACTION' : CliMatcher.EnumMatcher( {
         'embedded' : "Trigger faultpoint's built-in action if any",
         'switchover' : 'Trigger switchover of supervisor',
         'restart' : 'Trigger restart of Agent',
      } ),
      'persistent' : 'Configure fault point persistence',
      'PERSISTENT' : CliMatcher.EnumMatcher( {
         'True' : 'Persist the fault point state across agent reload',
         'False' : 'Disable the fault point after the agent reload',
      } ),
      'restart' : 'Configure the number of maximum consecutive agent restart',
      'RESTART' : CliMatcher.IntegerMatcher( 1, 10,
         helpdesc='Maximum number of consecutive agent restart' ),
      'KEY_VALUE' : CliMatcher.PatternMatcher(
         pattern='[a-zA-Z0-9-]+=[a-zA-Z0-9-/:,]+',
         helpdesc='Key=Pair to indicate condition', helpname='WORD' ),
   }
   hidden = True

   @staticmethod
   def handler( mode, args ):
      agent = args[ 'AGENT' ]
      faultPointId = args[  'FAULT_POINT' ]
      faultType = args[ 'ACTION' ]
      persistent = args[ 'PERSISTENT' ]
      maxRestart = args[ 'RESTART' ]
      keyValuePairs = args.get( 'KEY_VALUE' )

      if invalidFaultPointId( mode, agent, faultPointId ):
         return
      faultInjAction = Tac.Value( 'FaultPt::FaultAction' )
      if faultType == 'switchover':
         action = faultInjAction.supSwitchover
      elif faultType == 'restart':
         action = faultInjAction.agentRestart
      else:
         action = faultInjAction.actionEmbedded
      doFaultPointUpdate( mode, agent, faultPointId, action,
                          persistent, maxRestart, keyValuePairs )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      agent = args[ 'AGENT' ]
      faultPointId = args[  'FAULT_POINT' ]
      if invalidFaultPointId( mode, agent, faultPointId ):
         return
      faultInjAction = Tac.Value( 'FaultPt::FaultAction' )
      action = faultInjAction.actionEmbedded
      doFaultPointUpdate( mode, agent, faultPointId, action,
                          None, None, None, no=True )

BasicCli.GlobalConfigMode.addCommandClass( FaultPointCmd )

def Plugin( entityManager ):
   global faultPointConfig, faultPointStatus
   faultPointConfig = ConfigMount.mount( entityManager, 'agent/faultPoint/config',
         'Tac::Dir', 'wi' )
   faultPointStatus = LazyMount.mount( entityManager, 'agent/faultPoint/status',
         'Tac::Dir', 'ri' )
