#!/bin/python
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import RcfAst
from RcfMetadata import RcfKeywords
from antlr4.error.DiagnosticErrorListener import DiagnosticErrorListener
from enum import Enum

class MessageLevel( Enum ):
   Warning = 0
   Error = 1

   @staticmethod
   def strep( level ):
      return {
         MessageLevel.Warning: "Warning",
         MessageLevel.Error: "Error",
      }[ level ]

def errorStr( msgLevel, loc, funcName, msg ):
   funcLoc = ( " (function '%s')" % funcName ) if funcName else ""
   locStr = ( ' line %i:%i' % loc ) if ( loc != ( 0, 0 ) ) else ""
   fmt = {
      'level': MessageLevel.strep( msgLevel ),
      'funcLoc': funcLoc,
      'loc': locStr,
      'msg': msg,
   }
   error = "{level}{loc}{funcLoc}: {msg}".format( **fmt )
   return error

def sourceLoc( context ):
   return context.start.line, context.start.column

class RcfSyntaxError( object ):
   """ Holds a syntax error details as provided by Antlr's ErrorListener

   Attributes:
      recoginzer (Parser): Antlr Parser Instance.
      offendingSymbol token: the unexpected token.
      line (int): line at which the error was found.
      col (int): column at which the error was found.
      msg (str): Antlr generated message.
      e (Antlr Exception): Exception details.
   """
   def __init__( self, recognizer, offendingSymbol, line, col, msg, e ):
      self.recognizer = recognizer
      self.offendingSymbol = offendingSymbol
      self.line = line
      self.col = col
      self.msg = msg
      self.e = e

   def message( self ):
      return errorStr( MessageLevel.Error, ( self.line, self.col ), None, self.msg )

class RcfResolutionError( object ):
   """ Holds resolution error details as provided by Rcf symbolic phases.

   Attributes:
      function (RcfAst.Function): function where we found the error.
      offendingRef (RcfAst.Node): offending node in the AST (call or attribute).
   """
   def __init__( self, function, offendingRef ):
      self.function = function
      self.offendingRef = offendingRef

   def messageCall( self ):
      line, col = sourceLoc( self.offendingRef.context )
      fmt = {
         't': 'function',
         'fn': self.function.name,
         'name': self.offendingRef.funcName
      }
      msg = "undefined reference to {t} '{name}'".format(
            **fmt )
      return errorStr( MessageLevel.Error, ( line, col ),
            self.function.name, msg )

   def messageAttribute( self ):
      line, col = sourceLoc( self.offendingRef.context )
      attrStr = self.offendingRef.name
      fmt = {
         't': 'attribute',
         'fn': self.function.name,
         'name': attrStr,
      }
      msg = "undefined reference to {t} '{name}'".format(
            **fmt )
      return errorStr( MessageLevel.Error, ( line, col ),
            self.function.name, msg )

   def message( self ):
      if isinstance( self.offendingRef, RcfAst.Call ):
         return self.messageCall()
      else:
         return self.messageAttribute()

class RcfDefinitionError( object ):
   """ Holds details regarding defintion phase errors.

   It is assumed that definitonError only comes from function definition.
   e.g a user defines a function called 'med', or a user re-defines the
   same function 'foo' twice.

   Attributes:
      function (Ast.Function): Ast node of the function where the error was found.
   """
   def __init__( self, function ):
      self.function = function

   def message( self ):
      funcNameToken = self.function.context.FUNCTION().getSymbol()
      line = funcNameToken.line
      col = funcNameToken.column
      fmt = {
         'fn': self.function.name,
      }
      if RcfKeywords.isKeyword( self.function.name ):
         msg = "function name '{fn}' conflicts with a language keyword".format(
               **fmt )
      else:
         msg = "redefinition of symbol '{fn}'".format( **fmt )
      return errorStr( MessageLevel.Error, ( line, col ),
            None, msg )

