#!/usr/bin/env python
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

from CliModel import ( Model,
                       List,
                       Int,
                       Str,
                       Enum,
                       Float,
                       Dict,
                       Bool,
                       DeferredModel,
                       Submodel,
)
from ArnetModel import ( IpGenericPrefix,
                         IpGenericAddress,
)
import TableOutput
from Ark import ( utcTimeRelativeToNowStr,
                  timestampToStr,
)
from EosRpkiLib import createTableOutput
import Tac
from TypeFuture import TacLazyType
import Toggles.RpkiToggleLib

rpkiDefaults = TacLazyType( 'Rpki::RpkiDefaults' )

# This is the same for ipv4 and ipv6 ROAs.
class Roa( DeferredModel ):
   prefix = IpGenericPrefix( help="Prefix" )
   asn = Int( help="Autonomous System Number" )
   maxLength = Int( help="Max prefix length" )

class RpkiShowRoaModel( DeferredModel ):
   roas = List( valueType=Roa, help="Resource Public Key Infrastructure (RPKI)" +
                "Route Origin Authorizations (ROAs)" )

CACHE_STATE_LIST = ( 'idle', 'connected', 'requestRoaReset',
                     'requestRoaUpdate', 'syncing', 'synced', 'disabled',
                     'disconnected', 'misconfigured' )

# Corresponds to enum RpkiProtocolVersion -( 0, 1, -1 )
PROTOCOL_VERSION_LIST = ( 'version0', 'version1', 'notNegotiatedYet' )

TCP_SOCKET_STATE_LIST = ( 'NOTUSED', 'ESTABLISHED', 'SYN-SENT',
                         'SYN-RECEIVED', 'FIN-WAIT-1', 'FIN-WAIT-2',
                         'TIME-WAIT', 'CLOSED', 'CLOSE-WAIT',
                         'LAST-ACK', 'LISTEN', 'CLOSING' )

TCPI_OPT_TIMESTAMPS = 1
TCPI_OPT_SACK = 2
TCPI_OPT_WSCALE = 4
TCPI_OPT_ECN = 8
TCPI_OPT_ECN_SEEN = 16
TCPI_OPT_SYN_DATA = 32

class TcpOptionsModel( Model ):
   timestamps = Bool( default=False, help='Timestamps enabled' )
   selectiveAcks = Bool( default=False, help='Selective Acknowledgments enabled' )
   windowScale = Bool( default=False, help='Window Scale enabled' )
   ecn = Bool( default=False, help='Explicit Congestion Notification enabled' )

