# Copyright (c) 2015 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from CliModel import Model, Submodel, Dict, Int, Bool, Str, List, Float, Enum
from ArnetModel import MacAddress
from IntfModels import Interface
from MacsecCommon import ( 
   tacCipherSuiteToCli, 
   trafficPolicyOnNoMkaToCli, 
   intfTrafficStatus
)
from ReversibleSecretCli import decodeKey
from TableOutput import terminalWidth
import Arnet
import hashlib
import Tac
from textwrap import TextWrapper
import Toggles.MacsecToggleLib as macsecToggle

txSakDefault = Tac.Value( "Macsec::Sak" )

class MacsecMessageCounterDetail( Model ):
   rxInvalid = Int( help="Invalid interface or null packets" )
   rxEapolError = Int( help="EAPOL errored packets" )
   rxBasicParamSetError = Int( help="Basic parameter set errored received packets" )
   rxUnrecognizedCkn = Int( help="Unrecognized ckn in received packets" )
   rxIcvValidationError = Int( help="ICV validation error in received packets" )
   rxLivePeerListError = Int( help="Live peerlist set errored received packets" )
   rxPotentialPeerListError = Int( help="Potential peerlist set errored" 
                                        " received packets" )
   rxSakUseSetError = Int( help="SAK Use Set errored received packets" )
   rxDistSakSetError = Int( help="Dist SAK Set errored received packets" )
   rxDistCakSetError = Int( help="Dist CAK Set errored received packets" )
   rxIcvIndicatorError = Int( help="ICV Indicator errored received packets" )
   rxUnrecognizedSetError = Int( help="Unrecognized parameter set errored"
                                      " received packets" )
   txInvalid = Int( help="Invalid interface in tx packets" )

   def fromTacc( self, intfCounter ):
      self.rxInvalid = intfCounter.rxMsgCounter.invalid
      self.rxEapolError = intfCounter.rxMsgCounter.eapolErr
      self.rxBasicParamSetError = intfCounter.rxMsgCounter.basicParamSetErr
      self.rxUnrecognizedCkn = intfCounter.rxMsgCounter.unrecognizedCkn
      self.rxIcvValidationError = intfCounter.rxMsgCounter.icvValidationErr
      self.rxLivePeerListError = intfCounter.rxMsgCounter.livePeerListErr
      self.rxPotentialPeerListError = intfCounter.rxMsgCounter.potentialPeerListErr
      self.rxSakUseSetError = intfCounter.rxMsgCounter.sakUseSetErr
      self.rxDistSakSetError = intfCounter.rxMsgCounter.distSakSetErr
      self.rxDistCakSetError = intfCounter.rxMsgCounter.distCakSetErr
      self.rxIcvIndicatorError = intfCounter.rxMsgCounter.icvIndicatorErr
      self.rxUnrecognizedSetError = intfCounter.rxMsgCounter.unrecognizedSetErr
      self.txInvalid = intfCounter.txMsgCounter.invalid

