#!/usr/bin/env python
# Copyright (c) 2009-2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from CliPlugin import IntfCli, EthIntfCli, VlanCli
import CliCommand
import CliParser
import LazyMount
import SmashLazyMount
import ConfigMount
import Tac
import SubIntfCli
from PortSecCliModels import ( GeneralPortSecurityStatistics,
                               GeneralPortSecurityVlanStatistics,
                               PortSecurityAddresses,
                               PortSecurityInterfaces,
                               ethStr )
from BridgingHostEntryType import isEntryTypeConfigured

bridgingConfig = None
bridgingStatus = None
portSecConfig = None
portSecLocalConfig = None
portSecStatus = None
portSecHwCap = None

validIntfTypes = [ EthIntfCli.EthIntf, SubIntfCli.SubIntf ]
VlanId = Tac.Type( "Bridging::VlanId" )

class PortSecIntf( IntfCli.IntfDependentBase ):
   def setDefault( self ):
      del portSecConfig.intfConfig[ self.intf_.name ]

def intfConfig( intf ):
   cfg = portSecConfig.intfConfig.get( intf )
   if cfg is not None:
      return Tac.nonConst( cfg )
   else:
      return Tac.Value( 'PortSec::IntfConfig', intf )

def vlanConfig( psiConf, vlanId ):
   vcfg = psiConf.vlanConfig.get( vlanId )
   if vcfg is not None:
      return Tac.nonConst( vcfg )
   else:
      return Tac.Value( 'PortSec::VlanConfig', vlanId )

def intfStatus( mode, intf ):
   return portSecStatus.intfStatus.get( intf.name, None )

def guardGlobalMacAddressConfig( mode, token ):
   if not portSecHwCap.allowSecureAddressDeletion:
      return CliParser.guardNotThisPlatform
   else:
      return None

def maximumRange( mode ):
   if not portSecConfig:
      return 1, portSecHwCap.allowedMaxLimitNoLogging

   cfg = portSecConfig.intfConfig.get( mode.intf.name )
   if not cfg or not cfg.log:
      return 1, portSecHwCap.allowedMaxLimitNoLogging

   return 1, portSecHwCap.allowedMaxLimitForLogging

def doShowPortSecurity( mode, args ):
   intfs = set( x.name for x in
                IntfCli.Intf.getAll( mode, intfType=validIntfTypes ) )
   portStatistics = {}
   numAddresses = 0
   secureAddressMoves = portSecConfig.allowSecureAddressMoves
   secureAddressAging = portSecConfig.allowSecureAddressAging
   persistence = portSecConfig.persistenceEnabled
   for iname in portSecStatus.intfStatus:
      if iname not in intfs:
         continue
      intf = portSecStatus.intfStatus[ iname ]
      config = portSecConfig.intfConfig.get( iname )
      addrs = intf.addrs
      portStatistics[ iname ] = GeneralPortSecurityStatistics.PortStatistic(
                                   maxSecureAddr=config.maxAddrs if config else 0,
                                   currentAddr=addrs,
                                   numberOfViolations=intf.violations,
                                   securityAction=config.mode )
      numAddresses += addrs
   genPortStatStatistics = GeneralPortSecurityStatistics(
                              portStatistics=portStatistics,
                              totalAddresses=numAddresses,
                              secureAddressMoves=secureAddressMoves,
                              secureAddressAging=secureAddressAging,
                              persistence=persistence )
   return genPortStatStatistics

def doShowPortSecurityAddress( mode, args ):
   pStat = portSecStatus.intfStatus
   validIntfs = set( x.name for x in
                     IntfCli.Intf.getAll( mode, intfType=validIntfTypes ) )
   portSecIntfNames = set( i.intfId for i in
                           pStat.itervalues() ).intersection( validIntfs )
   numAddresses = 0
   # TBD: We should show the addresses from the allowedAddress rather
   # than smashFdbStatus
   smashFdbStatuses = bridgingStatus.smashFdbStatus
   addresses = []
   hosts = [ host for host in smashFdbStatuses.itervalues()
                      if host.intf in portSecIntfNames ]
   for host in hosts:
      configured = isEntryTypeConfigured( host.entryType )
      entryType = "secureConfigured" if configured else "secureDynamic"
      addresses.append( PortSecurityAddresses.Address(
                           macAddress=host.key.addr,
                           vlan=host.key.fid,
                           entryType=entryType,
                           interface=host.intf,
                           remainingAge=None ) )
      numAddresses += 1
   return PortSecurityAddresses( addresses=addresses, totalAddresses=numAddresses )