class TcpInformationModel( Model ):
   localTcpAddress = IpGenericAddress( help='TCP local address' )
   localTcpPort = Int( help='TCP local port' )
   state = Enum( values=TCP_SOCKET_STATE_LIST, help='TCP Socket current state' )
   options = Submodel( valueType=TcpOptionsModel, help='TCP socket options' )
   sendWindowScale = Int( help='TCP socket send window scale factor' )
   recvWindowScale = Int( help='TCP socket receive window scale factor' )
   retransmitTimeout = Int( help='TCP socket retransmission timeout '
                                 '(microseconds)' )
   delayedAckTimeout = Int( help='TCP socket delayed ack timeout (microseconds)' )
   maxSegmentSize = Int( help='TCP outgoing maximum segment size (bytes)' )
   sendRtt = Int( help='TCP round-trip time (microseconds)' )
   sendRttVariance = Int( help='TCP round-trip time variance (microseconds)' )
   slowStartThreshold = Int( help='TCP send slow start size threshold (bytes)' )
   congestionWindow = Int( help='TCP send congestion window (bytes)' )
   recvRtt = Int( help='TCP receive round-trip time (microseconds)' )
   recvWindow = Int( help='TCP advertised receive window (bytes)' )
   totalRetrans = Int( help='Total number of TCP retransmissions' )
   inputQueueLength = Int( help='TCP input queue length' )
   inputQueueMaxLength = Int( help='TCP input queue max length' )
   outputQueueLength = Int( help='TCP output queue length' )
   outputQueueMaxLength = Int( help='TCP output queue max length' )

   def setAttrFromDict( self, data ):
      self.options = TcpOptionsModel()
      if 'options' in data:
         optionFlag = data.pop( 'options' )
         self.options.timestamps = bool( optionFlag & TCPI_OPT_TIMESTAMPS )
         self.options.selectiveAcks = bool( optionFlag & TCPI_OPT_SACK )
         self.options.windowScale = bool( optionFlag & TCPI_OPT_WSCALE )
         self.options.ecn = bool( optionFlag & TCPI_OPT_ECN )

      if 'state' in data:
         data[ 'state' ] = TCP_SOCKET_STATE_LIST[ data[ 'state' ] ]

      for key in data:
         setattr( self, key, data[ key ] )

   def render( self ):
      if self.state is None:
         return

      def valueOrNA( value ):
         return value if value is not None else "not available"

      def enabledOrDisabled( value ):
         return 'enabled' if value else 'disabled'

      outputQueueLength = valueOrNA( self.outputQueueLength )
      outputQueueMaxLength = valueOrNA( self.outputQueueMaxLength )
      inputQueueLength = valueOrNA( self.inputQueueLength )
      inputQueueMaxLength = valueOrNA( self.inputQueueMaxLength )

      print( "TCP Socket Information:" )
      print( "  Local TCP address is {}, local port is {}".format(
         self.localTcpAddress, self.localTcpPort ) )
      print( "  TCP state is {}".format( self.state ) )
      print( "  Send-Q: {}/{}".format( outputQueueLength, outputQueueMaxLength ) )
      print( "  Recv-Q: {}/{}".format( inputQueueLength, inputQueueMaxLength ) )
      print( "  Outgoing Maximum Segment Size (MSS): {}".format(
         self.maxSegmentSize ) )
      print( "  Total Number of TCP retransmissions: {}".format(
         self.totalRetrans ) )
      print( "  Options:" )
      print( "    Timestamps: {}".format(
         enabledOrDisabled( self.options.timestamps ) ) )
      print( "    Selective Acknowledgments: {}".format(
         enabledOrDisabled( self.options.selectiveAcks ) ) )
      print( "    Window Scale: {}".format(
         enabledOrDisabled( self.options.windowScale ) ) )
      print( "    Explicit Congestion Notification (ECN): {}".format(
         enabledOrDisabled( self.options.ecn ) ) )
      print( "  Socket Statistics:" )
      print( "    Window Scale (wscale): {},{}".format(
         self.sendWindowScale, self.recvWindowScale ) )
      print( "    Retransmission Timeout (rto): %.1fms" %
            ( self.retransmitTimeout / 1000.0 ) )
      print( "    Round-trip Time (rtt/rtvar): %.1fms/%.1fms" %
            ( self.sendRtt / 1000.0, self.sendRttVariance / 1000.0 ) )
      print( "    Delayed Ack Timeout (ato): %.1fms" %
            ( self.delayedAckTimeout / 1000.0 ) )
      print( "    Congestion Window (cwnd): {}".format( self.congestionWindow ) )

      if self.slowStartThreshold < 65535:
         print( "    Slow-start Threshold (ssthresh): {}".format(
            self.slowStartThreshold ) )
      if self.sendRtt > 0 and self.maxSegmentSize and self.congestionWindow:
         print( "    TCP Throughput: %.2f Mbps" %
                ( float( self.congestionWindow ) *
                  ( float( self.maxSegmentSize ) * 8. / float( self.sendRtt ) ) ) )

      if self.recvRtt:
         print( "    Recv Round-trip Time (rcv_rtt): %.1fms" %
               ( float( self.recvRtt ) / 1000 ) )
      if self.recvWindow:
         print( "    Advertised Recv Window (rcv_space): {}".format(
            self.recvWindow ) )

class CacheDetailModel( Model ):
   if Toggles.RpkiToggleLib.toggleRpkiCacheTcpKeepaliveEnabled():
      tcpKeepaliveIdleTime = Int( help="Idle time (seconds) before TCP keepalive" )
      tcpKeepaliveProbeInterval = Int(
            help="Interval (seconds) between TCP keepalive probes" )
      tcpKeepaliveProbeCount = Int(
            help="Number of keepalive probes before closing connection" )
   lastConfigChange = Float( help="Time when the host, port, VRF or local "
                                  "interface was changed in the config in UTC" )
   lastConnectionError = Float( help="Time of last connection error to "
                                     "cache server in UTC" )
   lastProtocolError = Float( help="Time of sending/receiving protocol "
                                   "error to/from cache server in UTC" )
   overrideConfiguredIntervalValues = Bool( help="Cache suggested or default "
                                                 "interval values will be used" )
   tcpInformation = Submodel( valueType=TcpInformationModel,
                              optional=True, help='TCP information' )

