#!/usr/bin/env python
# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac
import sys
import shlex
import EntityManager
import PyClient
from DeviceNameLib import kernelIntfToEosIntf
import os
#pylint: disable=redefined-outer-name

#####################################################################################
#                 Getting statistics from iptables output                           #
# Global Counters                                                                   #
#    Bypass:                                                                        #
#       Packet count from FIREWALL Chain against last rule ( target RETURN )        #
#    Invalid Packet Drop:                                                           #
#       Sum of packets from first rule in every SEGMENT chain ( cstate INVALID )    #
#    Flows Created:                                                                 #
#       Rule with cstate INVALID, NEW in FORWARD chain - Invalid packet drop        #
#                                                                                   #
# Policy Counters                                                                   #
#    Hit:                                                                           #
#       Sum of all packets from rules in Policy ( _PO ) chain                       #
#    Drop:                                                                          #
#       Sum of all packets in policy chain where rule has target as AR_DROP         #
#    Default Drop:                                                                  #
#       Packets hitting last rule in a policy chain                                 #
#                                                                                   #
# Segment Counters                                                                  #
#    For each SEGMENT chain ( dest ), for each rule identify the source based on    #
#    interface or subnet and update hit count.                                      #
#                                                                                   #
# Doc: https://goo.gl/yLGZFG                                                        #
#####################################################################################

# variable to store stats
#
# policy = { '<policy-name>' : { 'Hit' : < count >, 'Drop' : < count >,
#                                 'Default Drop' : < count > } }
#
# segment = { '<dest-seg>' : { '<src-seg1>' : < count >,
#                              '<src-seg2>' : < count > } }

policy = {}
segment = {}
bypassPkts = 0
invalidPkts = 0
newAndInvalid = 0

# firewall Sysdb paths
fwConfig = None
fwCounters = None

def cleanupOutput( rules ):
   while ' ' in rules:
      rules.remove( ' ' )
   while '' in rules:
      rules.remove( '' )
   return rules

def getIntCount( strCount ):
   units = { 'K' : '000',
             'M' : '000000',
             'G' : '000000000',
             'T' : '000000000000',
             'P' : '000000000000000' }
   for unit in units:
      if unit in strCount:
         return int( strCount.replace( unit, units[ unit ] ) )
   return int( strCount )

def updatePolicyCounters( policyName, rules, vrfName ):
   # Get AR_DROP, AR_ACCEPT, LOG_N_DROP and LOG_N_ACCEPT chainName
   # from ar/Acl/Acl/ctx/filterChain
   aclAgentRoot = PyClient.PyClient( 'ar', 'Acl' ).agentRoot()
   filterChain = aclAgentRoot[ 'Acl' ].ctx.filterChain
   chain = { 'AR_DROP' : None, 'LOG_N_DROP' : None,
             'AR_ACCEPT': None, 'LOG_N_ACCEPT' : None }
   for aclName in chain:
      aclKey = Tac.newInstance( 'Acl::IptablesAclKey', 'ip', 'out', vrfName,
                                aclName )
      if aclKey in filterChain:
         chain[ aclName ] = filterChain[ aclKey ]

   policy[ policyName ] = { 'Hit' : 0, 'Drop' : 0, 'Default Drop' : 0 }
   for rule in rules[ : -1 ]:
      rule = rule.split()
      count = getIntCount( rule[ 0 ] )
      if chain[ 'AR_DROP' ] == rule[ 2 ] or chain[ 'LOG_N_DROP' ] == rule[ 2 ]:
         policy[ policyName ][ 'Drop' ] += count
         policy[ policyName ][ 'Hit' ] += count
      if chain[ 'AR_ACCEPT' ] == rule[ 2 ] or chain[ 'LOG_N_ACCEPT' ] == rule[ 2 ]:
         policy[ policyName ][ 'Hit' ] += count

   policy[ policyName ][ 'Default Drop' ] = getIntCount( rules[ -1 ].split()[ 0 ] )

def updateBypassPkts( rule ):
   global bypassPkts
   if 'RETURN' in rule:
      bypassPkts = getIntCount( rule.split()[ 0 ] )

def updateInvalidPkts( line ):
   global invalidPkts
   if 'DROP' in line and 'INVALID' in line:
      invalidPkts += getIntCount( line.split()[ 0 ] )

def updateNewAndInvalid( rules ):
   global newAndInvalid
   for rule in rules:
      if 'INVALID' in rule and 'NEW' in rule:
         newAndInvalid = getIntCount( rule.split()[ 0 ] )
         break

def mountFwConfig():
   global fwConfig, fwCounters
   em = EntityManager.Sysdb( 'ar' )
   mg = em.mountGroup()
   fwConfig = mg.mount( 'firewall/hw/config', 'Firewall::HwConfig', mode='r' )
   fwCounters = mg.mount( 'firewall/counters', 'Firewall::Counters', mode='r' )
   mg.close( blocking=True )

