# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function
import os
from collections import defaultdict

from Assert import assertEqual, assertIn, assertGreater, assertTrue

import Tracing
import QuickTrace

from RcfAetCleanup import AetCleanupHelper
import RcfAst
from RcfAstVisitor import Visitor
from RcfImmediateValueHelper import AetValueGen
import RcfMetadata
import RcfTypeFuture as Rcf
import RoutingTypeFuture as Routing

t0 = Tracing.t0
qv = QuickTrace.Var
qt8 = QuickTrace.trace8

RcfTypeSystem = RcfMetadata.RcfTypeSystem
BT = RcfMetadata.RcfBuiltinTypes

class AetGenPhase( Visitor ):
   """ Generation of the Abstract Evaluation Tree (AET) from the AST.

   During this phase:
      - we assume that the AST is fully valid.
      - we build the AET node representing the AST we traverse.
      - we build the metadata information required for the evaluation of the AET.
      - we dedup compiler AET against published AET.
   """

   def __init__( self, publishedAfs ):
      super( AetGenPhase, self ).__init__()
      self.aetDedupHelper = RcfAetDedupHelper( publishedAfs )
      self.metadataHelper = RcfAetMetadataHelper( publishedAfs )
      self.currentFunction = None
      self.breadthTestMode = os.environ.get( 'BREADTH_TEST' )
      self.functionToAetKeys = defaultdict( set )

   def visitRoot( self, root, **kwargs ):
      for function in root.functions:
         self.visit( function )
      RcfAetMetadataHelper( self.aetDedupHelper.resultAfs ).verifySanity()
      return self.aetDedupHelper.resultAfs

   def setCurrentFunction( self, function ):
      fnName = function.name.encode( 'utf8' )
      self.currentFunction = function
      if self.breadthTestMode:
         Rcf.currentFunctionKeySet = self.functionToAetKeys[ fnName ]

   def visitFunction( self, function, **kwargs ):
      """ Build an AET for the given function.

      Args: function (RcfAst.Function): the AST function we want to transform in
                                        AET.

      Note:
         This function can be called when we're walking the functions definitions
         of the AST's Root, or when we want to get the AET address of a function
         from an AST's function call perspective.

         If we never saw this function before during this compilation, we build it.
         Otherwise we use the dedup'd AET stored inside aetDedupHelper.
      """
      fnName = function.name.encode( 'utf8' )

      if self.aetDedupHelper.functionExistsInResult( fnName ):
         # We already visited this one, return the deduped AET stored in the result.
         return self.aetDedupHelper.getAet( fnName )

      self.setCurrentFunction( function )
      function.metadata = Rcf.Eval.MetadataType()
      function.features = Routing.Policy.RouteMapFeatures()

      body = self.visit( function.block )
      compiledFunction = Rcf.Eval.Function( fnName,
                                            function.metadata,
                                            function.features,
                                            body )
      for key in self.functionToAetKeys[ fnName ]:
         compiledFunction.allAetNodeKeys.add( key )
      # Todo: no need to pass name
      return self.aetDedupHelper.dedupAet( fnName, compiledFunction )

   def visitBlock( self, block, **kwargs ):
      aetStmts = []
      # first, build all the statements
      for stmt in block.stmts:
         aetStmts.append( self.visitExpr( stmt ) )

      # walk the aet statement backwards and construct the
      # linked list of blockstatememt from the last to the
      # first element. Walking back aetStmts simulates the stack
      # so we don't have to use stack based recursion during the construction
      # of the blockstmt itself.
      previousBlockStmt = None
      aetBlockStmt = None
      for aetStmt in reversed( aetStmts ):
         aetBlockStmt = Rcf.Eval.BlockStmt( aetStmt, previousBlockStmt )
         previousBlockStmt = aetBlockStmt

      # return the constructed block if any, or an empty blockstmt otherwise
      return aetBlockStmt or Rcf.Eval.BlockStmt( None, None )

   def visitIfStmt( self, ifStmt, **kwargs ):
      condition = self.visitExpr( ifStmt.condition )
      thenBlock = self.visit( ifStmt.thenBlock )
      elseBlock = self.visit( ifStmt.elseBlock ) if ifStmt.elseBlock else None
      aetIfStmt = Rcf.Eval.IfExpression( condition, thenBlock, elseBlock )
      return aetIfStmt

   def visitSequentialExpr( self, sequentialExpr, **kwargs ):
      finalBool = sequentialExpr.finalBool
      if finalBool is not None:
         aetValue = AetValueGen.aetTriStateNameToValueMap[ finalBool.value ]
         aet = Rcf.Eval.SequentialExpression(
                     Rcf.Eval.TriStateBoolImmediateExpression( aetValue ), None )
      else:
         aet = None

      # Build it bottom up, since the AET actually holds a linked list.
      for funcCall in reversed( sequentialExpr.funcCalls ):
         entry = self.visit( funcCall )
         aet = Rcf.Eval.SequentialExpression( entry, aet )

      return aet

   def visitCall( self, call, **kwargs ):
      caller = self.currentFunction
      calleePtr = self.visit( call.symbol.node )
      self.setCurrentFunction( caller )
      caller.metadata.updateWith( calleePtr.metadata )
      caller.features.updateFrom( calleePtr.features )
      calleeName = call.funcName.encode( "utf8" )
      aetFuncCall = Rcf.Eval.FunctionCallExpression( calleeName, calleePtr )
      return aetFuncCall

   def visitExternalRefOp( self, externalRefOp, **kwargs ):
      ctorArgs = []
      aetAttr = self.visit( externalRefOp.attribute )
      attrSym = externalRefOp.attribute.symbol
      ctorArgs.append( aetAttr )
      ctorArgs.append( externalRefOp.extRef.utf8name )
      if externalRefOp.extRef.isIpV4 is not None:
         ctorArgs.append( externalRefOp.extRef.isIpV4 )
      index = self.visit( externalRefOp.extRef )
      ctorArgs.append( index )
      opAetType = attrSym.rcfType.aetOpTypes[ externalRefOp.op ]
      opAetNode = opAetType( *ctorArgs )
      return opAetNode

   def visitExternalRef( self, extRef, **kwargs ):
      return self.extRefIndex( extRef )

   def visitAssign( self, assign, **kwargs ):
      attrSym = assign.attribute.symbol
      op = assign.op
      value = assign.value
      # For some assign statements, the metadata may provide specific lhs, op, and
      # rhs AET types as a way of overriding the AET construction. Additionally, the
      # metadata may provide specific ctor args to be used when instantiating the rhs
      # AET node.
      lhsAetType, opAetType, rhsAetType, rhsAetCtorArgs = RcfTypeSystem.aetTypes.get(
            ( attrSym.rcfType, op, value.evalType ), ( None, None, None, None ) )
      if lhsAetType:
         return opAetType( lhsAetType(), rhsAetType( *rhsAetCtorArgs ) )
      else:
         aetAttr = self.visit( assign.attribute )
         ctorArgs = []
         ctorArgs.append( aetAttr )
         if isinstance( value, RcfAst.ExternalRef ):
            ctorArgs.append( value.utf8name )
         ctorArgs.append( self.visitValue( value ) )
         opAetType = attrSym.rcfType.aetOpTypes[ op ]
         opAetNode = opAetType( *ctorArgs )
         return opAetNode

   def visitReturn( self, returnStmt, **kwargs ):
      aetExpr = self.visitExpr( returnStmt.expr )
      aetReturnExpr = Rcf.Eval.ReturnExpression( aetExpr )
      return aetReturnExpr

   def visitAttribute( self, attribute, **kwargs ):
      attrAetType = attribute.symbol.aetType
      attribute.symbol.updateRouteMapFeatures( self.currentFunction.features )
      attrAetNode = attrAetType()
      return attrAetNode

   def visitBinOp( self, binOp, **kwargs ):
      if binOp.isBooleanOp:
         # We treat boolean ops in a special way because we don't need to
         # grab AET types from the operands: 'and' and 'or' operation
         # only ever work on boolean (unlike 'is' or '>', which AET operation
         # node is driven by it's operand types).
         aetExpr = self.visitBooleanBinOp( binOp )
      else:
         aetExpr = self.visitAttrBinOp( binOp )
      return aetExpr

   def visitNot( self, notExpr, **kwargs ):
      aetExpr = self.visitExpr( notExpr.expr )
      return Rcf.Eval.NotExpression( aetExpr )

   #-- custom
   def visitExpr( self, expr ):
      if isinstance( expr, RcfAst.Constant ):
         constant = expr
         aetTriStateValue = AetValueGen.aetTriStateNameToValueMap[ constant.value ]
         aetExpr = Rcf.Eval.TriStateBoolImmediateExpression( aetTriStateValue )
      else:
         aetExpr = self.visit( expr )
      return aetExpr

   def visitValue( self, value ):
      valueVisitor = AetValueGen( aetGenVisitor=self )
      return valueVisitor.build( value )

   def visitAttrBinOp( self, binOp ):
      op = binOp.operator
      attribute = binOp.lhs
      value = binOp.rhs
      attrSym = attribute.symbol
      # For some attribute binary operation statements, the metadata may provide
      # specific lhs, op, and rhs AET types as a way of overriding the AET
      # construction. Additionally, the metadata may provide specific ctor args to be
      # used when instantiating the rhs AET node.
      lhsAetType, opAetType, rhsAetType, rhsAetCtorArgs = RcfTypeSystem.aetTypes.get(
            ( attrSym.rcfType, op, value.evalType ), ( None, None, None, None ) )
      if lhsAetType and rhsAetType:
         return opAetType( lhsAetType(), rhsAetType( *rhsAetCtorArgs ) )
      elif lhsAetType:
         return opAetType( lhsAetType() )
      else:
         aetAttr = self.visit( attribute )
         aetValue = self.visitValue( value )
         binOpType = attrSym.rcfType.aetOpTypes[ op ]
         return binOpType( aetAttr, aetValue )

   def visitBooleanBinOp( self, binOp ):
      aetLhs = self.visitExpr( binOp.lhs )
      aetRhs = self.visitExpr( binOp.rhs )
      if binOp.operator == "and":
         aetExpr = Rcf.Eval.AndExpression( aetLhs, aetRhs )
      else:
         aetExpr = Rcf.Eval.OrExpression( aetLhs, aetRhs )
      return aetExpr

   def extRefIndex( self, extRef ):
      typeMap = {
         "prefix_list_v4": self.currentFunction.metadata.v4PrefixListsUsed,
         "prefix_list_v6": self.currentFunction.metadata.v6PrefixListsUsed,
         "as_path_list": self.currentFunction.metadata.asPathAccessListsUsed,
         "community_list": self.currentFunction.metadata.commListsUsed,
         "ext_community_list": self.currentFunction.metadata.extCommListsUsed,
      }
      attrNameMap = {
         "prefix_list_v4": "v4PrefixListsUsed",
         "prefix_list_v6": "v6PrefixListsUsed",
         "as_path_list": "asPathAccessListsUsed",
         "community_list": "commListsUsed",
         "ext_community_list": "extCommListsUsed",
      }

      metadataType = typeMap[ extRef.type ]
      index = self.metadataHelper.getIndex( extRef.name,
                                            attrNameMap[ extRef.type ] )
      metadataType[ str( extRef.name ) ] = index
      return index