class MacsecMessageCountersInterface( Model ):
   rxPacketsSuccess = Int( help="Total of successfully received packets" )
   rxPacketsFailure = Int( help="Total of invalid/errored received packets" )
   txPacketsSuccess = Int( help="Total of successfully transmitted packets" )
   txPacketsFailure = Int( help="Total of invalid/errored transmitted packets" )

   details = Submodel( help="MAC security counter detailed information",
                       valueType=MacsecMessageCounterDetail,
                       optional=True )

   def fromTacc( self, intfCounter, detail=False ):

      self.rxPacketsSuccess = intfCounter.rxMsgCounter.success
      self.rxPacketsFailure = intfCounter.rxMsgCounter.failure
      self.txPacketsSuccess = intfCounter.txMsgCounter.success
      self.txPacketsFailure = intfCounter.txMsgCounter.failure     
 
      if detail:
         self.details = MacsecMessageCounterDetail()
         self.details.fromTacc( intfCounter )

   def render( self ):
      if not self.details:
         print "%-15d %-15d %-15d %-15d" % \
            ( self.rxPacketsSuccess,
              self.rxPacketsFailure,
              self.txPacketsSuccess,
              self.txPacketsFailure )
      if self.details:
         detailIndentation = ' ' * 4
         errorIndentation = ' ' * 8
         print detailIndentation + "Tx packet success: %d" \
                        % self.txPacketsSuccess
         print detailIndentation + "Tx packet failure: %d" \
                        % self.txPacketsFailure
         print errorIndentation + "Tx invalid: %d" \
                        % self.details.txInvalid
         print detailIndentation + "Rx packet success: %d" \
                        % self.rxPacketsSuccess
         print detailIndentation + "Rx packet failure: %d" \
                        % self.rxPacketsFailure
         print errorIndentation + "Rx invalid: %d" \
                        % ( self.details.rxInvalid )
         print errorIndentation + "Rx eapol error: %d" \
                        % ( self.details.rxEapolError )
         print errorIndentation + "Rx basic parameter set error: %d" \
                        % ( self.details.rxBasicParamSetError )
         print errorIndentation + "Rx unrecognized CKN error: %d" \
                        % ( self.details.rxUnrecognizedCkn )
         print errorIndentation + "Rx ICV validation error: %d" \
                        % ( self.details.rxIcvValidationError )
         print errorIndentation + "Rx live peer list error: %d" \
                        % ( self.details.rxLivePeerListError )
         print errorIndentation + "Rx potential peer list error: %d" \
                        % ( self.details.rxPotentialPeerListError )
         print errorIndentation + "Rx SAK use set error: %d" \
                        % ( self.details.rxSakUseSetError )
         print errorIndentation + "Rx distributed SAK set error: %d" \
                        % ( self.details.rxDistSakSetError )
         print errorIndentation + "Rx distributed CAK set error: %d" \
                        % ( self.details.rxDistCakSetError )
         print errorIndentation + "Rx ICV Indicator error: %d" \
                        % ( self.details.rxIcvIndicatorError )
         print errorIndentation + "Rx unrecognized parameter set error: %d" \
                        % ( self.details.rxUnrecognizedSetError )

class MacsecMessageCounters( Model ):
   interfaces = Dict( help="A mapping between interfaces and"
                            " MAC security MKA counters", keyType=Interface,
                            valueType=MacsecMessageCountersInterface )

   def render( self ):
      if not self.interfaces:
         return

      printHeader = True
      if self.interfaces.values()[ 0 ].details:
         printHeader = False

      if printHeader:
         print "%-15s %-15s %-15s %-15s %-15s" % \
            ( 'Interface', 'Rx Success',
               'Rx Failure', 'Tx Success', 'Tx Failure' )
      for key in Arnet.sortIntf( self.interfaces ):
         if printHeader:
            # below print has a comma so that a new line is not printed
            print "%-15s" % key,
         else:
            print
            print "Interface: %s" % key
         self.interfaces[ key ].render()

class MacsecParticipantDetail( Model ):
   keyServerAddr = MacAddress( help="Mac address of the key server" )
   keyServerPortId = Int( help="Port Id of the key server" )
   
   sakTransmit = Bool( help="True if the participant is using the distributed"
                            " SAK for transmit" )
   llpnExhaustion = Int( help="Number of times LLPN exhaustion event detected when"
                              " this participant is the principal actor" )
   keyServerMsgId = Str( help="Message identifier of the key server" )
   keyNum = Int( help="Key number of the distributed session association key" )
   livePeerList = List( help="List of all live peers for this participant",
                        valueType=str )
   potentialPeerList = List( help=" List of all potential peers for this"
                                  " participant", valueType=str )

   def fromTacc( self, actorStatus ):
      self.keyServerAddr = actorStatus.keyServer.addr
      self.keyServerPortId = actorStatus.keyServer.portNum
      self.sakTransmit = actorStatus.sakTransmit
      self.llpnExhaustion = actorStatus.llpnExhaustion
      self.keyServerMsgId = actorStatus.distSak.keyMsgId
      self.keyNum = actorStatus.distSak.keyNum
      self.livePeerList = []
      for key in actorStatus.livePeer.keys():
         self.livePeerList.append( key )
      self.potentialPeerList = []
      for key in actorStatus.potentialPeer.keys():
         self.potentialPeerList.append( key )