class RcfExtResolutionError( object ):
   """ Holds external resolution error details as provided by Rcf symbolic phases.

   Attributes:
      function (Ast.Function): Ast node of the function where the error was found.
      offendingSymbol (Ast.ExternalRef): Ast node of the external ref.
   """
   def __init__( self, function, offendingRef, isWarning=False ):
      self.function = function
      self.offendingRef = offendingRef
      self.isWarning = isWarning

   def message( self ):
      line, col = sourceLoc( self.offendingRef.context )
      fmt = {
         't': self.offendingRef.type,
         'name': self.offendingRef.name,
      }
      msg = "undefined reference to {t} '{name}'".format(
            **fmt )
      level = MessageLevel.Warning if self.isWarning else MessageLevel.Error
      return errorStr( level, ( line, col ),
            self.function.name, msg )

class RcfCycleError( object ):
   """ Holds cycle error details from cycles detected after generating
   the symbol table

   Attributes:
      cycle (list of unicode string): list of function names
      representing a detected cycle
   """
   def __init__( self, cycle ):
      self.cycle = [ str( unicodeFunc ) for unicodeFunc in cycle ]

   def message( self ):
      line = 0
      col = 0
      # Convert a list of functions representing a cycle into a human readable format
      cycleStr = " -> ".join( self.cycle + [ self.cycle[ 0 ] ] )
      fmt = {
         'cyc': cycleStr,
      }
      msg = "cycle found in function callgraph '{cyc}'".format( **fmt )
      return errorStr( MessageLevel.Error, ( line, col ), None,
            msg )

class RcfTypingError( object ):
   """ Holds typing errors found during the type binding phase.

   Attributes:
      function (AstNode.Function): the function where we found the error.
      offendingRef (AstNode): the AST construction that has invalid type.
      what (str): the description of what went wrong.
   """
   def __init__( self, function, offendingRef, what ):
      self.function = function
      self.offendingRef = offendingRef
      self.what = what

   def message( self ):
      line, col = sourceLoc( self.offendingRef.context )
      return errorStr( MessageLevel.Error, ( line, col ),
            self.function.name, self.what )

class RcfExtRefTypingError( RcfTypingError ):
   """ Holds typing errors found during the type binding phase, specific
   to external references resolved, but not fitting for the use case.
   """
   pass

class RcfNoEffectWarning( object ):
   """ Holds a warning about statement having no effect.

   Attributes:
      function (AstNode.Function): the function where we found the warning.
      offendingRef (AstNode): the AST construction that has invalid type.
      what (str): the description of what went wrong.
   """
   def __init__( self, function, offendingRef, what ):
      self.function = function
      self.offendingRef = offendingRef
      self.what = what

   def message( self ):
      line, col = sourceLoc( self.offendingRef )
      return errorStr( MessageLevel.Warning, ( line, col ),
                       self.function, self.what )

