# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import antlr4
import os

from Assert import assertEqual
import Tac
import Tracing
import QuickTrace
from RcfAetCleanup import AetCleanupHelper
from RcfAetGen import genAbstractEvalTree
from RcfAstGen import genAst
from RcfDiag import RcfDiag
from RcfLexer import (
      InputStream,
      RcfLexer,
   )
from RcfParser import RcfParser
import RcfSymbolTableGen

t0 = Tracing.t0
qv = QuickTrace.Var
qt8 = QuickTrace.trace8

class RcfCompileRequest( object ):
   """Input for the RCF compiler

   rcfCodeString : str
      The RCF code to compile
   strictMode : bool
      Compiler to follow strict rules. This will fail compilation if the
      referred external constructs such as prefix-lists, community-lists,
      are not configured.
      An instance where strictMode is False is when compiling RCF code
      as part of the startup-config.
   """
   def __init__( self, rcfCodeString, strictMode=True ):
      self.rcfCodeString = rcfCodeString
      self.strictMode = strictMode

class RcfCompileResult( object ):
   """Output from the RCF compiler

   Arguments
   ---------
   rcfCompileRequest : RcfCompileRequest
      The compile request to which this is a response
   success : bool
      Whether the compilation was successful or not
   allFunctionStatus : Rcf::AllFunctionStatus
      The result of compilation. This may have pointers to previously published
      AETs if nothing in them changes.
   errorList : (list of str) or (None)
      None if compilation was successful. List of error strings otherwise.
   """
   def __init__( self, rcfCompileRequest, success, allFunctionStatus,
         errorList=None ):
      self.rcfCompileRequest = rcfCompileRequest
      self.success = success
      self.errorList = errorList
      self.allFunctionStatus = allFunctionStatus

   def publish( self, publishedAfs, rcfStatus ):
      """Compare new results in self.allFunctionStatus with previously published
      state in publishedAfs and update publishedAfs. This involves
      1. deleting functions that are no longer in code
      2. updating the select functions whose aet has changed
      3. updating newly added functions
      """
      prevPublishedFnNames = set( publishedAfs.aet )
      newPublishedFnNames = set( self.allFunctionStatus.aet )
      stalePublishedFns = prevPublishedFnNames - newPublishedFnNames

      cleanupHelper = AetCleanupHelper()

      # Delete all stale functions
      for fnName in stalePublishedFns:
         t0( "Deleting RCF function|", fnName )
         qt8( "Deleting RCF function|", qv( fnName ) )
         # Cleanup the corresponding AET nodes from the instantiating collections
         cleanupHelper.cleanup( publishedAfs.aet[ fnName ] )
         del publishedAfs.aet[ fnName ]
         rcfStatus.functionNames.remove( fnName )

      # Update all modified/new functions
      for fnName in newPublishedFnNames:
         oldAet = publishedAfs.aet.get( fnName )
         newAet = self.allFunctionStatus.aet[ fnName ]
         if oldAet != newAet:
            t0( "Updating RCF function|", fnName )
            qt8( "Updating RCF function|", qv( fnName ) )
            publishedAfs.aet.addMember( newAet )
            if oldAet:
               cleanupHelper.cleanup( oldAet )

         if fnName not in rcfStatus.functionNames:
            rcfStatus.functionNames.add( fnName )

class RcfCompiler( object ):
   """RCF compiler instance.

   aclConfig : Acl::AclListConfig
      This contains the list of all configured IPv4 and IPv6 prefix lists,
      community lists, etc.
   publishedAfs : Rcf::AllFunctionStatus
      The currently published version of functions
   """
   def __init__( self, aclConfig, publishedAfs ):
      self.aclConfig = aclConfig
      self.publishedAfs = publishedAfs

   @staticmethod
   def failureResult( request, diag ):
      emptyAfs = Tac.newInstance( "Rcf::AllFunctionStatus", "Empty" )
      return RcfCompileResult( request,
                               success=False,
                               allFunctionStatus=emptyAfs,
                               errorList=diag.allErrorStrList() )

   def compile( self, request ):
      """
      request : RcfCompileRequest
         Code to compile, strictness, etc
      Returns RcfCompilerResult instance.
      """
      diag = RcfDiag( strict=request.strictMode )

      # Step 1: Got token stream
      lexer = RcfLexer( InputStream( request.rcfCodeString ) )
      lexer.removeErrorListeners()
      lexer.addErrorListener( diag )
      tokenStream = antlr4.CommonTokenStream( lexer )

      # Step 2: Perform parsing
      parser = RcfParser( tokenStream )
      parser.removeErrorListeners()
      parser.addErrorListener( diag )
      parseTree = parser.rcf()

      if diag.hasErrors():
         return self.failureResult( request, diag )

      # Step 3: Build AST
      rcfAst = genAst( parseTree, diag )

      # Step 4: Run the semantic analysis
      RcfSymbolTableGen.runSemanticAnalysis( rcfAst, diag, self.aclConfig )

      if diag.hasErrors():
         return self.failureResult( request, diag )

      # Step 5: Generate the AllFunctionStatus
      afs = genAbstractEvalTree( rcfAst, self.publishedAfs )

      if not os.environ.get( 'BREADTH_TEST' ):
         for _, aet in afs.aet.iteritems():
            # This should not get populated in product code
            assertEqual( len( aet.allAetNodeKeys ), 0 )

      return RcfCompileResult( request, success=True, allFunctionStatus=afs )