class MacsecParticipant( Model ):
   msgId = Str( help="Member identifier generated for the participant" )
   electedSelf = Bool( help="Indicates whether the participant has elected itself"
                            " as the key server" )
   success = Bool( help="Participant has successfully elected a key"
                        " server and has at least one live peer" )
   defaultActor = Bool( help="Participant is spawned from a configured"
                        " fallback key" )
   principalActor = Bool( help="Participant is also a principal actor" )
   details =  Submodel( valueType=MacsecParticipantDetail, optional=True,
                     help="MAC security participant detail information" )

   def fromTacc( self, actorStatus, detail=False ):
      self.msgId = actorStatus.msgId
      self.electedSelf = actorStatus.electedSelf
      self.success = actorStatus.success
      self.defaultActor = actorStatus.defaultActor
      self.principalActor = actorStatus.principal
      if detail:
         self.details = MacsecParticipantDetail()
         self.details.fromTacc( actorStatus )

   def render( self ):
      cknInformationIndentation = ' ' * 6
      print cknInformationIndentation + "Member ID: %s" % self.msgId
      keyMgmtRole = "Key Server" if self.electedSelf else "Non Key Server"
      print cknInformationIndentation + "Key management role: %s" % keyMgmtRole
      print cknInformationIndentation + "Success: %s" % self.success
      print cknInformationIndentation + "Principal: %s" % self.principalActor
      keyType = "Fallback" if self.defaultActor else "Primary"
      print cknInformationIndentation + "Key type: %s" % keyType
      if self.details:
         sci = "None"
         if self.details.keyServerAddr:
            sci = str( self.details.keyServerAddr ) + "::" + \
                  str( self.details.keyServerPortId )
         print cknInformationIndentation + "KeyServer SCI: %s" % sci
         print cknInformationIndentation + "SAK transmit: %s" % \
               self.details.sakTransmit
         print cknInformationIndentation + "LLPN exhaustion: %d" % \
               self.details.llpnExhaustion
         ki = "None"
         if self.details.keyServerMsgId:
            ki = self.details.keyServerMsgId + ":" + str( self.details.keyNum )
         print cknInformationIndentation + "Distributed key identifier: %s" % ki
         print cknInformationIndentation + "Live peer list: %s" % \
               self.details.livePeerList
         print cknInformationIndentation + "Potential peer list: %s" % \
               self.details.potentialPeerList


class MacsecParticipantsInterface( Model ):
   participants = Dict( help="A mapping between CKN and"
                             " MAC security participants", keyType=str,
                             valueType=MacsecParticipant )

   def render( self ):
      if not self.participants:
         return

      for key in sorted( self.participants ):
         cknIndentation = ' ' * 4
         print cknIndentation + "CKN: %s" % key
         self.participants[ key ].render()
         print 

class MacsecParticipants( Model ):
   interfaces = Dict( help="A mapping between interfaces and"
                           " MAC security interface participants", 
                           keyType=Interface,
                           valueType=MacsecParticipantsInterface )

   def render( self ):
      if not self.interfaces:
         return

      for key in Arnet.sortIntf( self.interfaces ):
         print "Interface: %s" % key
         self.interfaces[ key ].render()