class RcfDiag( DiagnosticErrorListener ):
   """ Collects all error and warning messages that occurs during the
   compilation.

   Attribute:
     - strict (bool): Whether we're in strict mode or not.
     - syntaxErrors (list): List of syntax errors.
     - cycleErrors (list): List of cycle errors.
     - resolutionErrors (list): List error found during the resolution phase.
     - fatalErrors (list): List of fatal errors.
   """
   def __init__( self, strict ):
      super( RcfDiag, self ).__init__()
      self.strict = strict
      self.allErrors = []
      self.allWarnings = []

   def isFatal( self, error ):
      ignorableErrorTypes = ( RcfExtResolutionError, RcfExtRefTypingError )
      ignorable = isinstance( error, ignorableErrorTypes ) and not self.strict
      return not ignorable

   def hasErrors( self ):
      """ Whether or not the ErrorListner recorded an error.
      """
      for err in self.allErrors:
         if self.isFatal( err ):
            return True
      return False

   def hasWarnings( self ):
      """ Whether or not warnings exist.
      """
      return len( self.allWarnings ) > 0

   def allErrorStrList( self ):
      return [ event.message() for event in self.allErrors ]

   def allWarningStrList( self ):
      return [ event.message() for event in self.allWarnings ]

   def yieldErrorType( self, errorType ):
      errorsList = []
      for error in self.allErrors:
         if isinstance( error, errorType ):
            errorsList.append( error )
      return errorsList

   def extRefResolutionError( self, func, offendingExtRef ):
      """ Callback when a user code references a external construct that is not
      defined.

      Args:
         offendingExtRef (AstNode.ExtRef): Ast node referencing an external
                                           construct.
         line (int): Line number in the source where this symbol is referenced.
         column (int): Column number in the source where this symbol is referenced.
         msg (str): hint about what is wrong.
      """
      error = RcfExtResolutionError( func, offendingExtRef )
      self.allErrors.append( error )

   def resolutionError( self, func, offendingRef ):
      """ Callback when a user code references a construct that is not
      defined (path attribute, or rcf function).

      Args:
         offendingExtRef (AstNode): Ast node referencing something.
         line (int): Line number in the source where this symbol is referenced.
         column (int): Column number in the source where this symbol is referenced.
         msg (str): hint about what is wrong.
      """
      error = RcfResolutionError( func, offendingRef )
      self.allErrors.append( error )

   def definitionError( self, func ):
      """ Callback when a user code redefines the same symbol twice or uses a
      language keyword as a symbol name.

      Args:
         func (Ast.Function): the function that redefines an existing symbol.
      """
      error = RcfDefinitionError( func )
      self.allErrors.append( error )

   def syntaxError( self, recognizer, offendingSymbol, line, column, msg, e ):
      """ Callback when Antlr finds a syntax error. Record such error.

         Args:
            recoginzer (Parser): parser
            offendingSymbol token: the unexpected token
            line (int): line at which the error was found
            col (int): column at which the error was found
            msg (str): Antlr generated message
            e (Antlr Exception): Exception details
      """
      error = RcfSyntaxError( recognizer, offendingSymbol, line, column, msg, e )
      self.allErrors.append( error )

   def cycleError( self, cycles ):
      """ Callback when the compiler has discovered a function call while traverseing
      the callgraph generated from the symbol table.

         Args:
            cycles (list): list of cycles containing series of function calls
      """
      for cycle in cycles:
         error = RcfCycleError( cycle )
         self.allErrors.append( error )

   def typingError( self, func, offendingSymbol, what ):
      """ Callback when the compiler found a typing error during the type binding
      phase.

         Args:
            func (AstNode): the function where we found the error.
            offendingSymbol (AstNode): construction where we found the error.
            what (string): the custom message.
      """
      error = RcfTypingError( func, offendingSymbol, what )
      self.allErrors.append( error )

   def extRefTypingError( self, func, offendingSymbol, what ):
      """ Callback when the compiler found a typing error during the type binding
      phase, related to the external references only.

         Args:
            func (AstNode): the function where we found the error.
            offendingSymbol (AstNode): construction where we found the error.
            what (string): the custom message.
      """
      error = RcfTypingError( func, offendingSymbol, what )
      self.allErrors.append( error )

   def noEffectWarning( self, func, offendingSymbol, what ):
      """ Callback when the compiler found a warning about a statement that has
      no effect.

         Args:
            func (AstNode): the function where we found the error.
            offendingSymbol (AstNode): construct where we found the error.
            what (string): the custom message.
      """
      warning = RcfNoEffectWarning( func, offendingSymbol, what )
      self.allWarnings.append( warning )

   def reportAmbiguity( self, recognizer, dfa, startIndex, stopIndex, exact,
                       ambigAlts, configs ):
      assert False, "Antlr Ambiguity reported"

   def reportAttemptingFullContext( self, recognizer, dfa, startIndex, stopIndex,
                                    conflictingAlts, configs ):
      assert False, "Antlr FullContext attempt reported"

   def reportContextSensitivity( self, recognizer, dfa, startIndex, stopIndex,
                                 prediction, configs ):
      assert False, "Antlr Context Sensitivity reported"