class CacheModel( Model ):
   host = Str( help="IP address or hostname of cache server" )
   port = Int( help="Transport port of cache server" )
   vrf = Str( help="VRF where the cache server is configured" )
   refreshInterval = Int( help="Number of seconds between cache server polls" )
   retryInterval = Int( help="Number of seconds between poll error and cache "
                             "server poll" )
   expireInterval = Int( help="Number of seconds to retain data synced from cache "
                              "server" )
   preference = Int( help="Cache server preference. Lower value means "
                          "higher preference" )
   protocolVersion = Enum( values=PROTOCOL_VERSION_LIST,
                           help="RPKI-RTR protocol version" )
   state = Enum( values=CACHE_STATE_LIST, help="Cache server current state" )
   sessionId = Int( help="Session ID for this cache instance" )
   serialNumber = Int( help="Serial number for the last successful sync" )
   lastUpdateSync = Float( help="Time of last serial sync with "
                                "cache server in UTC" )
   lastFullSync = Float( help="Time of last reset sync with cache "
                              "server in UTC" )
   lastSerialQuery = Float( help="Time of last serial query sent to "
                                 "cache server in UTC" )
   lastResetQuery = Float( help="Time of last reset query sent to "
                                "cache server in UTC" )
   entries = Int( help="Number of Route Origin Authorization entries" )
   deleted = Bool( help="The config for this cache has been deleted" )
   activeSince = Float( help='Time the connection became active in UTC' )
   reasonInactive = Str( help='Reason the connection is not active', optional=True )
   transportProtocol = Enum(
         values=Tac.Type( "Rpki::RpkiTransportAuthenticationType" ).attributes,
         help="Transport protocol for connection to RPKI cache server" )
   detail = Submodel( valueType=CacheDetailModel, help='Detail information',
                      optional=True )

   def setProtocolVersion( self, protocolVersion ):
      self.protocolVersion = PROTOCOL_VERSION_LIST[ protocolVersion ]

   def setReasonInactive( self, errorReason, activeSince ):
      if activeSince:
         # Active connection
         self.reasonInactive = None
      elif errorReason:
         self.reasonInactive = errorReason
      else:
         # This wil be displayed only in CAPI output. Text output will not display
         # anything when the cache doesn't have a TCP connection if it's deleted
         # or less preferred or misconfigured.
         if self.deleted:
            self.reasonInactive = 'Cache is deleted'
         else:
            self.reasonInactive = 'Cache is idle or misconfigured'

   def printDetail( self ):
      if Toggles.RpkiToggleLib.toggleRpkiCacheTcpKeepaliveEnabled():
         if ( self.detail.tcpKeepaliveIdleTime == 0 and
              self.detail.tcpKeepaliveProbeInterval == 0 and
              self.detail.tcpKeepaliveProbeCount == 0 ):
            print( "TCP keepalive: disabled" )
         else:
            print( "TCP keepalive idle time: {} seconds".format(
               self.detail.tcpKeepaliveIdleTime ) )
            print( "TCP keepalive probe interval: {} seconds".format(
               self.detail.tcpKeepaliveProbeInterval ) )
            print( "TCP keepalive probe count: {}".format(
               self.detail.tcpKeepaliveProbeCount ) )
      print( "Last config change: {}".format(
         utcTimeRelativeToNowStr( self.detail.lastConfigChange ) ) )
      print( "Last connection error: {}".format(
         utcTimeRelativeToNowStr( self.detail.lastConnectionError ) ) )
      print( "Last protocol error: {}".format(
         utcTimeRelativeToNowStr( self.detail.lastProtocolError ) ) )
      if self.detail.tcpInformation:
         self.detail.tcpInformation.render()

   def printModel( self, cacheId ):
      print( "{}:".format( cacheId ) )
      if self.deleted:
         lastSync = max( self.lastFullSync, self.lastUpdateSync )
         if not lastSync:
            lastSync = max( self.lastSerialQuery, self.lastResetQuery )
         expireTime = lastSync + self.expireInterval
         # Convert expireTime to now time from utc Time
         expireTime = expireTime - Tac.utcNow() + Tac.now()
         print( "This cache's config has been removed;",
                "ROAs retained until expiry at",
                format( timestampToStr( expireTime, relative=False ) ) )
      print( "Host: {} port {}".format( self.host, self.port ) )
      print( "VRF: {}".format( self.vrf ) )
      if not self.deleted:
         print( "Refresh interval: {} seconds".format( self.refreshInterval ) )
         print( "Retry interval: {} seconds".format( self.retryInterval ) )
         print( "Expire interval: {} seconds".format( self.expireInterval ) )
         if self.detail and self.detail.overrideConfiguredIntervalValues:
            print( "Configured refresh interval is greater than expire "
                   "interval; Using cache suggested or default interval values" )
      print( "Preference: {}".format( self.preference ) )
      if not self.deleted:
         if self.protocolVersion == 'version0':
            protocolVersion = '0'
         elif self.protocolVersion == 'version1':
            protocolVersion = '1'
         else:
            protocolVersion = 'Not negotiated yet'
         print( "Protocol version: {}".format( protocolVersion ) )
         print( "State: {}".format( self.state ) )
         invalidSessionId = rpkiDefaults.invalidSessionId
         sessionId = self.sessionId if self.sessionId != invalidSessionId else "None"
         print( "Session ID: {}".format( sessionId ) )
         invalidSerialNumber = rpkiDefaults.invalidSerialNumber
         serialNumber = ( self.serialNumber
                          if self.serialNumber != invalidSerialNumber else "None" )
         print( "Serial number: {}".format( serialNumber ) )
      print( "Last update sync: {}".format(
         utcTimeRelativeToNowStr( self.lastUpdateSync ) ) )
      print( "Last full sync: {}".format(
         utcTimeRelativeToNowStr( self.lastFullSync ) ) )
      print( "Last serial query: {}".format(
         utcTimeRelativeToNowStr( self.lastSerialQuery ) ) )
      print( "Last reset query: {}".format(
         utcTimeRelativeToNowStr( self.lastResetQuery ) ) )
      print( "Entries: {}".format( self.entries ) )
      # Idle or misconfigured cache will not have any connection status.
      if self.activeSince:
         activeDuration = utcTimeRelativeToNowStr( self.activeSince )
         # Remove the word "ago"
         activeDuration = activeDuration.replace( ' ago', '', 1 )
         print( 'Connection: Active ({})'.format( activeDuration ) )
      elif self.reasonInactive and \
            'idle or misconfigured' not in self.reasonInactive and \
            'Cache is deleted' not in self.reasonInactive:
         print( 'Connection: {}'.format( self.reasonInactive ) )
      if self.detail:
         self.printDetail()

