#!/usr/bin/env python
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

# Constructional tokens in the ENBF grammar.
START_OPTIONAL = '['
END_OPTIONAL = ']'
START_GROUP = '('
END_GROUP = ')'
START_REPETITION = '{'
END_REPETITION = '}'
ALTERNATIVE_DELIMITER = '|'

TOKEN_DELIMITER = ' '
EBNF_TOKENS = frozenset( [ START_OPTIONAL, END_OPTIONAL,
                           START_GROUP, END_GROUP,
                           START_REPETITION, END_REPETITION,
                           ALTERNATIVE_DELIMITER ] )

class EbnfParseNode( object ):
   __slots__ = ( 'name_', 'nextNodes_', 'terminal_', 'repeatable_', 'canMerge_' )

   def __init__( self, name, canMerge ):
      assert name != '', 'each node must have a name'
      self.name_ = name
      self.nextNodes_ = []
      self.terminal_ = True
      self.repeatable_ = False
      self.canMerge_ = canMerge

   def getName( self ):
      return self.name_

   def addNextNode( self, node ):
      self.nextNodes_.append( node )

   def getNextNodes( self ):
      return self.nextNodes_

   def setTerminal( self, terminal ):
      self.terminal_ = terminal

   def isTerminal( self ):
      return self.terminal_

   def isRepeatable( self ):
      return self.repeatable_

   def setRepeatable( self, repeatable ):
      self.repeatable_ = repeatable

   def canMerge( self ):
      return self.canMerge_

class EbnfParseTree( object ):
   __slots__ = ( 'nextNodes_', )

   def __init__( self ):
      self.nextNodes_ = []

   def addNextNode( self, node ):
      self.nextNodes_.append( node )

   def getNextNodes( self ):
      return self.nextNodes_

   def getTerminalNodes( self, nextNodes=None, visitedNodes=None ):
      if nextNodes is None:
         nextNodes = self.nextNodes_

      if visitedNodes is None:
         visitedNodes = set()

      results = []
      for node in nextNodes:
         if node in visitedNodes:
            continue
         visitedNodes.add( node )
         if node.isTerminal():
            results.append( node )
         results += self.getTerminalNodes( node.getNextNodes(), visitedNodes )
      return set( results )

   def appendToTerminalNodes( self, nodes, markTermialNodes=True ):
      if not self.nextNodes_:
         for node in nodes:
            self.nextNodes_.append( node )
      else:
         for terminalNode in self.getTerminalNodes():
            for node in nodes:
               terminalNode.addNextNode( node )
               if markTermialNodes:
                  terminalNode.setTerminal( False )

   def setNodesAsRepeatble( self, nextNodes=None, visitedNodes=None ):
      if nextNodes is None:
         nextNodes = self.nextNodes_

      if visitedNodes is None:
         visitedNodes = set()

      for node in nextNodes:
         if node in visitedNodes:
            continue
         visitedNodes.add( node )
         node.setRepeatable( True )
         self.setNodesAsRepeatble( node.getNextNodes(), visitedNodes )

class _Expression( object ):
   __slots__ = ( 'sequences_', )

   def __init__( self ):
      self.sequences_ = []

   def addSequence( self, sequence ):
      self.sequences_.append( sequence )

   def popPrevSeqeunce( self ):
      return self.sequences_.pop()

   def numSequences( self ):
      return len( self.sequences_ )

   def isOptional( self ):
      return all( seq.isOptional() for seq in self.sequences_ )

   def _getOptionalNodes( self, sequences ):
      ebnfParseTree = EbnfParseTree()
      firstPass = True
      while sequences:
         sequence = sequences.pop( 0 )
         subtreeOptional = isinstance( sequence, _Optional )
         subEbnfParseTree = sequence.getEbnfParseTree()
         nodes = subEbnfParseTree.getNextNodes()
         ebnfParseTree.appendToTerminalNodes( nodes, not subtreeOptional )
         if not firstPass:
            for node in nodes:
               ebnfParseTree.addNextNode( node )

         if not subtreeOptional:
            return True, ebnfParseTree.getNextNodes()
         firstPass = False
      return False, ebnfParseTree.getNextNodes()

   def _checkOptOrOpt( self ):
      '''This is for bug 422312
      If we have the syntax "a [ b ] | [ c ]", the parser does not accept "a".
      The following code asserts if it finds such pattern.'''
      for s in self.sequences_:
         if isinstance( s, _Alternative ):
            left = s.left
            right = s.right
            if left.isOptional() and right.isOptional():

               # Strip brackets and wrap in parens if expressions
               # cosnsist of more than one token.
               left = str( left )[ 2 : -2 ]
               right = str( right )[ 2 : -2 ]
               badSyntax = str( s )
               if ' ' in left:
                  left = '( %s )' % left
               if ' ' in right:
                  right = '( %s )' % right
               goodSyntax = '[ %s | %s ]' % ( left, right )
               error = "Instead of having {badSyntax!r}, make it {goodSyntax!r}"
               assert 0, error.format( badSyntax=badSyntax, goodSyntax=goodSyntax )

   def getEbnfParseTree( self ):
      self._checkOptOrOpt()
      ebnfParseTree = EbnfParseTree()
      sequences = [ i for i in self.sequences_ ]
      while sequences:
         subtreeOptional = isinstance( sequences[ 0 ], _Optional )
         markTerminalNodes = True
         if subtreeOptional:
            markTerminalNodes, nodes = self._getOptionalNodes( sequences )
         else:
            sequence = sequences.pop( 0 )
            subEbnfParseTree = sequence.getEbnfParseTree()
            nodes = subEbnfParseTree.getNextNodes()
         ebnfParseTree.appendToTerminalNodes( nodes,
                                              markTermialNodes=markTerminalNodes )

      return ebnfParseTree

   def __str__( self ):
      return ' '.join( str( i ) for i in self.sequences_ )