class MacsecInterfaceDetailData( Model ):
   keyServerPriority = Int( help="Configured key server priority for the interface" )
   oldKeyMsgId = Str( help="A 96 bit message identifier for the old key" )
   oldKeyMsgNum = Int( help="Message number of the old key" )
   oldKeyTransmitting = Bool( help="The old key is currently used for"
                                   " encrypting data packets" )
   oldKeyReceiving = Bool( help="The old key is currently used for"
                                " decrypting data packets" )
   latestKeyMsgId = Str( help="A 96 bit message identifier for the latest key" )
   latestKeyMsgNum = Int( help="Message number of the latest key" )
   latestKeyTransmitting = Bool( help="The latest key is currently used for"
                                      " encrypting data packets" )
   latestKeyReceiving = Bool( help="The latest key is currently used for"
                                   " decrypting data packets" )
   sessionReKeyPeriod = Float( help="Period in seconds after which the session keys"
                                    " are refreshed" )
   localSsci = Str( help="Value of local SSCI" )
   bypassProtocol = List( help="List of protocols without MAC security protection",
                          valueType=str )
   traffic = Enum( values=( intfTrafficStatus ), help="Type of traffic" )
   cachedSak = Bool( help="Whether cached SAK is used" )
   fipsPostStatus = Enum( values=( "none", "inProgress", "passed", "failed" ),
                          help="FIPS Power-on self-test" )
   profileName = Str( help="MAC security profile name" )

   def fromTacc( self, intfStatus, cpStatus, hwPostStatus, portStatus ):
      self.keyServerPriority = intfStatus.keyServerPriority
      self.sessionReKeyPeriod = intfStatus.reKeyPeriod
      self.oldKeyMsgId = cpStatus.oldSakContext.sak.keyMsgId
      self.oldKeyMsgNum = cpStatus.oldSakContext.sak.keyNum
      self.oldKeyTransmitting = cpStatus.oldSakContext.txInstalled
      self.oldKeyReceiving = cpStatus.oldSakContext.rxInstalled
      self.latestKeyMsgId = cpStatus.latestSakContext.sak.keyMsgId
      self.latestKeyMsgNum = cpStatus.latestSakContext.sak.keyNum
      self.latestKeyTransmitting = cpStatus.latestSakContext.txInstalled
      self.latestKeyReceiving = cpStatus.latestSakContext.rxInstalled
      self.localSsci = cpStatus.mySsci.ssci
      # If controlledPort is disabled then no traffic is flowing.
      # Else traffic flowing is either unencrypted or encrypted.
      if not cpStatus.controlledPortEnabled:
         self.traffic = 'None'
         self.cachedSak = False
      elif cpStatus.unprotectedTraffic:
         self.traffic = 'Unprotected'
         self.cachedSak = False
      elif portStatus.numSuccessActors > 0:
         self.traffic = 'Protected'
         self.cachedSak = False
      else:
         self.traffic = 'Protected'
         self.cachedSak = True

      self.bypassProtocol = []
      if intfStatus.bypassLldp:
         self.bypassProtocol.append( "LLDP" )
      if hwPostStatus is None:
         self.fipsPostStatus = "none"
      elif hwPostStatus == 'complete':
         self.fipsPostStatus = "passed"
      elif hwPostStatus == 'failed':
         self.fipsPostStatus = "failed"
      else:
         self.fipsPostStatus = "inProgress"

      self.profileName = intfStatus.profileName

class MacsecInterface( Model ):
   address = MacAddress( help="MAC address of the local port" )
   portId = Int( help="Port number of the local port" )
   controlledPort = Bool( help="Indicates whether MAC security is operational or"
                               " not" )
   keyMsgId = Str( help="Message identifier of the distributed key currently in"
                        " use" )
   keyNum = Int( help="Key number of the distributed key currently in use" )
 
   details = Submodel( help="MAC security Interface detail information",
                       valueType=MacsecInterfaceDetailData, optional=True )

   def fromTacc( self, portStatus, cpStatus, intfStatus, hwPostStatus,
                 detail=False ):

      self.address = portStatus.mySci.addr
      self.portId = portStatus.mySci.portNum
      self.controlledPort = cpStatus.controlledPortEnabled
      if cpStatus.txSak != txSakDefault:
         self.keyMsgId = cpStatus.txSak.keyMsgId
         self.keyNum = cpStatus.txSak.keyNum
      else:
         self.keyMsgId = ""
         self.keyNum = 0      

      if detail:
         self.details = MacsecInterfaceDetailData()
         self.details.fromTacc( intfStatus, cpStatus, hwPostStatus, portStatus )

   def renderKey( self, msgId, msgNum, rx, tx ):
      if not msgId:
         return "None"

      flags = ""
      if rx:
         flags += "R"
      if tx:
         flags += "T"
      return msgId + ":" + str( msgNum ) + "(%s)" % flags

   def render( self ):
      ki = "None"
      if self.keyMsgId:
         ki = self.keyMsgId + ':' + \
               str( self.keyNum )
         
      if not self.details:
         print "%-10s::%-6s %-20s %-30s" % \
               ( self.address,
                 self.portId,
                 self.controlledPort, ki )
      else:
         detailIndentation = ' ' * 4
         print detailIndentation + "Profile: %s" % self.details.profileName
         print detailIndentation + "SCI: %s::%s" % \
               ( self.address, self.portId )
         print detailIndentation + "SSCI: %s" % ( self.details.localSsci )
         print detailIndentation + "Controlled port: %s" % \
               self.controlledPort
         print detailIndentation + "Key server priority: %d" % \
               self.details.keyServerPriority
         print detailIndentation + "Session rekey period: %d" % \
               self.details.sessionReKeyPeriod
         trafficStatus =  intfTrafficStatus[ self.details.traffic ]
         if self.details.cachedSak:
            trafficStatus += ' using cached SAK'
         print detailIndentation + "Traffic: %s" % trafficStatus
         if self.details.bypassProtocol:
            print detailIndentation + "Bypassed protocols: %s" % \
                  ", ".join( self.details.bypassProtocol )
         print detailIndentation + "Key in use: %s" % ki
         print detailIndentation + "Latest key: %s" % \
               self.renderKey( self.details.latestKeyMsgId,
                               self.details.latestKeyMsgNum,
                               self.details.latestKeyReceiving,
                               self.details.latestKeyTransmitting )
         print detailIndentation + "Old key: %s" % \
               self.renderKey( self.details.oldKeyMsgId, 
                               self.details.oldKeyMsgNum,
                               self.details.oldKeyReceiving,
                               self.details.oldKeyTransmitting )
         if self.details.fipsPostStatus != "none":
            print detailIndentation + "FIPS POST:",
            if self.details.fipsPostStatus == "inProgress":
               print "In Progress"
            else:
               print "%s" % self.details.fipsPostStatus.capitalize()