def getSourceSegment( rule, vrfName ):
   intf = None
   ipPrefix = None
   rule = rule.split()
   if 'et' in rule[ 5 ]:
      intf = rule[ 5 ]
   if '0.0.0.0/0' not in rule[ 7 ]:
      ipPrefix = rule[ 7 ]

   if intf == None and ipPrefix == None:
      return '-'

   for segment in fwConfig.vrf[ vrfName ].segmentDir.segment:
      segmentObj = fwConfig.vrf[ vrfName ].segmentDir.segment[ segment ]
      if intf != None:
         eosIntf = kernelIntfToEosIntf( intf )
         if eosIntf in segmentObj.intfVlanRangeSet:
            return segment
      if ipPrefix != None:
         for subnet in segmentObj.subnets:
            if subnet.stringValue == ipPrefix:
               return segment
            if subnet.len == 32 and subnet.address == ipPrefix:
               return segment

def updateSegmentCounters( destSeg, rules, vrfName='default' ):
   if destSeg not in segment:
      segment[ destSeg ] = {}

   for rule in rules:
      srcSeg = getSourceSegment( rule, vrfName=vrfName )
      count = getIntCount( rule.split()[ 0 ] )
      if srcSeg in segment[ destSeg ]:
         segment[ destSeg ][ srcSeg ] += count
      else:
         segment[ destSeg ][ srcSeg ] = count

def getNsName( vrfName='default' ):
   nsName = 'default'
   if vrfName != 'default':
      nsName = 'ns-' + vrfName
   return nsName

def updateSysdb():
   client = PyClient.PyClient( 'ar', 'Sysdb' )
   root = client.agentRoot()
   firewall = root[ 'firewall' ]
   if vrfName in firewall[ 'counters' ].vrf:
      vrf = firewall[ 'counters' ].vrf[ vrfName ]
      vrf.countersDir.globalCounters.invalidPkts = invalidPkts
      vrf.countersDir.globalCounters.flowCreated = \
            newAndInvalid - invalidPkts - bypassPkts
      vrf.countersDir.globalCounters.firewallBypass = bypassPkts

      for policyName in vrf.countersDir.policy:
         if policyName in policy:
            pol = vrf.countersDir.policy[ policyName ]
            pol.defaultDrop = policy[ policyName ][ 'Default Drop' ]
            pol.dropCount = policy[ policyName ][ 'Drop' ]
            pol.hitCount = policy[ policyName ][ 'Hit' ]

      for dstSegName in vrf.countersDir.segment:
         if dstSegName in segment:
            dstSeg = vrf.countersDir.segment[ dstSegName ]
            for srcSegName in dstSeg.segmentCounters:
               if srcSegName in segment[ dstSegName ]:
                  srcSeg = dstSeg.segmentCounters[ srcSegName ]
                  srcSeg.hitCount = segment[ dstSegName ][ srcSegName ]

def printOutput():
   print "VRF: ", vrfName
   print "Global counters"
   print "   Flows created: ", newAndInvalid - invalidPkts
   print "   Invalid pkt drop: ", invalidPkts
   print "   Bypass: ", bypassPkts
   print "Note: Bypass means no egress segment configured."
   print
   print "Policy Counter: ", policy
   print "Note: Hit means no of packets evaluated."
   print "      Default drop happen when there is no matching rule in policy."
   print
   print "Segment Counter: ", segment
   print

def main( vrfName='default' ):
   # Mount firewall/hw/config and firewall/counters
   mountFwConfig()

   # Get chain names corresponding to policy name and segment name
   if vrfName not in fwCounters.vrf:
      return
   policyChainNames = {}
   for pName in fwCounters.vrf[ vrfName ].countersDir.policy:
      pol = fwCounters.vrf[ vrfName ].countersDir.policy[ pName ]
      policyChainNames[ pol.policyChainName ] = pName
   segChainNames = {}
   for segName in fwCounters.vrf[ vrfName ].countersDir.segment:
      seg = fwCounters.vrf[ vrfName ].countersDir.segment[ segName ]
      segChainNames[ seg.segmentChainName ] = segName

   # Get 'iptables -nvL' output and parse each chain to get counters
   cmd = "ip netns exec %s iptables -nvL" % getNsName( vrfName )
   output = Tac.run( shlex.split( cmd ), stdout=Tac.CAPTURE, asRoot=True )
   chains = output.split( 'Chain' )
   for chain in chains:
      ruleList = chain.splitlines()
      ruleList = cleanupOutput( ruleList )
      if not ruleList:
         continue
      # ruleList[ 0 ] contain chain name
      # ruleList[ 1 ] contain headers
      chainName = ruleList[ 0 ].split()[ 0 ]
      if chainName in policyChainNames:
         updatePolicyCounters( policyChainNames[ chainName ], ruleList[ 2: ],
                               vrfName )

      if '_FIREWALL' in chainName:
         updateBypassPkts( ruleList[ -1 ] )

      if chainName in segChainNames:
         updateInvalidPkts( ruleList[ 2 ] )
         updateSegmentCounters( segChainNames[ chainName ], ruleList[ 3: ], vrfName )

      if 'FORWARD' in chainName:
         updateNewAndInvalid( ruleList )

   updateSysdb()
   #printOutput()

if __name__ == "__main__":
   if os.environ.get( 'BTEST' ) is None:
      vrfName = 'default'
      if len( sys.argv ) == 2:
         vrfName = sys.argv[ 1 ]
      main( vrfName )
   sys.exit( 0 )