class RpkiShowCacheModel( Model ):
   caches = Dict( keyType=str, valueType=CacheModel,
                  help="Collection of cache server information indexed by cache "
                  "server's name" )

   def render( self ):
      if self.caches:
         for cacheId in sorted( self.caches ):
            self.caches[ cacheId ].printModel( cacheId )
            print( "\n" )

class CacheErrorCounterModel( Model ):
   corruptData = Int( help='Count of Corrupt Data PDU' )
   internalError = Int( help='Count of Internal Error PDU' )
   unsupportedProtocol = Int( help='Count of Unsupported Protocol PDU' )
   unsupportedPdu = Int( help='Count of Unsupported PDU' )
   unexpectedProtocol = Int( help='Count of Unexpected Protocol PDU' )
   other = Int( help='Count of Other PDU' )

class CacheRxErrorCounterModel( CacheErrorCounterModel ):
   noData = Int( help='Count of No Data PDU' )
   invalidRequest = Int( help='Count of Invalid Request PDU' )
   pduTimeout = Int( help='Count of timeouts waiting for PDU' )

   def renderRxErrorCounters( self, cacheName, table ):
      table.newRow( cacheName, self.corruptData, self.internalError,
                    self.unsupportedProtocol, self.unexpectedProtocol,
                    self.unsupportedPdu, self.noData, self.invalidRequest,
                    self.pduTimeout, self.other )

class CacheTxErrorCounterModel( CacheErrorCounterModel ):
   withdrawalUnknown = Int( help='Count of Withdrawal Unknown PDU' )
   duplicateAnnouncement = Int( help='Count of Duplicate Announcement PDU' )

   def renderTxErrorCounters( self, cacheName, table ):
      table.newRow( cacheName, self.corruptData, self.internalError,
                    self.unsupportedProtocol, self.unexpectedProtocol,
                    self.unsupportedPdu, self.withdrawalUnknown,
                    self.duplicateAnnouncement, self.other )