class _DelimitedExpression( object ):
   __slots__ = ( 'expression_', 'template' )
   template = None

   def __init__( self, expression ):
      self.expression_ = expression

   def getEbnfParseTree( self ):
      return self.expression_.getEbnfParseTree()

   def isOptional( self ):
      return self.expression_.isOptional()

   def __str__( self ):
      return self.template % self.expression_

class _Repetition( _DelimitedExpression ):
   template = '{ %s }'

   def getEbnfParseTree( self ):
      ebnfParseTree = self.expression_.getEbnfParseTree()
      terminalNodes = ebnfParseTree.getTerminalNodes()
      ebnfParseTree.setNodesAsRepeatble()
      for terminalNode in terminalNodes:
         for node in ebnfParseTree.getNextNodes():
            terminalNode.addNextNode( node )
      return ebnfParseTree

class _Group( _DelimitedExpression ):
   template = '( %s )'

class _Optional( _DelimitedExpression ):
   template = '[ %s ]'

   def isOptional( self ):
      return True

class _Alternative( object ):
   __slots__ = ( 'left', 'right', 'canMerge_' )

   def __init__( self, left, right ):
      self.left = left
      self.right = right

   def isOptional( self ):
      return self.left.isOptional() and self.right.isOptional()

   def getEbnfParseTree( self ):
      ebnfParseTree = EbnfParseTree()
      for child in ( self.left, self.right ):
         for node in child.getEbnfParseTree().getNextNodes():
            ebnfParseTree.addNextNode( node )
      return ebnfParseTree

   def __str__( self ):
      return '%s | %s' % ( self.left, self.right )

class _Term( object ):
   __slots__ = ( 'term_', 'canMerge_' )

   def __init__( self, term, canMerge ):
      self.term_ = term
      self.canMerge_ = canMerge

   def isOptional( self ):
      return False

   def getEbnfParseTree( self ):
      ebnfParseTree = EbnfParseTree()
      ebnfParseTree.addNextNode( EbnfParseNode( self.term_, self.canMerge_ ) )
      return ebnfParseTree

   def __str__( self ):
      return self.term_

def expectToken_( tokens, expectedToken ):
   token = tokens.pop( 0 )
   assert token == expectedToken, ( 'Expected "%s" got "%s" instead' %
                                     ( expectedToken, token ) )
   return token

# The ENBF language itself can be recursively defined using ENBF as:
# sequence := ( alternative | repetition | optional | group | term ) [ sequence ]
# expression := sequence {'|' sequence}
# alternative := expression '|' expression
# repetition := '{' expression '}'
# optional := '[' expression ']'
# group := '(' expression ')'
# term := ..

def parseTokens_( tokens, canMerge=True, checkNumSequences=False ):
   """
   Given a sequence of tokens, generate a syntax tree.
   We assume there are no two tokens with an identical name.
   """
   expression = _Expression()
   while tokens:
      nextToken = tokens[ 0 ]
      if nextToken == START_OPTIONAL:
         expectToken_( tokens, START_OPTIONAL )
         canMerge = False
         result = _Optional( parseTokens_( tokens, canMerge=canMerge ) )
         expectToken_( tokens, END_OPTIONAL )
      elif nextToken == START_GROUP:
         expectToken_( tokens, START_GROUP )
         result = _Group( parseTokens_( tokens, canMerge=canMerge ) )
         expectToken_( tokens, END_GROUP )
      elif nextToken == START_REPETITION:
         expectToken_( tokens, START_REPETITION )
         canMerge = False
         result = _Repetition( parseTokens_( tokens, canMerge=canMerge ) )
         expectToken_( tokens, END_REPETITION )
      elif nextToken == ALTERNATIVE_DELIMITER:
         expectToken_( tokens, ALTERNATIVE_DELIMITER )
         canMerge = False
         prevExpression = expression.popPrevSeqeunce()
         prevExpression.canMerge_ = False
         nextExpression = parseTokens_( tokens, canMerge=canMerge,
                                        checkNumSequences=True )
         result = _Alternative( prevExpression, nextExpression )
      elif nextToken in ( END_OPTIONAL, END_GROUP, END_REPETITION ):
         return expression
      else:
         result = _Term( tokens.pop( 0 ), canMerge )

      expression.addSequence( result )
      if checkNumSequences and expression.numSequences():
         return expression

   return expression

def parseEbnf( syntax ):
   tokens = tokenize( syntax )
   expression = parseTokens_( tokens )
   return expression.getEbnfParseTree()

def tokenize( syntax ):
   """
   Convert the {syntax} into a {list} of tokens

   :param syntax: syntax
   :param type: str
   :returns: a list of tokens
   """
   for token in EBNF_TOKENS:
      syntax = syntax.replace( token, '%s%s%s' %
                                ( TOKEN_DELIMITER, token, TOKEN_DELIMITER ) )
   syntax = syntax.replace( '\n', TOKEN_DELIMITER )
   return [ syn.strip() for syn in syntax.strip().split( TOKEN_DELIMITER ) if syn ]

def isKeyword( name ):
   return not ( name.startswith( '<' ) and name.endswith( '>' ) )

def stripBracket( name ):
   return name if isKeyword( name ) else name[ 1 : -1 ]

def getTerms( tokens ):
   return [ token for token in tokens if token not in EBNF_TOKENS ]
