#!/bin/python
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

""" Translate the Antlr Parse tree into the Rcf AST.
"""

from __future__ import absolute_import, division, print_function

from RcfParser import RcfParser
from RcfVisitor import RcfVisitor
import RcfAst
from RcfImmediateValueHelper import RcfImmediateValueHelper

class AstGenVisitor( RcfVisitor ):
   """ This visitor converts the Antlr parse tree to the Rcf AST.

   !Rules (don't change unless discussing with authors first)

      - Define the visit method following same the order in which
        the parser rules are defined in the grammar file.

   Args:
      diag: diag object to store errors/warnings.

   Attributes:
      self.currentFunction (RcfFunction): the current function.

   @author: matthieu (rcf-dev)
   """
   def __init__( self, diag ):
      self.diag = diag
      self.currentFunction = None

   def __call__( self, parseTree ):
      """ Visitor instances are functors.

      Primes the visitor to walk over the parse tree.

      Args:
         parseTree: the parse tree from Antlr.

      Returns:
         The AST for this parse tree.
      """
      return parseTree.accept( self )

   def visitRcf( self, ctx ):
      """ Visit a parse tree produced by RcfParser#rcf
      """
      functions = list()
      for funcDecl in ctx.funcDecl():
         functions.append( self.visitFuncDecl( funcDecl ) )
      root = RcfAst.Root( ctx, functions )
      return root

   def visitFuncDecl( self, ctx ):
      """ Visit a parse tree produced by RcfParser#funcDecl.
      """
      # 'function <name>' will be parsed as one token, extract <name>
      name = ctx.FUNCTION().getText().split()[ 1 ]
      self.currentFunction = name
      block = self.visitBlock( ctx.block() )
      function = RcfAst.Function( ctx, name=name, block=block )
      return function

   def visitBlock( self, ctx ):
      """ Visit a parse tree produced by RcfParser#block.
      """
      stmts = list()
      for stmt in ctx.stmt():
         stmts.append( stmt.accept( self ) )
      astBlock = RcfAst.Block( ctx, stmts )
      return astBlock

   def visitStmt( self, ctx ):
      """ Visit a parse tree produced by RcfParser#stmt.
      """
      if ctx.ifStmt():
         return self.visitIfStmt( ctx.ifStmt() )
      elif ctx.assignStmt():
         return self.visitAssignStmt( ctx.assignStmt() )
      elif ctx.returnStmt():
         return self.visitReturnStmt( ctx.returnStmt() )
      elif ctx.expr():
         isSeqExpr = isinstance( ctx.expr(), RcfParser.SeqContext )
         isFuncCalExpr = isinstance( ctx.expr(), RcfParser.FuncCallContext )
         if not isSeqExpr and not isFuncCalExpr:
            # Other types of expressions are no-ops when used as standalone
            # statements.  Eg.
            # function foo() { med is 100; }
            # The User clearly meant to write 'med = 100' but wrote
            # 'med is 100' instead.
            # Add a warning in this case.
            what = 'statement has no effect'
            self.diag.noEffectWarning( self.currentFunction, ctx.expr(), what )

         # Continue building the AST.
         return ctx.expr().accept( self )
      else:
         raise NotImplementedError( 'Unknown statement node' )

   def visitAssignStmt( self, ctx ):
      """ Visit a parse tree produced by RcfParser#assignStmt.
      """
      attribute = self.visitAttribute( ctx.attribute() )
      op = str( ctx.ASSIGN_OP().getText() )
      value = self.visit( ctx.value() )

      astNode = RcfAst.Assign( ctx, attribute, value, op )
      return astNode

   def visitValue( self, ctx ):
      if ctx.string():
         # a string could be an enum, interface, or attribute.
         cvalue = ctx.string().getText()
         enumType = RcfImmediateValueHelper.getEnumType( cvalue )
         interfaceType = RcfImmediateValueHelper.getInterfaceType( cvalue )
         if enumType is not None:
            return RcfAst.Constant( ctx, cvalue, enumType )
         elif interfaceType is not None:
            return RcfAst.Constant( ctx, cvalue, interfaceType )
         else:
            return self.visitAttribute( ctx.string() )
      elif ctx.immediate():
         return self.visit( ctx.immediate() )
      else:
         return self.visit( ctx.extRef() )

   def asPathFromAsPathContext( self, asPathContext ):
      """ Get an ordered, python friendly list of as numbers given the
      asPathContext object built by the parser.

      Args:
         context (RcfParser.AsPathContext): the parsed, asPath object from Antlr.

      Returns:
         list (ordered) of single item dictionnary:

         "50 20.30 would map to:
         e.g [ { 'as': '50' }, { 'asdot': '20.30' }]

      """
      def getAsn( asn ):
         if asn.integer():
            asn = asn.integer()
            return 'int', RcfAst.Constant( asPathContext, int( asn.getText() ),
                  RcfAst.Constant.Type.integer )

         if asn.asDot():
            asdot = asn.asDot()
            asnums = asdot.getText().split( '.' )
            high16 = int( asnums[ 0 ] ) << 16
            low16 = int( asnums[ 1 ] )
            return 'as_dot', RcfAst.Constant( asPathContext, high16 + low16,
                  RcfAst.Constant.Type.asDot )

         if asn.attribute():
            astAttr = self.visitAttribute( asn.attribute() )
            return 'attr', astAttr
         assert False, ( "Unkown ASN parsed type: %s", type( asn ) )
         return False

      def asPathContextAsnWalker( asPathContext ):
         if asPathContext.asn(): # that as 0 or more following ASNs
            for asn in asPathContext.asn():
               yield asn

      asPathListResult = list()
      for asn in asPathContextAsnWalker( asPathContext ):
         t, val = getAsn( asn )
         asPathListResult.append( { t: val } )

      return asPathListResult

   def visitImmediate( self, ctx ):
      if ctx.integer():
         ctype = RcfAst.Constant.Type.integer
         cvalue = long( ctx.integer().getText() )
      elif ctx.ipPrefix():
         ctype = RcfAst.Constant.Type.prefix
         cvalue = ctx.ipPrefix().getText()
      elif ctx.ipAddress():
         ctype = RcfAst.Constant.Type.ipAddress
         cvalue = ctx.ipAddress().getText()
      elif ctx.asPath():
         ctype = RcfAst.Constant.Type.asPath
         cvalue = self.asPathFromAsPathContext( ctx.asPath() )
      elif ctx.asDot():
         asnums = ctx.asDot().getText().split( '.' )
         high16 = int( asnums[ 0 ] ) << 16
         low16 = int( asnums[ 1 ] )
         ctype = RcfAst.Constant.Type.asDot
         cvalue = high16 + low16
      elif ctx.EMPTY():
         ctype = RcfAst.Constant.Type.empty
         cvalue = None
      elif ctx.NONE():
         ctype = RcfAst.Constant.Type.none
         cvalue = None
      else:
         assert False, "Unknown constant type: %s" % type( ctx )
      constant = RcfAst.Constant( ctx, cvalue, ctype )
      return constant

   def visitReturnStmt( self, ctx ):
      """ Visit a parse tree produced by RcfParser#returnStmt.
      """
      returnExpr = ctx.expr()
      expr = returnExpr.accept( self )
      ast = RcfAst.Return( ctx, expr )
      return ast

   def visitIfStmt( self, ctx ):
      """ Visit a parse tree produced by RcfParser#ifStmt.
      """
      conditionExpr = ctx.expr()
      condition = conditionExpr.accept( self )
      thenBlock = self.visitBlock( ctx.block()[ 0 ] )
      if ctx.ifStmt(): # if then else if ...
         # Create a dummy block with a nested if statement node inside
         elseBlock = RcfAst.Block( ctx, [ ctx.ifStmt().accept( self ) ] )
         assert len( ctx.block() ) == 1
      elif len( ctx.block() ) == 2: # if then else
         elseBlock = self.visitBlock( ctx.block()[ 1 ] )
      else: # if then
         elseBlock = None
      ast = RcfAst.IfStmt( ctx, condition, thenBlock, elseBlock )
      return ast

   def visitRelational( self, ctx ):
      """ Visit a parse tree produced by RcfParser#relational.
      """
      operator = ctx.BIN_REL().getText()
      attribute = self.visitAttribute( ctx.attribute() )
      constant = self.visit( ctx.value() )
      astNode = RcfAst.BinOp( ctx, operator, attribute, constant )
      return astNode

   def visitAttribute( self, ctx ):
      """ Visit a parse tree produced by RcfParser#attribute.
      """
      name = '.'.join( part.getText() for part in ctx.ID() )
      ast = RcfAst.Attribute( ctx, name )
      return ast

   def visitExtRef( self, ctx ):
      """ Visit a parse tree produced by RcfParser#extRef
      """
      etype, ename = ctx.EXTERNAL_REF().getText().split()
      astNode = RcfAst.ExternalRef( ctx, ename, etype )
      return astNode

   def visitFuncCall( self, ctx ):
      """ Visit a parse tree produced by RcfParser#FuncCall
      """
      calleeName = ctx.funcCallExpr().ID().getText()
      astNode = RcfAst.Call( ctx, calleeName )
      return astNode

   def visitNot( self, ctx ):
      """ Visit a parse tree produced by RcfParser#Not
      """
      expr = ctx.expr().accept( self )
      astNode = RcfAst.Not( ctx, expr )
      return astNode

   def visitParens( self, ctx ):
      """ Visit a parse tree produced by RcfParser#Parens
      """
      astNode = ctx.expr().accept( self )
      return astNode

   def visitExternalRefOp( self, ctx ):
      """ Visit a parse tree produced by RcfParser#ExternalRefOp
      """
      attribute = self.visitAttribute( ctx.attribute() )
      isExact = ctx.EXT_REF_OP().getText() == 'match_exact'
      isMatchCovered = ctx.EXT_REF_OP().getText() == 'match_covered'
      extRef = self.visitExtRef( ctx.extRef() )
      astNode = RcfAst.ExternalRefOp( ctx, attribute, isExact, isMatchCovered,
                                      extRef )
      return astNode

   def visitAnd( self, ctx ):
      """ Visit a parse tree produced by RcfParser#And.
      """
      lhs = ctx.expr()[ 0 ].accept( self )
      rhs = ctx.expr()[ 1 ].accept( self )
      astNode = RcfAst.BinOp( ctx, 'and', lhs, rhs )
      return astNode

   def visitOr( self, ctx ):
      """ Visit a parse tree produced by RcfParser#Or.
      """
      lhs = ctx.expr()[ 1 ].accept( self )
      rhs = ctx.expr()[ 0 ].accept( self )
      astNode = RcfAst.BinOp( ctx, 'or', rhs, lhs )
      return astNode

   def visitRel( self, ctx ):
      """ Visit a parse tree produced by RcfParser#Rel
      """
      astNode = self.visitRelational( ctx.relational() )
      return astNode

   def visitSeq( self, ctx ):
      """ Visit a parse tree produced by RcfParser#Seq
      """
      astNode = self.visitSequentialExpr( ctx.sequentialExpr() )
      return astNode

   def visitTrueFalseUnknown( self, ctx ):
      """ Visit a parse tree produced by RcfParser#TrueFalseUnknown.
      """
      cvalue = ctx.getText()
      if cvalue == "unknown":
         ctype = RcfAst.Constant.Type.trilean
      else:
         ctype = RcfAst.Constant.Type.boolean
      astNode = RcfAst.Constant( ctx, cvalue, ctype )
      return astNode

   def visitSequentialExpr( self, ctx ):
      """ Visit a parse tree produced by RcfParser#sequentialExpr
      """
      funcCalls = [ RcfAst.Call( expr, expr.ID().getText() )
                    for expr in ctx.funcCallExpr() ]
      if ctx.TRUE() is not None:
         finalBool = self.visitTrueFalseUnknown( ctx.TRUE() )
      elif ctx.FALSE() is not None:
         finalBool = self.visitTrueFalseUnknown( ctx.FALSE() )
      elif ctx.UNKNOWN() is not None:
         finalBool = self.visitTrueFalseUnknown( ctx.UNKNOWN() )
      else:
         finalBool = None

      return RcfAst.SequentialExpr( ctx, funcCalls, finalBool )

def genAst( parseTree, diag ):
   """ Build an AST given a valid parse tree.

   Args:
      parseTree, the parse tree we got from the parsing phase.

   Return:
      AstNode (Ast.Root): The root of the AST.
   """
   walkParseTreeAndBuildAst = AstGenVisitor( diag )
   return walkParseTreeAndBuildAst( parseTree )
