# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import BasicCli
import CliCommand
import CliMatcher
import CliPlugin.MsdpCli as MsdpCli
import CliPlugin.MsdpModels as MsdpModels
import CliPlugin.IpAddrMatcher as IpAddrMatcher
from CliPlugin.McastCommonCli import mcastRoutingSupportedGuard
from CliPlugin.RouterMulticastCliLib import AddressFamily
import CliPlugin.VrfCli as VrfCli
from CliPlugin.VrfCli import DEFAULT_VRF
from CliToken.Ip import ipMatcherForShow
import CommonGuards
import PimsmSmashHelper
import ShowCommand
import Tac

matcherMsdp = CliMatcher.KeywordMatcher(
      'msdp', helpdesc='MSDP protocol information' )
nodeMsdp = CliCommand.Node(
      matcher=matcherMsdp, guard=mcastRoutingSupportedGuard )
matcherSaCache = CliMatcher.KeywordMatcher(
      'sa-cache', helpdesc='Show MSDP source-advertisement cache' )

def vrfNameForShowCommand( mode, args ):
   vrfName = args.get( 'VRF' ) or VrfCli.vrfMap.getCliSessVrf( mode.session )
   if not MsdpCli.vrfExists( AddressFamily.ipv4, vrfName ):
      mode.addError( "Invalid vrf name %s " % vrfName )
      return None
   return vrfName

def peerResetCount( peerStatus_ ):
   if peerStatus_:
      return peerStatus_.fsmResetTransitions
   else:
      return 0

def peerStatus( peerIp, vrfName=DEFAULT_VRF ):
   if vrfName not in MsdpCli.msdpStatus.vrfStatus:
      return None
   peer = MsdpCli.msdpStatus.vrfStatus[ vrfName ].peerStatus.get( peerIp )
   return peer

#--------------------------------------------------------------------------------
# show msdp [ VRF ] mesh-group
#
# Deprecated cmd: show ip msdp [ VRF ] mesh-group
#--------------------------------------------------------------------------------
def showMeshGroup( mode, args ):
   vrfName = vrfNameForShowCommand( mode, args )
   if vrfName is None:
      return None
   model = MsdpModels.MeshGroupModel()
   if vrfName not in MsdpCli.msdpConfig.vrfConfig:
      return model

   for ( name, mg ) in MsdpCli.msdpConfig.vrfConfig[ vrfName ].meshGroup.iteritems():
      group = MsdpModels.MeshGroupModel.Group()
      for ip in sorted( mg.member ):
         group.ipList.append( ip )
      model.members[ name ] = group
   return model 

matcherMeshGroup = CliMatcher.KeywordMatcher( 'mesh-group',
      helpdesc='Show MSDP mesh group information' )

class MsdpMeshGroupCmd( ShowCommand.ShowCliCommandClass ):
   syntax = 'show msdp [ VRF ] mesh-group'
   data = {
      'msdp': nodeMsdp,
      'VRF': MsdpCli.vrfExprFactory,
      'mesh-group': matcherMeshGroup,
   }
   handler = showMeshGroup
   cliModel = MsdpModels.MeshGroupModel

BasicCli.addShowCommandClass( MsdpMeshGroupCmd )

class IpMsdpMeshGroupCmd( ShowCommand.ShowCliCommandClass ):
   syntax = 'show ip msdp [ VRF ] mesh-group'
   data = {
      'ip': ipMatcherForShow,
      'msdp': nodeMsdp,
      'VRF': MsdpCli.vrfExprFactory,
      'mesh-group': CliCommand.Node( matcher=matcherMeshGroup,
         deprecatedByCmd='show msdp mesh-group in enable mode' ),
   }
   handler = showMeshGroup
   cliModel = MsdpModels.MeshGroupModel

BasicCli.addShowCommandClass( IpMsdpMeshGroupCmd )