def doShowPortSecurityInterface( mode, args ):
   intfs = IntfCli.Intf.getAll( mode, args.get( 'INTF', None ),
                                intfType=validIntfTypes )
   if not intfs:
      return PortSecurityInterfaces( interfaces={} )
   intfModels = {}
   for i in intfs:
      # Default values
      portSecurityEnabled = False
      portStatus = 'secure-down'
      violationMode = 'none'
      portMaxEnabled = True
      maxMacAddresses = 1
      agingTime = int( bridgingConfig.hostAgingTime / 60.0 + 0.5 )
      agingType = "inactivity"
      secureStaticAddressAging = "disabled"
      secureAddressMoves = portSecConfig.allowSecureAddressMoves
      secureAddressAging = portSecConfig.allowSecureAddressAging
      persistence = portSecConfig.persistenceEnabled
      totalMacAddresses = None
      configuredMacAddresses = None
      addressChanges = None
      lastChangeDetails = None
      lastViolation = None
      securityViolationCount = None
      logAddrsAfterLimit = "disabled"
      allowedAddresses = []

      conf = portSecConfig.intfConfig.get( i.name )
      if conf:
         stat = intfStatus( mode, i )
         portMaxEnabled = not conf.vlanBased
         maxMacAddresses = conf.maxAddrs
         if stat and conf.enabled:
            if stat.restrictionStatus == 'restrictionActive':
               if conf.mode == 'shutdown':
                  portStatus = 'secure-shutdown'
               else:
                  portStatus = 'secure-protected'
            elif i.lineProtocolState() == 'up':
               portStatus = 'secure-up'
         if conf.enabled:
            portSecurityEnabled = True
            violationMode = conf.mode
         if stat:
            totalMacAddresses = stat.addrs
            configuredMacAddresses = stat.staticAddrs
            addressChanges = stat.addrChanges
            interface = PortSecurityInterfaces.Interface
            securityViolationCount = stat.violations
            if stat.addrs:
               utcTime = stat.lastNewAddrTime + Tac.utcNow() - Tac.now()
               lastChangeDetails = interface.InterestingAddress(
                               macAddress=stat.lastNewAddr,
                               vlan=stat.lastNewVlanId,
                               time=utcTime )
            if stat.violations:
               utcTime = stat.lastViolationTime + Tac.utcNow() - Tac.now()
               lastViolation = interface.InterestingAddress(
                                  macAddress=stat.lastViolatingAddr,
                                  vlan=stat.lastViolatingVlanId,
                                  time=utcTime )
            if conf.mode == 'protect':
               if conf.log:
                  logAddrsAfterLimit = "enabled"
               for key in stat.allowedAddr:
                  allowedAddresses.append( "%s:%s" %
                                           ( ethStr( str( key.addr ) ), key.fid ) )

      intfModel = PortSecurityInterfaces.Interface(
            portSecurityEnabled=portSecurityEnabled,
            portStatus=portStatus,
            violationMode=violationMode,
            portMaxEnabled=portMaxEnabled,
            maxMacAddresses=maxMacAddresses,
            agingTime=agingTime,
            agingType=agingType,
            secureStaticAddressAging=secureStaticAddressAging,
            secureAddressMoves=secureAddressMoves,
            secureAddressAging=secureAddressAging,
            persistence=persistence,
            totalMacAddresses=totalMacAddresses,
            configuredMacAddresses=configuredMacAddresses,
            addressChanges=addressChanges,
            lastChangeDetails=lastChangeDetails,
            lastViolation=lastViolation,
            securityViolationCount=securityViolationCount,
            logAddrsAfterLimit=logAddrsAfterLimit,
            allowedAddresses=allowedAddresses )
      intfModels[ i.name ] = intfModel
   return PortSecurityInterfaces( interfaces=intfModels )

