#!/bin/python
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

from contextlib import contextmanager

import RcfAst
import RcfAstVisitor
import RcfLibCycleDetection
import RcfSymbolTable
import RcfSymbol
import RcfTypeBinding
import RcfMetadata

BT = RcfMetadata.RcfBuiltinTypes
RcfBuiltinSymbols = RcfMetadata.RcfBuiltinSymbols

class DefinitionPhase( RcfAstVisitor.Visitor ):
   """ Definiton phase of the semantic analysis.

   During this phase, we visit the AST, and most importantly:

      - we define built-in symbols (med, prefix etc...)
      - we define each function we encounter
      - we annotate the AST nodes with the scope it belongs to.
      - we gather meta data (any)

   !Rules (don't change unless discussing with authors first)
      - Define the visit methods in the order in which the AST nodes are defined.

      @author: matthieu (rcf-dev)

   Attributes:
      globalScope (Scope): the global scope.
      currentScope (Scope): the current scope changes as we change scope.
      diags (RcfDiag): the diagnostic report object.
   """
   def __init__( self, diag ):
      """ Constructor

      Args:
         diags (RcfDiag): the diagnostic report object.
      """
      super( DefinitionPhase, self ).__init__()
      self.diag = diag
      self.globalScope = RcfSymbolTable.GlobalScope()
      self.currentScope = self.globalScope
      self.defineBuiltInAttributes()
      self.currentFunctionName = None

   @contextmanager
   def scope( self, newScope ):
      """ Creates a current scope, and restore the previous scope
      when exiting the context.

      Args:
         newScope ( AbstractScope ): the new scope.
      """
      currentScope = self.currentScope
      self.currentScope = newScope
      yield newScope
      self.currentScope = currentScope

   def defineBuiltInAttributes( self ):
      """ Defines all the built-in attributes registered in RcfAttributes in
      the global scope.
      """
      for symbol in RcfBuiltinSymbols.builtInAttributes:
         self.globalScope.define( symbol )

   #---------------------------------------------------------------------------
   #                       visit methods override.
   #---------------------------------------------------------------------------
   def visitRoot( self, root, **kwargs ):
      for function in root.functions:
         self.visitFunction( function )
      return root

   def visitFunction( self, function, **kwargs ):
      # todo BUG423280, we should have unit-tested that the symbol node is set!
      funcSymbolScope = RcfSymbol.Function(
         name=function.name, rcfType=BT.Function, retType=BT.Trilean, node=function,
         enclosingScope=self.currentScope )

      symbol = self.currentScope.resolve( function.name )
      if symbol: # This name is already defined, we can't redefine this symbol
         self.diag.definitionError( function )
      elif RcfMetadata.RcfKeywords.isKeyword( function.name ):
         # The function name clashes with a language keyword
         self.diag.definitionError( function )
      else:
         self.currentScope.define( funcSymbolScope )

      self.currentFunctionName = function.name
      function.symbol = funcSymbolScope
      with self.scope( funcSymbolScope ):
         self.visit( function.block )

   def visitBlock( self, block, **kwargs ):
      with self.scope( RcfSymbolTable.BlockScope( self.currentScope ) ):
         for stmt in block.stmts:
            self.visit( stmt )

   def visitIfStmt( self, ifStmt, **kwargs ):
      self.visit( ifStmt.condition )
      self.visit( ifStmt.thenBlock )
      if ifStmt.elseBlock:
         self.visit( ifStmt.elseBlock )

   def visitCall( self, call, **kwargs ):
      call.scope = self.currentScope

   def visitSequentialExpr( self, sequentialExpr, **kwargs ):
      for call in sequentialExpr.funcCalls:
         self.visit( call )

   def visitExternalRefOp( self, externalRefOp, **kwargs ):
      self.visit( externalRefOp.attribute )
      self.visit( externalRefOp.extRef )

   def visitExternalRef( self, extRef, **kwargs ):
      extRef.scope = self.currentScope

   def visitAssign( self, assign, **kwargs ):
      self.visit( assign.attribute )
      self.visit( assign.value )

   def visitReturn( self, returnStmt, **kwargs ):
      self.visit( returnStmt.expr )

   def visitConstant( self, constant, **kwargs ):
      if constant.type == RcfAst.Constant.Type.asPath:
         self.visitAsPathValue( constant )

   def visitAsPathValue( self, asPathValue ):
      for asn in asPathValue.value:
         if asn.get( "attr" ):
            self.visit( asn[ "attr" ] )

   def visitAttribute( self, attribute, **kwargs ):
      attribute.scope = self.currentScope

   def visitBinOp( self, binOp, **kwargs ):
      self.visit( binOp.lhs )
      self.visit( binOp.rhs )

   def visitNot( self, notExpr, **kwargs ):
      self.visit( notExpr.expr )

