#!/usr/bin/env python
# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# Usage:
# Capture a BMP session using netcat, for example: netcat -l 5000 > /tmp/bmp_capture
# Add BGP config to connect to netcat:
# router bgp ...
#  monitoring station netcat
#    connection mode active port 5000
#    connection address 127.0.0.1
#
# When capture is complete (eg by checking states in 'show bgp monitoring internal',
# terminate the netcat process and decode /tmp/bmp_capture using this script.
#
# This script can perform the following functions:
# - Detailed dump of all BMP messages
# - Scan the dump and output summary statistics (default)
# - Filter the messages processed based on any combination of:
#   - BMP message type
#   - Peer
#   - Peer address family (ipv4 or ipv6)
#   - Prefix
#   - Pre-policy or post-policy
#
# Prefix filter is specified as <address>/<len>, eg "10.1.2.0/24"
#
# BmpDecode.py [options] <filename>
# Options:
# -d           Dump contents of messages
# -s           Output summary statistics (enabled by default)
#
# -t <type>    Filter by BMP message type:
#                0: Route Monitoring
#                1: Statistics Report
#                2: Peer Down Notification
#                3: Peer Up Notification
#                4: Initiation Message
#                5: Termination Message
#                6: Route Mirroring Message
#              Note: more than one -t argument is allowed
# -p <address> Filter by peer address
# -v           Only ipv4 peers
# -V           Only ipv6 peers
# -r <prefix>  Filter by a route prefix (ipv4 or ipv6 prefix)
# -l           Only pre-policy paths
# -L           Only post-policy paths
#

import argparse
#import dpkt
import BmpDpkt
#import BgpDpkt
import struct
import os
import sys
import datetime
#import time
from socket import AF_INET
from socket import AF_INET6
from socket import inet_ntop
from socket import inet_ntoa

def bmpMsgTypeStr( msgType ):
   msgTypeStr = {
      0 : "Route Monitoring",
      1 : "Stats Report",
      2 : "Peer Down",
      3 : "Peer Up",
      4 : "Initiation",
      5 : "Termination",
      6 : "Route Mirroring",
   }
   if msgType in msgTypeStr:
      return msgTypeStr[ msgType ]
   return "Unknown: " + str( msgType )

def dumpBmpHeader( bmpMsg ):
   print "Version:", bmpMsg.version
   print "Length:", bmpMsg.length
   print "Type:", bmpMsgTypeStr( bmpMsg.type )

def dumpInitiation( bmpMsg ):
   for tlv in bmpMsg.data.tlvs_:
      print "TLV type:", tlv.type, "Len:", tlv.length, "Data:", tlv.data

def dumpTermination( bmpMsg ):
   for tlv in bmpMsg.data.tlvs_:
      if tlv.type == 1:
         print "Reason Code:", tlv.reason_code

def dumpPeerHeader( bmpMsg ):
   peerHdr = bmpMsg.data.peer_hdr
   peerFlags = peerHdr.peer_flags
   if peerFlags & 0x80:
      peerAddr = inet_ntop( AF_INET6, peerHdr.peer_address )
   else:
      peerAddr = inet_ntop( AF_INET, peerHdr.peer_address[ 12 : ] )
   print "Peer Address:", peerAddr
   print "Peer AS:", peerHdr.peer_as
   print "Peer ID:", inet_ntoa( struct.pack( "!I", peerHdr.peer_bgp_id ) )
   dt = datetime.datetime.fromtimestamp( peerHdr.timestamp )
   dt = dt.replace( microsecond=peerHdr.timestamp_us )
   print "Timestamp:", dt.strftime( "%Y-%m-%d %H:%M:%S.%f" )