class RcfAetMetadataTypeHelper( object ):
   """This class helps manage index re-use and allocation for the index values that
   need to be allocated for specific external constructs (prefix lists etc).

   Attributes:
      afs: AllFunctionStatus that has the AETs.
      metadataType: Name of metadata attribute inside Rcf::Eval::FunctionMetadata.
      afsInUseIndicesDict: A dict of ExternalEntityName -> index for all functions
                           in self.afs.
      freeIndexGenerator: The generator for generating new index values.
   """
   def __init__( self, afs, metadataType ):
      self.afs = afs
      self.metadataType = metadataType
      self.afsInUseIndicesDict = self.getAfsInUseIndicesDict( self.afs )
      self.freeIndexGenerator = self.freeIndexGeneratorFunction()

   def getAfsInUseIndicesDict( self, afs ):
      """Returns dict of mappings from external entity name to index"""
      ret = {}
      for aet in afs.aet.values():
         metadataDict = getattr( aet.metadata, self.metadataType )
         for name, index in metadataDict.items():
            assertGreater( index, 0 )
            if ret.get( name ):
               # The name/index mapping should be the same across all AETs
               assertEqual( ret[ name ], index )
            else:
               ret[ name ] = index
      return ret

   def freeIndexGeneratorFunction( self ):
      # Start the range from 1, so 0 can be used an invalid value. Zero is the
      # default value for an int and the code can assert if it sees a zero index.
      index = 1
      while True:
         inUseIndices = self.afsInUseIndicesDict.values()
         while index in inUseIndices:
            index += 1
         yield index
         index += 1

   def getIndex( self, name ):
      """Get an index for an external construct.  Re-use the index collected
      from afs, if it exists there."""
      if self.afsInUseIndicesDict.get( name ) is not None:
         # Reuse the existing index
         index = self.afsInUseIndicesDict[ name ]
         t0( "Rcf:RAMTH:getIndex: re-used index" )
         qt8( "Rcf:RAMTH:getIndex: re-used index" )
      else:
         index = next( self.freeIndexGenerator )
         self.afsInUseIndicesDict[ name ] = index
      # todo (matthieu): factorise tracing
      t0( "Rcf:RAMTH:getIndex|",
           str( self.metadataType ),
           " ",
           name,
           index )
      qt8( "Rcf:RAMTH:getIndex|",
           qv( str( self.metadataType ) ),
           qv( " " ),
           qv( name ),
           qv( index )  )

      assertGreater( index, 0 )
      return index

   def verifySanity( self ):
      """Verify that the same index value is used in every reference to an external
      construct, across all RCF functions.
      """
      # The below operation has the necessary asserts to verify consistency
      self.getAfsInUseIndicesDict( self.afs )