class MacsecInterfaces( Model ):
   interfaces = Dict( help="A mapping between interfaces and" 
                           " MAC security interface information", 
                           keyType=Interface,
                           valueType=MacsecInterface )

   def render( self ):
      if not self.interfaces:
         return

      printHeader = True
      if self.interfaces.values()[ 0 ].details:
         printHeader = False

      if printHeader:
         print "%-15s %-25s %-20s %-20s" % \
               ( 'Interface', 'SCI', 'Controlled Port', 'Key in Use' )

      for key in Arnet.sortIntf( self.interfaces ):
         if printHeader:
            print "%-15s" % key,
         else:
            print
            print "Interface: %s" % key
         self.interfaces[ key ].render()

class MacsecStatus( Model ):
   activeProfiles = Int( help="Number of profiles configured in at least"
                              "one interface" ) 
   delayProtection = Bool( help="True if data delay protection is enabled" )
   eapolDestMac = MacAddress( help="MACsec EAPoL destination MAC address" )
   if macsecToggle.toggleMacsecConfigurableEapolEtherTypeEnabled():
      eapolEtherType = Int( help="MACsec EAPoL Ethernet type" )
   fipsMode = Bool( help="True if FIPS Mode is enabled" )
   securedInterfaces = Int( help="Number of interfaces with mac security"
                                 "enabled" )
   licenseEnabled = Bool( help="True if License is enabled" )

   def fromTacc( self, status ):
      self.activeProfiles = len( set( [ intfStatus.profileName for intfStatus in
                                 status.intfStatus.values() ] ) )
      self.delayProtection = status.delayProtection
      self.eapolDestMac = status.eapolAttr.destinationMac
      if macsecToggle.toggleMacsecConfigurableEapolEtherTypeEnabled():
         self.eapolEtherType = status.eapolAttr.etherType
      self.fipsMode = status.fipsStatus.fipsRestrictions
      securedInterfaces = len( [ cpStatus for cpStatus in status.cpStatus.values()
                               if cpStatus.txSak != txSakDefault ] )
      self.securedInterfaces = securedInterfaces
      self.licenseEnabled = status.licenseEnabled

   def render( self ):

      def printLine( label, content ):
         print "%-25s %s " % ( label, content )

      printLine( "Active Profiles:", self.activeProfiles )
      printLine( "Data Delay Protection:", "No" )
      printLine( "EAPoL Destination MAC:",
                 self.eapolDestMac.displayString )
      if macsecToggle.toggleMacsecConfigurableEapolEtherTypeEnabled():
         printLine( "EAPoL Ethernet Type:",
               '0x{0:0{1}X}'.format( self.eapolEtherType, 4 ) )
      printLine( "FIPS Mode:", "Yes" if self.fipsMode else "No" )
      printLine( "Secured Interfaces:", self.securedInterfaces )
      printLine( "License:", "Enabled" if self.licenseEnabled else "Disabled" )