def dumpPeerUp( bmpMsg ):
   dumpPeerHeader( bmpMsg )
   peerUp = bmpMsg.data.peer_up
   peerHdr = bmpMsg.data.peer_hdr
   peerFlags = peerHdr.peer_flags
   if peerFlags & 0x80:
      localAddr = inet_ntop( AF_INET6, peerUp.local_addr )
   else:
      localAddr = inet_ntop( AF_INET, peerUp.local_addr[ 12 : ] )

   print "Local Address:", localAddr
   print "Local Port:", peerUp.local_port
   print "Remote Port:", peerUp.remote_port

   sentOpen = peerUp.sent_open.open
   print "Sent OPEN:"
   print "  Version:", sentOpen.v
   print "  ASN:", sentOpen.asn
   print "  Holdtime:", sentOpen.holdtime
   print "  Identifier:", inet_ntoa( struct.pack( "!I", sentOpen.identifier ) )
   for cap in peerUp.sent_open.data.data.capabilities:
      print " ", repr( cap )

   rcvdOpen = peerUp.rcvd_open.open
   print "Received OPEN:"
   print "  Version:", rcvdOpen.v
   print "  ASN:", rcvdOpen.asn
   print "  Holdtime:", rcvdOpen.holdtime
   print "  Identifier:", inet_ntoa( struct.pack( "!I", rcvdOpen.identifier ) )
   for cap in peerUp.rcvd_open.data.data.capabilities:
      print " ", repr( cap )

def dumpPeerDown( bmpMsg ):
   dumpPeerHeader( bmpMsg )
   print "Reason:", bmpMsg.data.reason.reason

def dumpIpv4Nlri( bmpMsg ):
   updMsg = bmpMsg.data.update.data
   print "IPv4 Withdrawn length = ", len( updMsg.withdrawn )
   if updMsg.withdrawn:
      print "IPv4 Withdrawn routes:"
      for rt in updMsg.withdrawn:
         print "Prefix: %s/%d" % ( inet_ntop( AF_INET, rt.prefix ), rt.len )
   print "IPv4 Announced length = ", len( updMsg.announced )
   if updMsg.announced:
      print "IPv4 Announced routes:"
      for rt in updMsg.announced:
         print "Prefix: %s/%d" % ( inet_ntop( AF_INET, rt.prefix ), rt.len )

def dumpIpv6Nlri( bmpMsg ):
   updMsg = bmpMsg.data.update.data
   attributes = updMsg.attributes
   mpReach = None
   mpUnreach = None
   for attr in attributes:
      if attr.type == 14:
         mpReach = attr
      if attr.type == 15:
         mpUnreach = attr
   if mpUnreach is not None:
      print "IPv6 Withdrawn length = ", len( mpUnreach.mp_unreach_nlri.withdrawn )
      print "IPv6 Withdrawn routes:"
      mpUnreachNlri = mpUnreach.mp_unreach_nlri.withdrawn
      for rt in mpUnreachNlri:
         print "Prefix: %s/%d" % ( inet_ntop( AF_INET6, rt.prefix ), rt.len )
   if mpReach is not None:
      print "IPv6 Announced length = ", len( mpReach.mp_reach_nlri.announced )
      print "IPv6 Announced routes:"
      mpReachNlri = mpReach.mp_reach_nlri.announced
      for rt in mpReachNlri:
         print "Prefix: %s/%d" % ( inet_ntop( AF_INET6, rt.prefix ), rt.len )

def checkEor( bmpMsg ):
   updMsg = bmpMsg.data.update.data
   if updMsg.withdrawn or updMsg.announced:
      return ( 0, 0 )
   attributes = updMsg.attributes
   ipv4Unicast = True
   mpReach = None
   mpUnreach = None
   for attr in attributes:
      if attr.type == 14 or attr.type == 15:
         ipv4Unicast = False
      if attr.type == 14:
         mpReach = attr
      if attr.type == 15:
         mpUnreach = attr

   if ipv4Unicast:
      return ( 1, 1 )

   if ( ( mpReach is not None and mpReach.mp_reach_nlri.announced ) or
        ( mpUnreach is not None and mpUnreach.mp_unreach_nlri.withdrawn ) ):
      return ( 0, 0 )

   if mpUnreach is not None and not mpUnreach.mp_unreach_nlri.withdrawn:
      return ( mpUnreach.mp_unreach_nlri.afi, mpUnreach.mp_unreach_nlri.safi )

   return ( 0, 0 )

def dumpRouteMonitoring( bmpMsg ):
   dumpPeerHeader( bmpMsg )

   peerHdr = bmpMsg.data.peer_hdr
   if peerHdr.peer_flags & BmpDpkt.PEERFLAG_POST_POLICY:
      print "Post-policy announcement"
   else:
      print "Pre-policy announcement"

   ( afi, safi ) = checkEor( bmpMsg )
   if ( afi, safi ) == ( 0, 0 ):
      dumpIpv4Nlri( bmpMsg )
      dumpIpv6Nlri( bmpMsg )
   else:
      print "End-of-Rib: afi=%d, safi=%d" % ( afi, safi )