#--------------------------------------------------------------------------------
# show ip msdp [ VRF ] peer [ PEERADDR ] [ accepted-sas ]
#--------------------------------------------------------------------------------
def getPeerModel( peerConfig_, acceptedSas, vrfName=DEFAULT_VRF ):
   peerStatus_ = peerStatus( peerConfig_.remote, vrfName=vrfName )

   if not peerStatus_:
      return None

   peer = MsdpModels.PeerInfo()
   peer.peerIpAddress = peerConfig_.remote
   peer.description = peerConfig_.description
   peer.state = peerStatus_.state
   if peerStatus_.state == 'established':
      peer.sessionStartTime = peerStatus_.fsmEstablishedTime + \
                                 Tac.utcNow() - Tac.now()
   peer.resetCount = peerResetCount( peerStatus_ )
   if peerConfig_.connSrc:
      peer.connSourceInterface = peerConfig_.connSrc
      peer.connSourceAddress = peerStatus_.connSrc   
   peer.saFilterIn = peerConfig_.saFilterIn
   peer.saFilterOut = peerConfig_.saFilterOut
   if acceptedSas:
      for sg, sa in MsdpCli.msdpStatus.vrfStatus[ vrfName ].saMsg.iteritems():
         if sa.remote == peerConfig_.remote:
            sgPair = MsdpModels.SGPair( sourceAddress=sg.sAddr, 
                                        groupAddress=sg.gAddr )
            sas = MsdpModels.SaMsg( sourceGroupPair=sgPair, rpAddress=sa.rpAddr )
            peer.acceptedSas.append( sas )
   return peer

class IpMsdpPeerCmd( ShowCommand.ShowCliCommandClass ):
   syntax = 'show ip msdp [ VRF ] peer [ PEERADDR ] [ accepted-sas ]'
   data = {
      'ip': ipMatcherForShow,
      'msdp': nodeMsdp,
      'VRF': MsdpCli.vrfExprFactory,
      'peer': 'Show MSDP peer information',
      'PEERADDR': MsdpCli.msdpPeerAddrMatcher,
      'accepted-sas': 'Show only accepted SAs',
   }
   
   @staticmethod
   def handler( mode, args ):
      vrfName = vrfNameForShowCommand( mode, args )
      if vrfName is None:
         return None
      peerIp = args.get( 'PEERADDR' )
      acceptedSas = 'accepted-sas' in args

      model = []
      if vrfName not in MsdpCli.msdpConfig.vrfConfig:
         return MsdpModels.PeerModel( peerList=model )

      if peerIp:
         c = MsdpCli.peerConfig( peerIp, vrfName=vrfName )
         if c:
            peerModel = getPeerModel( c, acceptedSas, vrfName )
            if peerModel:
               model.append( peerModel )
         else:
            MsdpCli.printNonexistentPeerError( mode, peerIp )
      else:
         for peer in MsdpCli.msdpConfig.vrfConfig[ vrfName ].peerConfig.itervalues():
            peerModel = getPeerModel( peer, acceptedSas, vrfName )
            if peerModel:
               model.append( peerModel )
      return MsdpModels.PeerModel( peerList=model )

   cliModel = MsdpModels.PeerModel

BasicCli.addShowCommandClass( IpMsdpPeerCmd )