class MacsecProfile( Model ):
   cipher = Enum( values=tuple( tacCipherSuiteToCli.values() ),
                  help="Cipher suite for a profile" )
   ckn = Str( help="Primary connection association key name" )
   cakSha256Hash = Str( help="Connection association key" )
   fallbackCkn = Str( help="Fallback connection association key name" )
   fallbackCakSha256Hash = Str( help="Fallback connection association key" )
   source = Str( help="Cli or EosSdk agent name" )
   priority = Int( help="MACsec config priority" )
   keyRetirePolicy = Enum( values=( "immediate", "delayed" ),
                            help="Key retirement policy" )
   unprotectedTrafficPolicy = Enum ( values=trafficPolicyOnNoMkaToCli.values(), 
                                     help="Unprotected traffic policy" )
   interfaces = List( valueType=Interface,
                        help="List of interfaces this profile is configured on" )

   def fromTacc( self, profileConfig ):
      self.cipher = tacCipherSuiteToCli.get( profileConfig.cipherSuite, "unknown" )
      self.ckn = profileConfig.key.ckn
      self.cakSha256Hash = self.sha256Cak( profileConfig.key.cak )
      self.fallbackCkn = profileConfig.defaultKey.ckn
      self.fallbackCakSha256Hash = self.sha256Cak( profileConfig.defaultKey.cak )
      self.interfaces = profileConfig.intf.keys()
      self.keyRetirePolicy = "immediate" if profileConfig.keyRetire else "delayed"
      self.unprotectedTrafficPolicy = trafficPolicyOnNoMkaToCli.get( 
         profileConfig.trafficPolicyOnNoMka, "unknown" )

   def sha256Cak( self, cak ):
      if cak == "":
         return cak
      return hashlib.sha256( decodeKey( cak ) ).hexdigest()
   
   def _printIntfs( self, intfs, title, lineLen ):
      titleLen = len( title )
      intfStr = ", ".join( Arnet.sortIntf( intfs) )
      intfStr = title + " " + intfStr
      indent = ' ' * ( titleLen + 1 )
      wrapper = TextWrapper(width=lineLen, subsequent_indent=indent)
      intfStr = wrapper.fill( intfStr )
      print intfStr

   def render( self ):
      indent = " " * 4
      def printLine( label, content ):
         print "%s%s %s " % ( indent, label, content )
      printLine( "Cipher:", self.cipher )
      printLine( "Primary CKN:", self.ckn )
      printLine( "Primary CAK SHA-256 hash:", self.cakSha256Hash )
      printLine( "Fallback CKN:", self.fallbackCkn if self.fallbackCkn else "" )
      printLine( "Fallback CAK SHA-256 hash:",
                 self.fallbackCakSha256Hash if self.fallbackCakSha256Hash else "" )
      printLine( "Source:", self.source )
      printLine( "Priority:", self.priority )
      printLine( "Key retirement policy:", self.keyRetirePolicy )
      if macsecToggle.toggleMacsecTrafficBlockOnMkaFailureEnabled():
         printLine( "Unprotected traffic policy:", self.unprotectedTrafficPolicy )
      self._printIntfs( self.interfaces, indent + "Configured on:",
            terminalWidth() - 1 )

class MacsecProfiles( Model ):
   profiles = Dict( help="Mapping between profile name and profile details",
                    keyType=str, valueType=MacsecProfile )
   def render( self ):
      for k, v in self.profiles.iteritems():
         print "Profile:", k
         v.render()

class MacsecAllProfiles( Model ):
   sources = Dict( help="Mapping between source and profiles",
                   keyType=str, valueType=MacsecProfiles )
   def render( self ):
      for _, v  in self.sources.iteritems():
         v.render()