def dumpMessage( bmpMsg ):
   dumpBmpHeader( bmpMsg )
   if bmpMsg.type == BmpDpkt.ROUTE_MONITORING:
      dumpRouteMonitoring( bmpMsg )
   elif bmpMsg.type == BmpDpkt.PEER_DOWN:
      dumpPeerDown( bmpMsg )
   elif bmpMsg.type == BmpDpkt.PEER_UP:
      dumpPeerUp( bmpMsg )
   elif bmpMsg.type == BmpDpkt.INITIATION:
      dumpInitiation( bmpMsg )
   elif bmpMsg.type == BmpDpkt.TERMINATION:
      dumpTermination( bmpMsg )

   print

def decodeFile( args, filename, msgFilter, counters=None ):
   with open( filename, 'rb' ) as f:
      # XXX: Assumes we can read the entire file into memory.
      # This should change so we read one message at a time.
      data = f.read()

   offset = 0
   dataLen = len( data )

   while offset < dataLen:
      ( _, msgLength, _ ) = struct.unpack_from( "!BLB", data, offset )
      msgString = data[ offset : msgLength + offset ]
      bmpMsg = BmpDpkt.BMP( msgString )
      offset += msgLength

      if not msgFilter.match( bmpMsg ):
         continue

      if args.detail:
         dumpMessage( bmpMsg )

      if counters is not None:
         counters.update( bmpMsg, msgFilter )

class BmpFilter( object ):
   PrePolicy = 'PrePolicy'
   PostPolicy = 'PostPolicy'
   PolicyTypes = frozenset( [ PrePolicy, PostPolicy ] )

   PeerIpv4 = 'PeerIpv4'
   PeerIpv6 = 'PeerIpv6'
   PeerTypes = frozenset( [ PeerIpv4, PeerIpv6 ] )

   def __init__( self ):
      self.messageTypeFilter = None
      self.policyFilter = None
      self.peerAddressFilter = None
      self.peerTypeFilter = None
      self.prefixFilter = None
      self.withdrawMatched = False
      self.announceMatched = False
      self.messageTypesWithPeerHdr = [ 0, 1, 2, 3, 6 ]

   def addMessageType( self, msgType ):
      if self.messageTypeFilter is None:
         self.messageTypeFilter = []
      self.messageTypeFilter.append( msgType )

   def addPolicy( self, policy ):
      assert policy in self.PolicyTypes
      self.policyFilter = policy

   def addPeerAddress( self, address ):
      self.peerAddressFilter = address

   def addPeerType( self, peerType ):
      assert peerType in self.PeerTypes
      self.peerTypeFilter = peerType

   def addPrefixFilter( self, prefix ):
      self.prefixFilter = prefix

   def extractV4Nlri( self, bmpMsg, announceList, withdrawList ):
      updMsg = bmpMsg.data.update.data
      for rt in updMsg.withdrawn:
         withdrawList.append( "%s/%d" % ( inet_ntop( AF_INET, rt.prefix ),
                                          rt.len ) )
      for rt in updMsg.announced:
         announceList.append( "%s/%d" % ( inet_ntop( AF_INET, rt.prefix ),
                                          rt.len ) )

   def extractV6Nlri( self, bmpMsg, announceList, withdrawList ):
      updMsg = bmpMsg.data.update.data
      attributes = updMsg.attributes
      mpReach = None
      mpUnreach = None
      for attr in attributes:
         if attr.type == 14:
            mpReach = attr
            if attr.type == 15:
               mpUnreach = attr
      if mpUnreach is not None:
         mpUnreachNlri = mpUnreach.mp_unreach_nlri.withdrawn
         for rt in mpUnreachNlri:
            withdrawList.append( "%s/%d" % ( inet_ntop( AF_INET6, rt.prefix ),
                                             rt.len ) )
      if mpReach is not None:
         mpReachNlri = mpReach.mp_reach_nlri.announced
         for rt in mpReachNlri:
            announceList.append( "%s/%d" % ( inet_ntop( AF_INET6, rt.prefix ),
                                             rt.len ) )

   def match( self, bmpMsg ):
      if self.messageTypeFilter is not None:
         if bmpMsg.type not in self.messageTypeFilter:
            return False

      if self.policyFilter is not None:
         if bmpMsg.type != 0:
            return False
         peerHdr = bmpMsg.data.peer_hdr
         if peerHdr.peer_flags & BmpDpkt.PEERFLAG_POST_POLICY:
            if self.policyFilter != self.PostPolicy:
               return False
         else:
            if self.policyFilter != self.PrePolicy:
               return False

      if self.peerAddressFilter is not None:
         if bmpMsg.type not in self.messageTypesWithPeerHdr:
            return False
         peerHdr = bmpMsg.data.peer_hdr
         peerFlags = peerHdr.peer_flags
         if peerFlags & 0x80:
            peerAddr = inet_ntop( AF_INET6, peerHdr.peer_address )
         else:
            peerAddr = inet_ntop( AF_INET, peerHdr.peer_address[ 12 : ] )
         if self.peerAddressFilter != peerAddr:
            return False

      if self.peerTypeFilter is not None:
         if bmpMsg.type not in self.messageTypesWithPeerHdr:
            return False
         peerHdr = bmpMsg.data.peer_hdr
         peerFlags = peerHdr.peer_flags
         if peerFlags & 0x80:
            if self.peerTypeFilter != self.PeerIpv6:
               return False
         else:
            if self.peerTypeFilter != self.PeerIpv4:
               return False

      if self.prefixFilter is not None:
         if bmpMsg.type != 0:
            return False
         self.withdrawMatched = False
         self.announceMatched = False
         # Extract routes
         announceList = []
         withdrawList = []
         self.extractV4Nlri( bmpMsg, announceList, withdrawList )
         self.extractV6Nlri( bmpMsg, announceList, withdrawList )
         if self.prefixFilter in announceList:
            self.announceMatched = True
         elif self.prefixFilter in withdrawList:
            self.withdrawMatched = True
         else:
            return False

      return True

