# Copyright (c) 2015 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import errno
import socket
import array
from struct import unpack_from

# Tests in Arrow should import Marco before importing this module.
# We do not import Marco here as other EOS code may wish to use the TacMarco version
# of MarcoDebug.
import MarcoDebug as Debug

import Arrow.ClientLibConstants as clc
import Arrow.monotonic_time as mTime
from baseRow import BaseRow
from Arrow.Protocol.Handshake import Handshake as AtpHandshake
from Arrow.Protocol.Control import Control as AtpControl

mType = clc.ArrowMessageType()
dh = Debug.Handle( handleName="ArrowMessageHandler" )

# suppress pylint errors caused by not finding methods which are dynamically
# generated by protocolBuffers during initialization
# pylint: disable-msg=E1101

# FIXME: these should be moved to a utils.py file somewhere
# indicates that the connection to the server has failed for some reason
class ConnectionError( Exception ):
   pass

# indicates that the connection to the server couldn't be established
class HandshakeError( ConnectionError ):
   pass

# indicates that the connection to the server has been lost
class LostConnectionError( ConnectionError ):
   pass

class MessageHandler( object ):
   def __init__( self, sock, timeoutHandler = None, 
                 timeoutTime=0.2, blocking=True, 
                 readHandler=None, writeHandler=None ):
      """
      If this messageHandler is blocking, then connections blocks on read,
      with a timeoutTime timeout, and any work other than reads are done
      in timeoutHandler.  Reads are always done in readHandler, which gets 
      passed 2 args msgType and message.
      When messageHandler is non-blocking, the timeout handler should not perform
      any writes.  Instead, writeHandler is called when the socket is.
      """
      dh.DEBUG8( "MessageHandler.__init__: self.timeInterval: "
                 + str( timeoutTime ) )
      self.timeInterval = timeoutTime # 200 ms timeout
      self.sock_ = sock
      self.sock_.setblocking( 1 if blocking else 0 )
      self.sock_.settimeout( self.timeInterval )
      self.readHandler_ = readHandler
      self.writeHandler_ = writeHandler
      self.toHandler_ = timeoutHandler
      self.sendOffset_ = 0

      self.running_ = True
      self.timeoutTime = 0.0

      self.recvWakeups_ = 0
      self.recvMsgs_ = 0
      self.sendAttempts_ = 0
      self.sendCalls_ = 0
      self.sendMsgs_ = 0
      self.partialSends_ = 0
      self.requeuedSends_ = 0

      self.sendBuffer_ = None
      self.atpControlRowType_ = None
      self.recvMsgType_ = None
      self.recvTableId_ = None
      self.recvTablePath_ = None
      self.sendMsgType_ = None
      self.sendTableId_ = None
      self.sendTablePath_ = None

      # pylint: disable-msg=W0212
      AtpControl._rowTypeIdIs( 0 )

      self.remoteControlRowId_ = None
      self.firstRead_ = True

      self.sendProtocol()

   def terminate( self ):
      dh.DEBUG9( "MessageHandler for sock %s terminating" % self.sock_ )
      self.running_ = False

   # Receive 'length' bytes from the socket, blocking until we read the requested
   # number of bytes. Specify peek=True to receive the bytes without removing them
   # from the kernel socket buffer.
   # Returns a string representing the data received. If successful, the length of
   # the return value matches 'length'. If fewer than 'length' bytes are returned, it
   # is because a signal was caught, or an error or disconnect occurred.
   # Raises socket.timeout if trying to receive all 'length' bytes from the socket
   # causes us to block for longer than self.timeInterval.
   def readFromSocket( self, length, peek=False ):
      dh.DEBUG9( "readFromSocket: length: " + str( length ) + " peek: " +
                 str( peek ) )
      # MSG_WAITALL blocks until we receive all bytes, or some error occurs.
      flags = socket.MSG_WAITALL
      if peek:
         flags |= socket.MSG_PEEK
      buf = self.sock_.recv( length, flags )
      if len( buf ) != length:
         dh.DEBUG8( "readFromSocket: Wanted to %s %d bytes, got %d. Most likely, "
                    "connection lost."
                    % ( ( "peek at " if peek else "recv " ), length, len( buf ) ) )
      return buf

   # Read a full row from the socket.
   # On successfully reading a row, returns an array of bytes representing the row,
   # or the empty string for a null row.
   # Returns None if an error or disconnection occurs.
   # Raises socket.timeout if we blocked longer than self.timeInterval trying to read
   # a row.
   def readRow( self ):
      dh.DEBUG9( "readRow: enter" )
      # peek at the row length - if the blocking socket read results in a timeout,
      # the data is left in the kernel recv queue.
      rowLenBuf = self.readFromSocket( clc.rowLengthLength, peek=True )
      if len( rowLenBuf ) != clc.rowLengthLength:
         return None

      buf = array.array( 'B', rowLenBuf )
      ( rowLength, ) = unpack_from( "<H", buf, 0 )
      if rowLength > 0:
         # try to receive the whole row (including the length)
         rowBuf = self.readFromSocket( rowLength )
         if len( rowBuf ) != rowLength:
            return None
         
         dh.DEBUG9( "readRow: len = %d row = %s"
                    % ( len( rowBuf ),
                        ':'.join( x.encode( 'hex' ) for x in rowBuf ) ) )
      else:
         # it's a null row - take the bytes we peeked at out of the recv queue.
         dh.DEBUG9( "readRow: got a null row" )
         rowLenBuf = self.readFromSocket( clc.rowLengthLength )
         if len( rowLenBuf ) != clc.rowLengthLength: 
            return None
         rowBuf = ""

      dh.DEBUG9( "readRow: exit" )
      return rowBuf

   def readProtocolVersion( self ):
      dh.DEBUG9( "reading protocol version" )
      row = self.readRow()
      assert row is not None
      handshakeRow = AtpHandshake( row=row )
      assert handshakeRow.protocolVersion() == clc.ArrowATPVersion
      self.remoteControlRowId_ = handshakeRow.controlRowType()
      dh.DEBUG9( "protocol version ok" )
      return True

   def processControlRow( self, row ):
      controlRow = AtpControl( row=row )
      self.recvMsgType_ = controlRow.msgType()
      self.recvTableId_ = controlRow.tableId()
      self.recvTablePath_ = controlRow.tablePath()

   def readAndDecodeMsg( self ):
      errorTuple = ( 0, 0, None, None, None )
      row1 = ""
      row2 = ""
      timeout = True
      while timeout and self.running_:
         try:
            timeout = False
            self.recvWakeups_ += 1
            if self.firstRead_:
               connectionOk = self.readProtocolVersion()
               assert( connectionOk )
               self.firstRead_ = False
            row1 = self.readRow()
            if row1 is None:
               return errorTuple
            elif row1 != "":
               brow = BaseRow( row=row1 )
               if brow.rowType() == self.remoteControlRowId_:
                  self.processControlRow( row1 )
                  row1 = self.readRow()
                  if row1 is None:
                     return errorTuple
            if self.recvMsgType_ == mType.rowDelRange:
               row2 = self.readRow()
               if row2 is None:
                  return errorTuple
            self.recvMsgs_ += 1
            break
         except socket.timeout:
            dh.DEBUG9( "MessageHandler.readAndDecodeMsg: caught socket.timeout" )
            timeout = True
         except socket.error as msg:
            dh.DEBUG1( "exception in socket handling: %s" % msg )
            return errorTuple
         finally:
            currentTime = mTime.monotonic_time()
            if timeout or ( currentTime - self.timeoutTime > self.timeInterval ):
               self.timeoutTime = currentTime
               if self.toHandler_ is not None:
                  ( self.toHandler_ )()

      if not self.running_:
         dh.DEBUG9( "not running" )
         return errorTuple

      return ( self.recvMsgType_, self.recvTableId_, self.recvTablePath_, row1,
               row2 )

   def sendMessage( self, message ):
      self.sendAttempts_ += 1
      queued = self.sendBuffer_ is not None
      if queued:
         msgLen = len( self.sendBuffer_ )
         origin = self.sendOffset_
      else:
         msgLen = len( message )
         origin = 0
         self.sendBuffer_ = message
      while True:
         try:
            sent = 0 # create "sent" in case next line raises error
            self.sendCalls_ += 1
            sent = self.sock_.send( self.sendBuffer_[origin:],
                                    socket.MSG_DONTWAIT )
         except socket.timeout:
            break
         except socket.error, e:
            if( e.errno == errno.EWOULDBLOCK or
                e.errno == errno.EAGAIN ):
               break
            if e.errno == errno.EPIPE:
               raise LostConnectionError, e.message
            dh.DEBUG0( "socket.send() socket.error: " + str( e ) )
            raise
         except IOError, e:
            if( e.errno == errno.EWOULDBLOCK or
                e.errno == errno.EAGAIN ):
               break
            if e.errno == errno.EPIPE:
               raise LostConnectionError, e.message
            dh.DEBUG1( "socket.send() IOError: " + str( e ) )
            raise
         except Exception, e:
            if e.errno == errno.EPIPE:
               raise LostConnectionError, e.message
            dh.DEBUG0( "socket.send() Exception: " + str( e ) )
            raise
         # any exception should have either re-raised or broke out of the loop
         # so sent is value actually returned by socket.send()
         if sent == 0: # EOF?
            break
         origin += sent
         assert origin <= msgLen, "sent data past end of msg!"
         if origin == msgLen:
            self.sendMsgs_ += 1
            if queued:
               dh.DEBUG9( "Finished queued msg" )
               queued = False
               self.sendBuffer_ = message
               msgLen = len( message )
               origin = 0
            else: # done
               break
      if origin < msgLen:
         self.partialSends_ += 1
         # self.sendBuffer_ should contain message .. haven't finished it yet
         self.sendOffset_ = origin
      else:
         self.sendBuffer_ = None
         self.sendOffset_ = 0
      # if we've either completed or buffered the message, then no need to 
      # requeue (and "queued" would be False).  
      # Otherwise, though, we need to requeue
      if queued:
         self.requeuedSends_ += 1
      return queued 
   
   def sendProtocol( self ):
      handshakeRow = AtpHandshake()
      handshakeRow.protocolVersionIs( clc.ArrowATPVersion )
      handshakeRow.controlRowTypeIs( AtpControl.protRowType() )
      dh.DEBUG9( "sendProtocol " + handshakeRow.toString() )
      return self.sendMessage( handshakeRow.buf().tostring() )

   def sendControlRow( self, msgType, tableId, tablePath ):
      controlRow = AtpControl()
      controlRow.rowTypeIs( AtpControl.protRowType() )
      controlRow.msgTypeIs( msgType )
      controlRow.tableIdIs( tableId )
      controlRow.tablePathIs( tablePath )
      self.sendMsgType_ = msgType
      self.sendTableId_ = tableId
      self.sendTablePath_ = tablePath
      dh.DEBUG9( "sendControlRow " + controlRow.toString() )
      return self.sendMessage( controlRow.buf().tostring() )

   def sendRowOp( self, tableId, row, pathName, msgType ):
      if pathName is None:
         pathName = ""
      if ( msgType != self.sendMsgType_ or
           ( tableId == 0 and pathName != self.sendTablePath_ ) or
           ( tableId != 0 and tableId != self.sendTableId_ ) ):
         self.sendControlRow( msgType, tableId, pathName )
      rowstr = row.buf().tostring()
      dh.DEBUG9( "sendRowOp %d row: %s, tbl=%d pathName = %s"
                  % ( msgType, row.toString(), tableId, pathName ) )
      return self.sendMessage( rowstr )

   def sendRowIs( self, tableId, row, pathName ):
      return self.sendRowOp( tableId, row, pathName, mType.rowIs )

   def sendRowDel( self, tableId, row, pathName ):
      return self.sendRowOp( tableId, row, pathName, mType.rowDel )

   # FIXME: need to generate sendRowDelRange

   def statsToString( self ):
      rstr = "recv: msgs/wakeups: %d/%d" % ( self.recvMsgs_, self.recvWakeups_ )
      fmt = "send: msgs/part/re-q/attempts/calls %d/%d/%d/%d/%d"
      sstr = fmt % ( self.sendMsgs_, self.partialSends_, self.requeuedSends_,
                     self.sendAttempts_, self.sendCalls_ )
      return rstr + "\n" + sstr

   def showStats( self ):
      print self.statsToString()