#--------------------------------------------------------------------------------
# show ip msdp [ VRF ] sa-cache [ SRCORGRP1 [ SRCORGRP2 ] ] [ rejected ]
#--------------------------------------------------------------------------------
class IpMsdpSaCacheCmd( ShowCommand.ShowCliCommandClass ):
   syntax = 'show ip msdp [ VRF ] sa-cache [ SRCORGRP1 [ SRCORGRP2 ] ]' \
            '[ rejected ]'
   data = {
      'ip': ipMatcherForShow,
      'msdp': nodeMsdp,
      'VRF': MsdpCli.vrfExprFactory,
      'sa-cache': matcherSaCache,
      'SRCORGRP1': IpAddrMatcher.IpAddrMatcher(
         helpdesc='Multicast Source or Group Address' ),
      'SRCORGRP2': IpAddrMatcher.IpAddrMatcher( helpdesc='Multicast Group Address' ),
      'rejected': 'Show Rejected Source Active Cache',
   }

   @staticmethod
   def handler( mode, args ):
      vrfName = vrfNameForShowCommand( mode, args )
      if vrfName is None:
         return None
      srcOrGrp1 = args.get( 'SRCORGRP1' )
      srcOrGrp2 = args.get( 'SRCORGRP2' )
      rejected = args.get( 'rejected' )

      model = MsdpModels.SaCacheModel()
      if vrfName not in MsdpCli.msdpStatus.vrfStatus:
         return model
      model.setRejected( False )
      source = None
      group = None
      if all( [ srcOrGrp1, srcOrGrp2 ] ):
         source = srcOrGrp1
         group = srcOrGrp2
      elif srcOrGrp1:
         group = srcOrGrp1

      if group:
         error = IpAddrMatcher.validateMulticastIpAddr( group )
         if error:
            mode.addError( "Invalid Multicast Group: %s" % group ) 
      
      def accept( sg, source, group ):
         if source and group:
            return source == sg.sAddr and group == sg.gAddr
         elif group:
            return group == sg.gAddr
         else:
            return True

      for ( sg, sa ) in MsdpCli.msdpStatus.vrfStatus[ vrfName ].saMsg.iteritems():
         if accept( sg, source, group ):
            sgPair = MsdpModels.SGPair(
                  sourceAddress=sg.sAddr, groupAddress=sg.gAddr )
            saMsg = MsdpModels.SaMsg( sourceGroupPair=sgPair, rpAddress=sa.rpAddr, 
                                      remoteAddress=sa.remote ) 
            model.acceptedSaMsg.append( saMsg)
      if rejected:
         model.setRejected( True )
         for ( sg, sa ) in \
            MsdpCli.msdpStatus.vrfStatus[ vrfName ].rejectedSAMsg.iteritems():
            if accept( sg, source, group ):
               sgPair = MsdpModels.SGPair( sourceAddress=sg.sAddr,
                                           groupAddress=sg.gAddr )
               saMsg = MsdpModels.SaMsg( sourceGroupPair=sgPair, rpAddress=sa.rpAddr,
                                         remoteAddress=sa.remote )
               model.rejectedSaMsg.append( saMsg )
      return model

   cliModel = MsdpModels.SaCacheModel

BasicCli.addShowCommandClass( IpMsdpSaCacheCmd )

#--------------------------------------------------------------------------------
# show ip msdp [ VRF ] pim sa-cache
#--------------------------------------------------------------------------------
class IpMsdpPimSaCacheCmd( ShowCommand.ShowCliCommandClass ):
   syntax = 'show ip msdp [ VRF ] pim sa-cache'
   data = {
      'ip': ipMatcherForShow,
      'msdp': nodeMsdp,
      'VRF': MsdpCli.vrfExprFactory,
      'pim': 'MSDP related PIM protocol information',
      'sa-cache': matcherSaCache,
   }

   @staticmethod
   def handler( mode, args ):
      vrfName = vrfNameForShowCommand( mode, args )
      if vrfName is None:
         return None
      model = MsdpModels.PimSaCacheModel()
      if vrfName not in MsdpCli.msdpStatus.vrfStatus:
         return model
      for ( sg, sa ) in MsdpCli.msdpPimStatus.vrfStatus[ vrfName ].saMsg.iteritems():
         sgPair = MsdpModels.SGPair( sourceAddress=sg.sAddr, groupAddress=sg.gAddr )
         info = MsdpModels.SaMsg( sourceGroupPair=sgPair, rpAddress=sa.rpAddr )
         model.saCache.append( info )
      return model

   cliModel = MsdpModels.PimSaCacheModel

BasicCli.addShowCommandClass( IpMsdpPimSaCacheCmd )

#--------------------------------------------------------------------------------
# show msdp [ VRF ] rpf-peer RP
#
# Deprecated cmd: show ip msdp [ VRF ] rpf-peer RP
#--------------------------------------------------------------------------------
def doShowRpfPeerForRp( mode, args ):
   vrfName = vrfNameForShowCommand( mode, args )
   if vrfName is None:
      return None
   rp = args[ 'RP' ]

   model = MsdpModels.RpfPeerModel( rpAddress=rp )
   rpfMapStatus = None
   if vrfName in MsdpCli.rpfMapStatusDir:
      if MsdpCli.rpfMapStatusDir[ vrfName ]:
         rpfMapStatus = MsdpCli.rpfMapStatusDir[ vrfName ]
   else:
      model.queryCompleted = False
      return model

   if vrfName not in MsdpCli.msdpConfig.vrfConfig:
      model.queryCompleted = False
      return model

   if rp not in rpfMapStatus.rpfMap:
      model.queryStarted = True
      MsdpCli.msdpConfig.vrfConfig[ vrfName ].cliRp[ rp ] = True
   try:
      Tac.waitFor( lambda: rpfMapStatus.rpfMap.get( rp ),
                   timeout=10,
                   warnAfter=None, sleep=True, maxDelay=0.1,
                   description = ' routing table to find rpf-peer' )
      model.rpfPeer = rpfMapStatus.rpfMap.get( rp )
      model.queryCompleted = True
   except Tac.Timeout:
      model.queryCompleted = False
   if MsdpCli.msdpConfig.vrfConfig[ vrfName ].cliRp.get( rp ):
      del MsdpCli.msdpConfig.vrfConfig[ vrfName ].cliRp[ rp ]
   return model