def doShowPortSecurityVlan( mode, args ):
   intfs = IntfCli.Intf.getAll( mode, args.get( 'INTFS' ),
                                intfType=EthIntfCli.EthIntf )
   intfs = set( x.name for x in intfs )
   if not intfs:
      return GeneralPortSecurityVlanStatistics( interfaces={},
                                                totalAddresses=0 )
   vlans = args.get( 'VLANS' )
   allVlans = { vlan.id for vlan in VlanCli.Vlan.getAll( mode ) }
   if vlans:
      vlans = vlans.ids
   else:
      vlans = allVlans

   intfModels = {}
   numAddresses = 0
   for ( iname, status ) in portSecStatus.intfStatus.items():
      if iname not in intfs:
         continue
      vlanModels = {}
      for vlanId in status.vlanStatus:
         if vlanId == VlanId.invalid:
            vlanStat = status.vlanStatus.get( vlanId )
            for defaultVlanId in status.defaultVlanCounter:
               if defaultVlanId not in allVlans:
                  continue
               dVlanCounter = status.defaultVlanCounter.get( defaultVlanId )
               addrs = dVlanCounter.addrs
               vlanModel = GeneralPortSecurityVlanStatistics.VlanDict.VlanStatistic(
                     maxAddrs=vlanStat.maxAddrs,
                     numAddrs=addrs,
                     numViolations=dVlanCounter.violations,
                     action=status.mode )
               vlanModels[ defaultVlanId ] = vlanModel
               numAddresses += addrs

         if vlanId not in vlans:
            continue
         vlanStat = status.vlanStatus.get( vlanId )
         addrs = vlanStat.addrs
         vlanModel = GeneralPortSecurityVlanStatistics.VlanDict.VlanStatistic(
               maxAddrs=vlanStat.maxAddrs,
               numAddrs=addrs,
               numViolations=vlanStat.violations,
               action=status.mode )
         vlanModels[ vlanId ] = vlanModel
         numAddresses += addrs
      intfModels[ iname ] = GeneralPortSecurityVlanStatistics.VlanDict(
           vlans=vlanModels )
   genVlanStatStatistics = GeneralPortSecurityVlanStatistics(
                              interfaces=intfModels,
                              totalAddresses=numAddresses )
   return genVlanStatStatistics

def clearPortSecurity( mode, args ):
   intfs = IntfCli.Intf.getAll( mode, args.get( 'INTF', None ),
                                intfType=validIntfTypes )
   if not intfs:
      return
   for i in intfs:
      portSecLocalConfig.clearPortSecurityRequest[ i.name ] = Tac.now()

def doEnDisPortSecurityLogging( mode, no=None ):
   if no:
      psiConf = portSecConfig.intfConfig.get( mode.intf.name )
      if psiConf:
         psiConf = Tac.nonConst( psiConf )
         psiConf.log = False
         portSecConfig.intfConfig.addMember( psiConf )
   else:
      conf = intfConfig( mode.intf.name )
      if conf and conf.maxAddrs > portSecHwCap.allowedMaxLimitForLogging:
         mode.addError( "Error enabling logging. Configured maximum "
                        "addresses %d is higher than the limit %d" % (
                        conf.maxAddrs, portSecHwCap.allowedMaxLimitForLogging ) )
         return
      doEnDisPortSecurity( mode, portSecMode='protect', logging=True )

def doEnDisPortSecurity( mode, no=None, portSecMode='shutdown', logging=False ):
   if no and not portSecConfig.intfConfig.get( mode.intf.name ):
      # if port security is not configured, and this is
      # 'no switchport port-security', don't bother doing anything.
      return
   psiConf = intfConfig( mode.intf.name )
   if no:
      psiConf.enabled = False
      psiConf.log = False
      psiConf.mode = 'shutdown' # reset to default
   else:
      psiConf.mode = portSecMode
      psiConf.log = logging
      psiConf.enabled = True
   portSecConfig.intfConfig.addMember( psiConf )