class RcfAetMetadataHelper( object ):
   """This wrapper class helps manage index re-use and allocation for all the
   external constructs (prefix lists etc).

   Attributes:
      afs: AllFunctionStatus that has the AETs.
      metadataTypes: The names of external construct types that are supported.
      helpers: Dict of metadataType -> RcfAetMetadataTypeHelper instance
   """
   def __init__( self, afs ):
      self.afs = afs
      self.metadataTypes = [ 'v4PrefixListsUsed',
                             'v6PrefixListsUsed',
                             'asPathAccessListsUsed',
                             'commListsUsed',
                             'extCommListsUsed' ]
      self.helpers = {}
      for metadataType in self.metadataTypes:
         self.helpers[ metadataType ] = \
               RcfAetMetadataTypeHelper( self.afs, metadataType )

   def getIndex( self, name, metadataType ):
      """Get an index for an external construct"""
      assertIn( metadataType, self.metadataTypes )
      return self.helpers[ metadataType ].getIndex( name )

   def verifySanity( self ):
      """Verifies consistency in all the helpers"""
      for helper in self.helpers.values():
         helper.verifySanity()
      for aet in self.afs.aet.values():
         assertTrue( aet.verifySanity( aet.metadata ) )

class RcfAetDedupHelper( object ):
   """This manages AET dedup, when the compiler generates the AETs.
   There should be one instance of RcfAetDedupHelper for every batch compilation.

   Attributes:
      publishedAfs: AllFunctionStatus that has the already published AETs.
      resultAfs: The AllFunctionStatus result that is produced by this class.
   """
   def __init__( self, publishedAfs ):
      self.publishedAfs = publishedAfs
      self.resultAfs = Rcf.AllFunctionStatus( "resultAfs" )
      self.cleanupHelper = AetCleanupHelper()

   def functionExistsInResult( self, funcName ):
      """Returns whether the function already exists in the deduped result."""
      return funcName in self.resultAfs.aet

   def getAet( self, funcName ):
      """Returns the deduped AET, if it exists in the result.  This should be
      called only after checking if functionExistsInResult().
      """
      assert self.functionExistsInResult( funcName )
      return self.resultAfs.aet[ funcName ]

   def dedupAet( self, funcName, newAet ):
      """Does AET Dedup.
      If AET comparison between newAet and the AET in self.publishedAfs (if it is
      present there) returns true, it reuses the AET present in self.publishedAfs
      and discards newAet.
      Stores the deduped AET in self.resultAfs and returns the deduped AET.
      """
      # funcName should not already exist in result, when this function is called.
      assert not self.functionExistsInResult( funcName )

      if newAet:
         publishedAet = self.publishedAfs.aet.get( funcName )
         isDuplicate = publishedAet and publishedAet.compare( newAet )
         if isDuplicate:
            dedupedAet = publishedAet
            t0( "Rcf:RADH:dedupAet: re-use|", funcName )
            qt8( "Rcf:RADH:dedupAet: re-use|", qv( funcName ) )
            self.cleanupHelper.cleanup( newAet )
         else:
            t0( "Rcf:RADH:dedupAet: new|", funcName )
            qt8( "Rcf:RADH:dedupAet: new|" , qv( funcName ) )
            dedupedAet = newAet
      else:
         assert False, 'newAet can not be None, with all-or-nothing semantics'
         # If we don't do all-or-nothing, then we need 'dedupedAet = None' here.

      self.resultAfs.aet.addMember( dedupedAet )
      return dedupedAet

def genAbstractEvalTree( rcfAst, publishedAfs ):
   Rcf.publishedAfs = publishedAfs
   generateAet = AetGenPhase( publishedAfs=publishedAfs )
   return generateAet( rcfAst )