class RpkiShowCacheErrorCounterModel( Model ):
   cachesRx = Dict( keyType=str, valueType=CacheRxErrorCounterModel,
                  help='A mapping of cache names to their received error counters' )
   cachesTx = Dict( keyType=str, valueType=CacheTxErrorCounterModel,
                  help='A mapping of cache names to their sent error counters' )

   def render( self ):
      if not self.cachesRx:
         return
      headings = ( 'Cache', 'Corrupt Data', 'Internal Error',
                   'Unsupported Protocol', 'Unexpected Protocol', 'Unsupported PDU',
                   'Withdrawal Unknown', 'Duplicate Announcement', 'Other' )
      tableTx = TableOutput.TableFormatter()
      createTableOutput( tableTx, headings )
      print( 'Errors sent:' )
      for cacheId in sorted( self.cachesTx ):
         self.cachesTx[ cacheId ].renderTxErrorCounters( cacheId, tableTx )
      print( tableTx.output() )

      headings = ( 'Cache', 'Corrupt Data', 'Internal Error',
                   'Unsupported Protocol', 'Unexpected Protocol', 'Unsupported PDU',
                   'No Data', 'Invalid Request', 'PDU Timeout', 'Other' )
      tableRx = TableOutput.TableFormatter()
      createTableOutput( tableRx, headings )
      print( 'Errors received:' )
      for cacheId in sorted( self.cachesRx ):
         self.cachesRx[ cacheId ].renderRxErrorCounters( cacheId, tableRx )
      print( tableRx.output() )

class CacheCounterModel( Model ):
   serialNotify = Int( help="Count of Serial Notify PDU" )
   serialQuery = Int( help="Count of Serial Query PDU" )
   resetQuery = Int( help="Count of Reset Query PDU" )
   cacheResponse = Int( help="Count of Cache Response PDU" )
   ipv4Prefix = Int( help="Count of IPv4 Prefix PDU" )
   ipv6Prefix = Int( help="Count of IPv6 Prefix PDU" )
   endOfData = Int( help="Count of End of Data PDU" )
   cacheReset = Int( help="Count of Cache Reset PDU" )

   def renderCounterModel( self, cacheName, table ):
      table.newRow( cacheName, self.serialNotify, self.serialQuery, self.resetQuery,
                    self.cacheResponse, self.ipv4Prefix + self.ipv6Prefix,
                    self.endOfData, self.cacheReset )

class RpkiShowCacheCounterModel( Model ):
   caches = Dict( keyType=str, valueType=CacheCounterModel,
                    help="A mapping of cache names to their counters" )

   def render( self ):
      if not self.caches:
         return
      headings = ( "Cache", "Serial Notify", "Serial Query", "Reset Query",
                   "Cache Response", "Prefix PDU", "End Of Data", "Cache Reset" )
      table = TableOutput.TableFormatter()
      createTableOutput( table, headings )

      for cacheId in sorted( self.caches ):
         self.caches[ cacheId ].renderCounterModel( cacheId, table )

      print( table.output() )

class CacheRoaSummaryModel( Model ):
   ipv4RoaCount = Int( help="Count of IPv4 ROAs" )
   ipv6RoaCount = Int( help="Count of IPv6 ROAs" )

   def renderRoaSummaryModel( self, cacheName, table ):
      table.newRow( cacheName, self.ipv4RoaCount, self.ipv6RoaCount )

class RpkiShowRoaSummaryModel( Model ):
   caches = Dict( keyType=str, valueType=CacheRoaSummaryModel,
                  help="A mapping of cache names to the count of ROAs" )
   total = Submodel( valueType=CacheRoaSummaryModel,
                     help="Total number of ROAs in the local database" )

   def render( self ):
      if not self.caches:
         return
      headings = ( "Cache", "IPv4", "IPv6" )
      table = TableOutput.TableFormatter()
      createTableOutput( table, headings )

      for cacheId in sorted( self.caches ):
         self.caches[ cacheId ].renderRoaSummaryModel( cacheId, table )
      print( table.output() )

      # Add the total output
      print( "Total ROAs in local database:" )
      print( "IPv4: %d" % self.total.ipv4RoaCount )
      print( "IPv6: %d" % self.total.ipv6RoaCount )