def doEnDisVlanBasedPortSec( mode, args ):
   psiConf = intfConfig( mode.intf.name )
   if psiConf.mode != 'shutdown':
      mode.addWarning( "This command is not supported in protect mode" )
      return
   maximum = args.get( 'MAXIMUM' )
   if 'default' in args:
      vlans = [ VlanId.invalid ]
   else:
      vlans = args.get( 'VLAN_SET' ).ids

   if CliCommand.isNoOrDefaultCmd( args ):
      for vlan in vlans:
         vlanId = Tac.Value( 'Bridging::VlanIdOrAnyOrNone', vlan )
         del psiConf.vlanConfig[ vlanId ]
   else:
      for vlan in vlans:
         vlanId = Tac.Value( 'Bridging::VlanIdOrAnyOrNone', vlan )
         psvConf = vlanConfig( psiConf, vlanId )
         psvConf.maxAddrs = maximum
         psiConf.vlanConfig[ vlanId ] = psvConf
   portSecConfig.intfConfig.addMember( psiConf )

def doSetMaximum( mode, args ):
   maximum = args.get( "MAXIMUM", 1 )
   vlanBased = 'disabled' in args
   if ( maximum == 1 and
        not vlanBased and
        not portSecConfig.intfConfig.get( mode.intf.name ) ):
      # If port security is not configured, and this is trying to set the default
      # max value, then return. We don't want to create a default
      # PortSec::IntfConfig object for no reason.
      return
   conf = intfConfig( mode.intf.name )
   conf.maxAddrs = maximum
   conf.vlanBased = vlanBased
   portSecConfig.intfConfig.addMember( conf )

def doEnDisMacAddressLimitMaximum( mode, args ):
   maximum = args.get( 'MAXIMUM', 1 )
   no = CliCommand.isNoOrDefaultCmd( args )
   if no and not portSecConfig.intfConfig.get( mode.intf.name ):
      # if mac address limit is not configured, and this is
      # 'no mac address limit', don't bother doing anything.
      return
   psiConf = intfConfig( mode.intf.name )
   psiConf.enabled = not no
   psiConf.maxAddrs = 1 if no else maximum
   psiConf.log = False
   # logging not supported at the point when this changes got introduced
   portSecConfig.intfConfig.addMember( psiConf )

def doSetMacAddressLimitViolationMode( mode, args ):
   # Just set the violationMode as specified.
   violationMode = 'protect' if not args.get( 'shutdown' ) \
                   else 'shutdown'
   psiConf = intfConfig( mode.intf.name )
   psiConf.mode = violationMode
   portSecConfig.intfConfig.addMember( psiConf )

def doEnAging( mode, args ):
   portSecConfig.allowSecureAddressAging = True

def doDisAging( mode, args ):
   portSecConfig.allowSecureAddressAging = False

def doEnMoves( mode, args ):
   portSecConfig.allowSecureAddressMovesData = 'data' in args
   portSecConfig.allowSecureAddressMovesPhone = 'phone' in args
   portSecConfig.allowSecureAddressMoves = not(
         portSecConfig.allowSecureAddressMovesData or
         portSecConfig.allowSecureAddressMovesPhone )

def doDisMoves( mode, args ):
   portSecConfig.allowSecureAddressMoves = False
   portSecConfig.allowSecureAddressMovesData = False
   portSecConfig.allowSecureAddressMovesPhone = False

def doEnPersistence( mode, args ):
   portSecConfig.persistenceEnabled = True

def doDisPersistence( mode, args ):
   portSecConfig.persistenceEnabled = False

def Plugin( entityManager ):
   global bridgingConfig, bridgingStatus
   global portSecConfig, portSecLocalConfig, portSecStatus, portSecHwCap
   bridgingConfig = LazyMount.mount( entityManager, "bridging/config",
                                     "Bridging::Config", "r" )
   portSecConfig = ConfigMount.mount( entityManager, "portsec/config",
                                      "PortSec::Config", "w" )
   portSecLocalConfig = LazyMount.mount( entityManager, "portsec/localconfig",
                                         "PortSec::LocalConfig", "w" )
   portSecStatus = LazyMount.mount( entityManager, "portsec/status",
                                    "PortSec::Status", "r" )
   portSecHwCap = LazyMount.mount( entityManager, "portsec/hwcap",
                                   "PortSec::HwCapabilities", "r" )
   IntfCli.Intf.registerDependentClass( PortSecIntf )
   bridgingStatus = SmashLazyMount.mount( entityManager, "bridging/status",
                                          "Smash::Bridging::Status",
                                          SmashLazyMount.mountInfo( 'reader' ) )