class BmpMessageSize( object ):
   def __init__( self ):
      self.totalLength = 0
      self.largest = 0
      self.smallest = 65536

   def update( self, bmpMsg ):
      msgLen = bmpMsg.length
      self.totalLength += msgLen
      if msgLen > self.largest:
         self.largest = msgLen
      if len < self.smallest:
         self.smallest = msgLen

class BmpCounters( object ):
   def __init__( self ):
      self.prePolicyAnnouncements = 0
      self.prePolicyWithdrawn = 0
      self.postPolicyAnnouncements = 0
      self.postPolicyWithdrawn = 0
      self.messageCounter = {}
      self.messageSizes = {}

   def update( self, bmpMsg, msgFilter=None ):
      if bmpMsg.type not in self.messageCounter:
         self.messageCounter[ bmpMsg.type ] = 0
      self.messageCounter[ bmpMsg.type ] = self.messageCounter[ bmpMsg.type ] + 1
      if bmpMsg.type not in self.messageSizes:
         self.messageSizes[ bmpMsg.type ] = BmpMessageSize()
      self.messageSizes[ bmpMsg.type ].update( bmpMsg )
      if bmpMsg.type == BmpDpkt.ROUTE_MONITORING:
         peerHdr = bmpMsg.data.peer_hdr
         updMsg = bmpMsg.data.update.data
         attributes = updMsg.attributes
         mpReach = None
         mpUnreach = None
         for attr in attributes:
            if attr.type == 14:
               mpReach = attr
            if attr.type == 15:
               mpUnreach = attr
         if peerHdr.peer_flags & BmpDpkt.PEERFLAG_POST_POLICY:
            if msgFilter and msgFilter.prefixFilter:
               if msgFilter.announceMatched:
                  self.postPolicyAnnouncements += 1
               else:
                  self.postPolicyWithdrawn += 1
            else:
               self.postPolicyAnnouncements += len( updMsg.announced )
               self.postPolicyWithdrawn += len( updMsg.withdrawn )
               if mpUnreach is not None:
                  self.postPolicyWithdrawn += len(
                     mpUnreach.mp_unreach_nlri.withdrawn )
               if mpReach is not None:
                  self.postPolicyAnnouncements += len(
                     mpReach.mp_reach_nlri.announced )
         else:
            if msgFilter and msgFilter.prefixFilter:
               if msgFilter.announceMatched:
                  self.prePolicyAnnouncements += 1
               else:
                  self.prePolicyWithdrawn += 1
            else:
               self.prePolicyAnnouncements += len( updMsg.announced )
               self.prePolicyWithdrawn += len( updMsg.withdrawn )
               if mpUnreach is not None:
                  self.prePolicyWithdrawn += len(
                     mpUnreach.mp_unreach_nlri.withdrawn )
               if mpReach is not None:
                  self.prePolicyAnnouncements += len(
                     mpReach.mp_reach_nlri.announced )

   def output( self ):
      print "Total BMP Message Count:", \
         sum( v for v in self.messageCounter.itervalues() )
      print "BMP Messages by Type:"
      for ( msgType, count ) in self.messageCounter.iteritems():
         print "  %s: %d" % ( bmpMsgTypeStr( msgType ), count )
      print "BMP Message Length by Type:"
      for ( msgType, msgSize ) in self.messageSizes.iteritems():
         count = self.messageCounter[ msgType ]
         averageLen = msgSize.totalLength / count
         print "  %s: Average=%d, Max=%d, Min=%d" % \
            ( bmpMsgTypeStr( msgType ), averageLen, msgSize.largest,
              msgSize.smallest )
      print "Pre-policy announcements: ", self.prePolicyAnnouncements
      print "Pre-policy withdrawn: ", self.prePolicyWithdrawn
      print "Post-policy announcements: ", self.postPolicyAnnouncements
      print "Post-policy withdrawn: ", self.postPolicyWithdrawn
      print "Total number of announcements: ", \
         self.prePolicyAnnouncements + self.postPolicyAnnouncements
      print "Total number of withdrawn: ", \
         self.prePolicyWithdrawn + self.postPolicyWithdrawn