matcherRpfPeer = CliMatcher.KeywordMatcher( 'rpf-peer',
      helpdesc='Show RPF Peer Information' )
rpAddrMatcher = IpAddrMatcher.IpAddrMatcher( helpdesc='PIM RP Address' )

class MsdpRpfPeerRpCmd( ShowCommand.ShowCliCommandClass ):
   syntax = 'show msdp [ VRF ] rpf-peer RP'
   data = {
      'msdp': nodeMsdp,
      'VRF': MsdpCli.vrfExprFactory,
      'rpf-peer': CliCommand.Node( matcher=matcherRpfPeer,
         guard=CommonGuards.standbyGuard ),
      'RP': rpAddrMatcher,
   }

   handler = doShowRpfPeerForRp
   cliModel = MsdpModels.RpfPeerModel

BasicCli.addShowCommandClass( MsdpRpfPeerRpCmd )

class IpMsdpRpfPeerRpCmd( ShowCommand.ShowCliCommandClass ):
   syntax = 'show ip msdp [ VRF ] rpf-peer RP'
   data = {
      'ip': ipMatcherForShow,
      'msdp': nodeMsdp,
      'VRF': MsdpCli.vrfExprFactory,
      'rpf-peer': CliCommand.Node( matcher=matcherRpfPeer,
         deprecatedByCmd='show msdp rpf-peer in enable mode',
         guard=CommonGuards.standbyGuard ),
      'RP': rpAddrMatcher,
   }

   handler = doShowRpfPeerForRp
   cliModel = MsdpModels.RpfPeerModel

BasicCli.addShowCommandClass( IpMsdpRpfPeerRpCmd )

