#!/usr/bin/env python
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import getopt
import json
import os
import sys
import time

from EventHelper import EventHealthHelper
import Tac

class DropCounterMonitor( object ):
   def __init__( self ):
      self.window = 900  # 15 minute
      self.threshold = 100
      self.count = 3
      self.show = False
      self.keep = False

      self.eventHealthHelper = None

      self.counters = {}          # { "name":value, ... }
      self.countersToChips = {}
      self.previousCounters = {}  # { "name":value, ... }
      self.violations = []        # [ ["name", timestamp, val, delta], ... ]
      self.repeatedViolators = {} # { 'violators':["name":count, ...] }
      self.repeatedViolators[ 'violators' ] = {}
      self.logEvents = {}         # { "name":timestamp, ... ]

      self.countersPath = "/tmp/DropCounters"
      self.violationsPath = "/tmp/DropCountersViolations"
      self.repeatedViolatorsPath = "/tmp/DropCountersRepeatedViolators"
      self.logEventsPath = "/tmp/DropCountersLogEvents"

      self.timestamp = time.time()

   def usage( self ):
      print """
      Monitors the Drop Adverse Counters and creates a list of the counters that
      cross a certain threshold repeatedly a number of times within a time window.
      Must be called at a certain polling intervals to examine this condition

      Usage:
          DropCounterMonitor.py [ options ]

      Created Files (JSON format):
          /tmp/DropCounters
          /tmp/DropCountersViolations
          /tmp/DropCountersRepeatedViolators

      Options:
        -c violation-count  The number of threshold violations within the time window
                    that will trigger an event
        -h, --help          Print this message
        -t threshold        Counter threshold
        -w time-window      Time window in seconds
        -k                  Keep the history of all violations even if outside window
        -s                  Print the list of violations and their timestamps
      """

   def parseOptions( self ):

      try:
         opts, _ = getopt.getopt( sys.argv[ 1 : ], "hw:t:c:sk", [ "window=",
                                                                  "threshold=",
                                                                  "count=",
                                                                  "show",
                                                                  "keep" ] )
      except getopt.GetoptError as err:
         # print help information and exit:
         print str( err )
         self.usage()
         sys.exit( 2 )
      for opt, val in opts:
         if opt in ( "-h", "--help" ):
            self.usage()
            sys.exit()
         elif opt in ( "-w", "--window" ):
            self.window = int( val )
         elif opt in ( "-t", "--threshold" ):
            self.threshold = int( val )
         elif opt in ( "-c", "--count" ):
            self.count = int( val )
         elif opt in ( "-s", "--show" ):
            self.show = True
         elif opt in ( "-k", "--keep" ):
            self.keep = True
         else:
            self.usage()
            sys.exit( 3 )

   def jsonLoad( self, filepath ):
      if os.path.isfile( filepath ):
         with open( filepath, "r" ) as fp:
            try:
               return json.load( fp )
            except ValueError:
               pass
      return None

   def trimLogEvents( self ):
      # Remove violations that are outside the time sliding window
      for counter, eventTimestamp in self.logEvents.items():
         if self.timestamp - eventTimestamp > self.window:
            del self.logEvents[ counter ]

   def readValues( self ):
      # Read the previousCounters and violations dicts from flash

      self.previousCounters = self.jsonLoad( self.countersPath )
      if not self.previousCounters:
         self.previousCounters = {}

      self.violations = self.jsonLoad( self.violationsPath )
      if not self.violations:
         self.violations = []

      self.logEvents = self.jsonLoad( self.logEventsPath )
      if not self.logEvents:
         self.logEvents = {}

   def findViolations( self ):
      # Compare previousCounters to counters. If the diference >= threshold
      # add the counter name to the violations table with a timestamp
      for counter, val in self.counters.items():
         previousVal = self.previousCounters.get( counter, 0 )
         if val - previousVal >= self.threshold:
            chipName = self.countersToChips[ counter ]
            self.violations.append( [ counter, self.timestamp, val,
                                      val - previousVal ] )
            desc = "[%s] %s value=%s delta=%s" % ( long( self.timestamp ), counter,
                                                   val, val - previousVal )
            self.eventHealthHelper.addViolation( counter, long( self.timestamp ),
                                                 chipName, desc )

   def trimViolations( self ):
      # Remove violations that are outside the time sliding window
      trimmedViolations = []
      for violation in self.violations:
         violationTimestamp = violation[ 1 ]
         if self.timestamp - violationTimestamp > self.window:
            continue
         trimmedViolations.append( violation )
      self.violations = trimmedViolations

   def saveValues( self ):
      # Save the counters and violations dicts to flash

      with open( self.countersPath, "w" ) as countersFp:
         json.dump( self.counters, countersFp )
         countersFp.write( "\n" )

      with open( self.violationsPath, "w" ) as violationsFp:
         json.dump( self.violations, violationsFp )
         violationsFp.write( "\n" )

   def saveRepeatedViolators( self ):
      # Save the repeatedViolators dict to flash

      self.repeatedViolators[ "window" ] = self.window
      self.repeatedViolators[ "threshold" ] = self.threshold
      self.repeatedViolators[ "count" ] = self.count
      self.repeatedViolators[ "keep" ] = self.keep
      self.repeatedViolators[ "timestamp" ] = time.ctime( self.timestamp )
      with open( self.repeatedViolatorsPath, "w" ) as repeatedViolatorsFp:
         json.dump( self.repeatedViolators, repeatedViolatorsFp )
         repeatedViolatorsFp.write( "\n" )

   def countViolations( self ):
      # Count the number of violations for each counter in the violations list
      # If any counter has violations >= count, return True
      self.repeatedViolators = {}
      self.repeatedViolators[ 'violators' ] = {}
      repeatedViolations = False
      for violation in self.violations:
         counter = violation[ 0 ]
         violationTimestamp = violation[ 1 ]
         if self.timestamp - violationTimestamp <= self.window:
            if counter in self.repeatedViolators[ 'violators' ]:
               self.repeatedViolators[ 'violators' ][ counter ] += 1
            else:
               self.repeatedViolators[ 'violators' ][ counter ] = 1
            if self.repeatedViolators[ 'violators' ][ counter ] >= self.count and \
               counter not in self.logEvents:
               repeatedViolations = True
      self.saveRepeatedViolators()
      self.eventHealthHelper.publishEvents( long( self.timestamp ),
                                            windowSize=self.window,
                                       thresholdCount=self.count )
      if repeatedViolations:
         return True
      return False

   def getCounterValues( self ):
      pass

   def run( self ):
      self.parseOptions()
      self.eventHealthHelper = EventHealthHelper( 'ar' )
      self.readValues()
      self.trimLogEvents()
      if self.show:
         for violation in self.violations:
            counter = violation[ 0 ]
            violationTimestamp = violation[ 1 ]
            counterVal = violation[ 2 ]
            counterDelta = violation[ 3 ]
            print time.ctime( violationTimestamp ), counter, counterVal, counterDelta
      else:
         self.getCounterValues()
         self.timestamp = time.time()
         self.findViolations()
         if not self.keep:
            self.trimViolations()
         self.saveValues()
         if self.countViolations():
            print "1"
         else:
            print "0"

            