def main():
   parser = argparse.ArgumentParser()
   parser.add_argument( "filename", help="File containing BMP messages" )
   parser.add_argument( "-s", "--summary", help="Output summary statistics",
                        action="store_true" )
   parser.add_argument( "-d", "--detail", help="Output detailed message dump",
                        action="store_true" )
   parser.add_argument( "-t", "--type", type=int, action="append",
                        choices=[ 0, 2, 3, 4, 5 ],
                        help='''Filter by BMP message type:
                                0=RouteMonitoring
                                2=PeerDown
                                3=PeerUp
                                4=Initiation
                                5=Termination''' )
   parser.add_argument( "-p", "--peer", help="Filter by peer address" )
   vGroup = parser.add_mutually_exclusive_group()
   vGroup.add_argument( "-v", "--v4peers", help="Filter to only include ipv4 peers",
                        action="store_true" )
   vGroup.add_argument( "-V", "--v6peers", help="Filter to only include ipv6 peers",
                        action="store_true" )
   parser.add_argument( "-r", "--route",
                        help="Filter by route prefix. Format: prefix/length" )
   lGroup = parser.add_mutually_exclusive_group()
   lGroup.add_argument( "-l", "--pre-policy",
                        help="Filter to only include pre-policy",
                        action="store_true" )
   lGroup.add_argument( "-L", "--post-policy",
                        help="Filter to only include post-policy",
                        action="store_true" )

   args = parser.parse_args()

   filename = args.filename
   if not os.access( filename, os.R_OK ):
      print "Unable to read file:", filename
      sys.exit( 1 )

   if not args.detail and not args.summary:
      args.summary = True

   counters = None
   if args.summary:
      counters = BmpCounters()

   msgFilter = BmpFilter()
   if args.type:
      for t in args.type:
         msgFilter.addMessageType( t )

   if args.post_policy:
      msgFilter.addPolicy( BmpFilter.PostPolicy )
   elif args.pre_policy:
      msgFilter.addPolicy( BmpFilter.PrePolicy )

   if args.peer:
      msgFilter.addPeerAddress( args.peer )

   if args.v4peers:
      msgFilter.addPeerType( BmpFilter.PeerIpv4 )
   elif args.v6peers:
      msgFilter.addPeerType( BmpFilter.PeerIpv6 )

   if args.route:
      msgFilter.addPrefixFilter( args.route )

   decodeFile( args, filename, msgFilter, counters )

   if args.summary:
      if args.detail:
         print ""
      print "Summary:"
      counters.output()

if __name__ == '__main__':
   main()