class ResolutionPhase( RcfAstVisitor.Visitor ):
   """ Resolution phase of the semantic analysis.

   During this phase, we try to resolve the symbol that we find in their scope.
   This phase can emit errors (resolution errors and/or warnings).

      - We raise Rcf diags errors if symbols are not defined.
      - We raise Rcf diags errors if external symbols are not defined.
      - We gather meta data (any)

   Note that ResolutionPhase doesn't decide whether a given error is fatal or
   not, it's RcfDiag's job to do so.

   !Rules (don't change unless discussing with authors first)

      - Define the visit methods in the order in which the AST nodes are defined.

      @author: matthieu (rcf-dev)
   """
   def __init__( self, diags, aclListConfig, callgraph ):
      """ Constructor

      Args:
         aclListConfig: the Tacc collection of ACL symbols defined in EOS config.
         diags (RcfDiag): the diagnostic report object.
      """
      super( ResolutionPhase, self ).__init__()
      self.diags = diags
      self.externalScope = RcfSymbolTable.ExternalScope( aclListConfig )
      self.currentFunction = None
      self.callgraph = callgraph
      self.refFromAssign = False

   def visitRoot( self, root, **kwargs ):
      for function in root.functions:
         self.visit( function )

   def visitFunction( self, function, **kwargs ):
      self.currentFunction = function
      self.visit( function.block )

   def visitBlock( self, block, **kwargs ):
      for stmt in block.stmts:
         self.visit( stmt )

   def visitIfStmt( self, ifStmt, **kwargs ):
      self.visit( ifStmt.condition )
      self.visit( ifStmt.thenBlock )
      if ifStmt.elseBlock:
         self.visit( ifStmt.elseBlock )

   def visitSequentialExpr( self, sequentialExpr, **kwargs ):
      for call in sequentialExpr.funcCalls:
         self.visit( call )

   def visitCall( self, call, **kwargs ):
      calleeName = call.funcName
      symbol = call.scope.resolve( calleeName )
      if not symbol or not isinstance( symbol, RcfSymbol.Function ):
         self.diags.resolutionError( self.currentFunction, call )
         return
      self.callgraph.add( self.currentFunction.name, calleeName )
      call.symbol = symbol

   def visitExternalRefOp( self, externalRefOp, **kwargs ):
      self.visit( externalRefOp.attribute )
      self.visit( externalRefOp.extRef )

   def visitExternalRef( self, extRef, **kwargs ):
       # whether we're referencing this external construct from an assignment
       # - this has bearing on whether some community list exist or not -
      extRef.refFromAssign = self.refFromAssign
      extSymbol = self.externalScope.resolve( extRef )
      if not extSymbol:
         self.diags.extRefResolutionError( self.currentFunction, extRef )
      extRef.symbol = extSymbol

   def visitAssign( self, assign, **kwargs ):
      self.refFromAssign = True
      self.visit( assign.attribute )
      self.visit( assign.value )
      self.refFromAssign = False

   def visitReturn( self, returnStmt, **kwargs ):
      self.visit( returnStmt.expr )

   def visitConstant( self, constant, **kwargs ):
      if constant.type == RcfAst.Constant.Type.asPath:
         self.resolveAsPathValue( constant )

   def resolveAsPathValue( self, asPathValue ):
      for asn in asPathValue.value:
         if asn.get( "attr" ):
            self.visit( asn[ "attr" ] )

   def visitAttribute( self, attribute, **kwargs ):
      attributeSymbol = attribute.scope.resolve( attribute.name )
      if attributeSymbol:
         attribute.symbol = attributeSymbol
      else:
         self.diags.resolutionError( self.currentFunction, attribute )

   def visitBinOp( self, binOp, **kwargs ):
      self.visit( binOp.lhs )
      self.visit( binOp.rhs )

   def visitNot( self, notExpr, **kwargs ):
      self.visit( notExpr.expr )

def genSymbolTable( rcfAst, diag, aclListConfig ):
   """ Build the Symbol table and annotate the AST with scope information.

   Args:
      rcfAst: the RCF Abstract Syntax Tree root.
      diag: the RCF diagnostics object.
      aclListConfig: Tacc entity holding the collections of ACLs defined in Sysdb.
   """
   callgraph = RcfLibCycleDetection.Callgraph()
   defineAllSymbolsAndScopes = DefinitionPhase( diag )
   resolveAllNamesInScopes = ResolutionPhase( diag, aclListConfig, callgraph )

   defineAllSymbolsAndScopes( rcfAst )
   resolveAllNamesInScopes( rcfAst )
   cyclesFound = RcfLibCycleDetection.findAllCycles( callgraph )
   if cyclesFound:
      diag.cycleError( cyclesFound )

def runSemanticAnalysis( rcfAst, diag, aclListConfig ):
   """ Runs the semantic analysis over the AST.

        The following phases will follow (in order).

        - SymbolTable generation / validation:
           - Definition
           - Resoltuion
           - Cycle detection

        - Type binding / validation

   Args:
      rcfAst: the RCF Abstract Syntax Tree root.
      diag: the RCF diagnostics object.
      aclListConfig: Tacc entity holding the collections of ACLs defined in Sysdb.
        """
   genSymbolTable( rcfAst, diag, aclListConfig )

   if diag.hasErrors():
      return

   # Eventually, when we have full AET support, we should run
   # The pristine type system: RcfTypeBinding.TypeBindingPhase
   bindTypes = RcfTypeBinding.TypeBindingPhase( diag )
   bindTypes( rcfAst )