#--------------------------------------------------------------------------------
# show ip msdp [ VRF ] sanity
#--------------------------------------------------------------------------------
class IpMsdpSanityCmd( ShowCommand.ShowCliCommandClass ):
   syntax = 'show ip msdp [ VRF ] sanity'
   data = {
      'ip': ipMatcherForShow,
      'msdp': nodeMsdp,
      'VRF': MsdpCli.vrfExprFactory,
      'sanity': 'Check consistency between SA Caches and Multicast Routing Table',
   }

   @staticmethod
   def handler( mode, args ):
      vrfName = vrfNameForShowCommand( mode, args )
      if vrfName is None:
         return
      pimsmSmashStatus = PimsmSmashHelper.mountInDependencyOrder(
         MsdpCli.entityMgr, "routing/pim/sparsemode/status/" + vrfName, 'keyshadow' )

      groupSourceMap = MsdpCli.groupSourceMapColl.vrfGroupSourceMap.get( vrfName )

      remoteSource = set()
      localSource = set()
      
      if groupSourceMap:
         for grpAddr in groupSourceMap.group:
            grp = groupSourceMap.group.get( grpAddr )
            for srcAddr in grp.groupSource:
               if srcAddr.stringValue == '0.0.0.0':
                  continue
               key = Tac.Value("Routing::Pim::RouteKey", srcAddr, grpAddr )
               mroute = pimsmSmashStatus.route.get( key )
               if mroute is None:
                  continue
               if mroute.routeFlags.msdpDiscovered:
                  remoteSource.add( ( srcAddr.stringValue, grpAddr.stringValue ) )
               if mroute.routeFlags.mayAdvertise:
                  localSource.add( ( srcAddr.stringValue, grpAddr.stringValue ) )

      header = "PIM SA cache entries not in the MRT"
      headerPrinted = False
      if vrfName in MsdpCli.msdpPimStatus.vrfStatus:
         for sg in MsdpCli.msdpPimStatus.vrfStatus[ vrfName ].saMsg:
            if ( sg.sAddr, sg.gAddr ) not in localSource:
               if not headerPrinted:
                  print( header )
                  headerPrinted = True
               print( "(%s, %s)" % ( sg.sAddr, sg.gAddr ) )

      if vrfName in MsdpCli.msdpStatus.vrfStatus:
         x = { ( item.sAddr, item.gAddr ) for item in \
               MsdpCli.msdpStatus.vrfStatus[ vrfName ].saMsg.keys() }
         header = "MSDP-learnt MRT entries not in the SA cache"
         headerPrinted = False
         for ( s, g ) in remoteSource:
            if ( s, g ) not in x:
               if not headerPrinted:
                  print( header )
                  headerPrinted = True
               print( "(%s, %s)" % ( s, g ) )

         header = "SA cache entries not in the MRT"
         headerPrinted = False
         for sg in MsdpCli.msdpStatus.vrfStatus[ vrfName ].saMsg:
            if ( sg.sAddr, sg.gAddr ) not in remoteSource:
               if not headerPrinted:
                  print( header )
                  headerPrinted = True
               print( "(%s, %s)" % ( sg.sAddr, sg.gAddr ) )

      if vrfName in MsdpCli.msdpPimStatus.vrfStatus:
         x = { ( item.sAddr, item.gAddr ) for item in \
                  MsdpCli.msdpPimStatus.vrfStatus[ vrfName ].saMsg.keys() }
         header = "May-Notify-MSDP entries not in the PIM SA cache"
         header += " (need not be an error condition)"
         headerPrinted = False
         for ( s, g ) in localSource:
            if ( s, g ) not in x:
               if not headerPrinted:
                  print( header )
                  headerPrinted = True
               print( "(%s, %s)" % ( s, g ) )

BasicCli.addShowCommandClass( IpMsdpSanityCmd )

#--------------------------------------------------------------------------------
# show ip msdp [ VRF ] summary
#--------------------------------------------------------------------------------
def countAcceptedSAsFromPeer( peer, vrfName=DEFAULT_VRF ):
   if vrfName not in MsdpCli.msdpStatus.vrfStatus:
      return None
   return len( [ sa for sa in
                 MsdpCli.msdpStatus.vrfStatus[ vrfName ].saMsg.itervalues()
                 if sa.remote == peer.remote ] )

class IpMsdpSummaryCmd( ShowCommand.ShowCliCommandClass ):
   syntax = 'show ip msdp [ VRF ] summary'
   data = {
      'ip': ipMatcherForShow,
      'msdp': nodeMsdp,
      'VRF': MsdpCli.vrfExprFactory,
      'summary': 'Show MSDP peer status',
   }

   @staticmethod
   def handler( mode, args ):
      vrfName = vrfNameForShowCommand( mode, args )
      if vrfName is None:
         return None

      model = []
      if vrfName not in MsdpCli.msdpConfig.vrfConfig:
         return MsdpModels.SummaryModel( peerList=model )

      for peer in MsdpCli.msdpConfig.vrfConfig[ vrfName ].peerConfig.itervalues():
         peerStatus_ = peerStatus( peer.remote, vrfName=vrfName )
         #The following check is to guard against crash if MSDP agent is shutdown
         if not peerStatus_:
            continue
         pmodel = MsdpModels.PeerSummary()
         pmodel.peerIpAddress = peer.remote
         pmodel.state = peerStatus_.state
         if peerStatus_.state == 'established':
            pmodel.sessionStartTime = peerStatus_.fsmEstablishedTime + \
                                       Tac.utcNow() - Tac.now() 
         pmodel.saCount = countAcceptedSAsFromPeer( peer, vrfName )
         pmodel.resetCount = peerResetCount( peerStatus_ )
         model.append( pmodel ) 
      return MsdpModels.SummaryModel( peerList=model )    

   cliModel = MsdpModels.SummaryModel

BasicCli.addShowCommandClass( IpMsdpSummaryCmd )
