#!/usr/bin/env python
# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from calendar import timegm
import itertools

import BasicCli
import ShowCommand
import CliCommand
import CliMatcher
import CliParser
import ConfigMount
import DateTimeRule
import LazyMount
import Tac
import ConfigMgmtMode
from CliMode.SharedSecretProfileMode import SharedSecretProfileCliMode
from CliPlugin.Security import SecurityConfigMode
from CliPlugin.SharedSecretProfileModel import ( SharedSecretProfiles,
                                                 SharedSecretProfile, SecretModel )
from MgmtSecurityLib import Lifetime, Secret
from ReversibleSecretCli import ( reversibleAuthPasswordExpression,
                                  tryDecodeToken )

config = None
status = None

securityKwMatcher = CliMatcher.KeywordMatcher( 'security',
                                               helpdesc="Show security status" )

class SharedSecretProfileMode( SharedSecretProfileCliMode, BasicCli.ConfigModeBase ):
   name = "Shared-Secret Profile Configuration"
   modeParseTree = CliParser.ModeParseTree()

   def __init__( self, parent, session, profileName ):
      self.profileName = profileName
      self.session_ = session

      if self.profileName not in config.profile:
         self.profileConfig = config.newProfile( self.profileName )
      else:
         self.profileConfig = config.profile[ self.profileName ]

      SharedSecretProfileCliMode.__init__( self, self.profileName )
      BasicCli.ConfigModeBase.__init__( self, parent, session )

class SharedSecretProfileModeCommandClass( CliCommand.CliCommandClass ):
   """Shared Secret Profile mode commands

      From management security mode do
         session shared-secret profile NAME
      to enter or create shared secret profile configuration mode
      for the specified NAME


      From management security mode
         ( no | default ) session shared-secret profile NAME
      to unconfigure all secret lifetimes configured in
      the shared secret profile configuration mode for the specified NAME"""

   syntax = "session shared-secret profile PROFILE_NAME"
   noOrDefaultSyntax = syntax
   data = {
      'session': 'configure session settings',
      'shared-secret': 'configure settings involving a shared secret',
      'profile': 'configure a profile of shared secret lifetimes',
      'PROFILE_NAME': CliMatcher.DynamicNameMatcher(
         lambda mode: config.profile,
         'shared-secret profile name' )
   }

   @staticmethod
   def handler( mode, args ):
      childMode = mode.childMode( SharedSecretProfileMode,
                                  profileName=args[ 'PROFILE_NAME' ] )
      mode.session_.gotoChildMode( childMode )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      profileName = args[ 'PROFILE_NAME' ]
      if profileName in config.profile:
         del config.profile[ profileName ]

SecurityConfigMode.addCommandClass( SharedSecretProfileModeCommandClass )

def noSecurityConfig( mode ):
   """does a no or default to every profile
      for situations such as when the parent mode is deleted"""
   config.profile.clear()

timeMatcher = DateTimeRule.ValidTimeMatcher()

def dateRuleValue( mode, date ):
   """converts a valid date of the form [ M, D, Y ]
      to a corresponding dictionary"""
   if 1969 < date[ 2 ] < 2038:
      return { 'year': date[ 2 ], 'month': date[ 0 ], 'day': date[ 1 ] }
   else:
      mode.addError( 'Date must be after 1970 and before 2038' )
      raise CliParser.AlreadyHandledError()

def timeRuleValue( mode, time ):
   """converts a time of the form ( h, m, s ) to a corresponding dictionary"""
   return { 'hour': time[ 0 ],
            'minute': time[ 1 ],
            'second': time[ 2 ] }

def datetimeRuleValue( mode, date, time ):
   """passes a time tuple to timegm, which computes seconds since the epoch"""
   return int( timegm( ( date[ 'year' ], date[ 'month' ], date[ 'day' ],
                         time[ 'hour' ], time[ 'minute' ], time[ 'second' ] ) ) )

def lifetimeRuleValue( mode, start, end ):
   """converts a start and end datetime (seconds past the epoch) to a lifetime
   """
   if start < end:
      return Lifetime( start, end )
   else:
      mode.addError( 'lifetime must end after it starts' )
      raise CliParser.AlreadyHandledError()

def lifetimeExpression( name ):
   startDateKey = 'START_DATE_' + name
   endDateKey = 'END_DATE_' + name
   startTimeKey = 'START_TIME_' + name
   endTimeKey = 'END_TIME_' + name
   infKey = 'infinite_' + name

   class LifetimeExpression( CliCommand.CliExpression ):
      expression = "( %s %s %s %s ) | %s" % ( startDateKey,
                                              startTimeKey,
                                              endDateKey,
                                              endTimeKey,
                                              infKey )
      data = {
         startDateKey: DateTimeRule.dateExpression( startDateKey ),
         startTimeKey: timeMatcher,
         endDateKey: DateTimeRule.dateExpression( endDateKey ),
         endTimeKey: timeMatcher,
         infKey: CliMatcher.KeywordMatcher( 'infinite',
                                             helpdesc='infinite lifetime' )
      }

      @staticmethod
      def adapter( mode, args, argsList ):
         infinite = args.get( infKey )
         lifetime = None
         if infinite:
            lifetime = Lifetime( 0, 0 )
         elif startDateKey in args:
            startDate = args[ startDateKey ]
            startTime = args[ startTimeKey ]
            endDate = args[ endDateKey ]
            endTime = args[ endTimeKey ]
            start = datetimeRuleValue( mode,
                                       dateRuleValue( mode, startDate ),
                                       timeRuleValue( mode, startTime ) )
            end = datetimeRuleValue( mode,
                                     dateRuleValue( mode, endDate ),
                                     timeRuleValue( mode, endTime ) )
            lifetime = lifetimeRuleValue( mode, start, end )
         if lifetime is not None:
            args[ name ] = lifetime

   return LifetimeExpression

def _findSecret( mode, keyId ):
   for s in mode.profileConfig.secret:
      if s.id == keyId:
         return s
   return None

class SecretConfigCommand( CliCommand.CliCommandClass ):
   syntax = "secret ID SECRET LIFETIME | " \
            "( receive-lifetime RECV_LIFETIME " \
            "transmit-lifetime TRANS_LIFETIME ) | " \
            "( transmit-lifetime TRANS_LIFETIME " \
            "receive-lifetime RECV_LIFETIME )"
   noOrDefaultSyntax = "secret ID ..."

   data = {
      'secret': 'Configure lifetimes for a specified key',
      'receive-lifetime': 'Configure the lifetime for receiving the key',
      'transmit-lifetime': 'Configure the lifetime for transmitting the key',
      'ID': CliMatcher.IntegerMatcher( 0, 255,
                                       helpdesc='Identifier for the key' ),
      'SECRET': reversibleAuthPasswordExpression( 'SECRET' ),
      'LIFETIME': lifetimeExpression( 'LIFETIME' ),
      'RECV_LIFETIME': lifetimeExpression( 'RECV_LIFETIME' ),
      'TRANS_LIFETIME': lifetimeExpression( 'TRANS_LIFETIME' )
   }

   @staticmethod
   def handler( mode, args ):
      keyId = args[ 'ID' ]
      recvLifetime = args.get( 'LIFETIME', args.get( 'RECV_LIFETIME' ) )
      transLifetime = args.get( 'LIFETIME', args.get( 'TRANS_LIFETIME' ) )
      secret = tryDecodeToken( args[ 'SECRET' ], algorithm='MD5', mode=mode )
      if not secret:
         return

      assert recvLifetime is not None
      assert transLifetime is not None

      s = _findSecret( mode, keyId )
      newS = Secret( keyId, secret, recvLifetime, transLifetime )
      if s != newS:
         if s is not None:
            mode.profileConfig.secret.remove( s )
         mode.profileConfig.secret.add( newS )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      s = _findSecret( mode, args[ 'ID' ] )
      if s:
         mode.profileConfig.secret.remove( s )

SharedSecretProfileMode.addCommandClass( SecretConfigCommand )

class ShowSharedSecretProfileMixin( object ):
   @staticmethod
   def _rotationSequence( getLifetime, secrets ):
      """
      A given profile's secrets will become "current" in a deterministic sequence
      based on their respective lifetimes. For example, a trivial example is 3
      secrets with disjoint lifetimes:

      Secret 1: now => for 3 months
      Secret 2: 4 months from now => for 3 more months
      Secret 3: 8 months from now => for 3 more months

      These secrets obviously become valid in the order { 1, 2, 3 }. In general,
      the order in which the secrets will be chosen is determined by their start
      times.

      A naive algorithm would simply iterate over all possible times representable
      as an integer [0,2^32-1), determine the set of valid secrets at each time
      instance. Then, for each set determine the most preferred of the valid
      secrets. Compressing the resulting list would yield the rotation order.
      Obviously, this is awfully inefficient. We can significantly improve the
      algorithm by only considering times at which the secret ID can change: the
      start and end of lifetimes. This routine does just that.
      """
      def significantTimes():
         times = []
         for s in secrets:
            # We care about either the receive or the transmit lifetime, so fetch
            # via lambda.
            lifetime = getLifetime( s )

            # In the unlikely event that we have a start time of 0, or greater
            # than 2^32 - 1 (Lifetime::{start.end} are of type U32), we should not
            # consider these times.
            def unsignedIncr( val ):
               return ( val + 1 ) & 0xffffffffL

            def unsignedDecr( val ):
               return ( val - 1 ) & 0xffffffffL

            # We need start -1 and end + 1 to capture infinite duration and "No
            # Secret" durations, respectively.
            times.extend( [ lifetime.start, unsignedDecr( lifetime.start ),
                            lifetime.end, unsignedIncr( lifetime.end ) ] )
         times.sort()
         return times

      def mostPreferredSecret( time ):
         """
         Determine the list of secrets which are valid at time 'time'. From this,
         return the "most preferred". That is, the secret whose lifetime began
         most recently. Infinite duration secrets are only chosen if there are no
         other valid secrets. Ties go to the higher ID.
         """
         def infinite( s ):
            return getLifetime( s ).isInfinite()
         validSecrets = [ s for s in secrets
                          if getLifetime( s ).isValidAtTime( time ) ]
         if not validSecrets:
            return "No Secret"

         nonInfiniteSecrets = [ s for s in validSecrets if not infinite( s ) ]
         if nonInfiniteSecrets:
            # Max of start time (and ID to break ties)
            secretId = max( nonInfiniteSecrets,
                                    key=lambda s: getLifetime( s ).start +
                                                  s.id ).id
         else:
            # Choose a tie breaker from the infinite secrets
            infiniteSecrets = filter( infinite, validSecrets )
            secretId = max( infiniteSecrets, key=lambda s: s.id ).id
         return secretId

      sequence = map( mostPreferredSecret, significantTimes() )

      # Unless there is an infinite duration secret or one which starts at time 0
      # (1970), there will always be a "No Secret" at the front. This information
      # is rather useless, so trim it.
      if sequence and sequence[ 0 ] == "No Secret":
         sequence = sequence[ 1 : ]

      # We may have the same secret multiple times, e.g. { 5, 5, 10, 15, 15, 15,
      # No Secret }. Remove the duplicates. Also, convert to strings.
      return [ str( secretId ) for secretId, _ in itertools.groupby( sequence ) ]

   @staticmethod
   def _computeSharedSecretProfileModel( name ):
      model = SharedSecretProfile()
      profile = config.profile[ name ]
      assert name in status.currentSecret
      currentSecret = status.currentSecret[ name ]
      receiveSecrets = currentSecret.receiveSecret.keys()
      rxSecret = receiveSecrets[ 0 ] if receiveSecrets else Secret()
      txSecret = currentSecret.transmitSecret

      model.profileName = name
      if rxSecret:
         model.currentRxId = rxSecret.id
         model.currentRxExpiration = rxSecret.receiveLifetime.end
      if txSecret:
         model.currentTxId = txSecret.id
         model.currentTxExpiration = txSecret.transmitLifetime.end

      secrets = profile.secret.keys()
      # pylint: disable=protected-access
      model.rxRotation = ShowSharedSecretProfileCmd._rotationSequence(
            lambda s: s.receiveLifetime, secrets )
      model.txRotation = ShowSharedSecretProfileCmd._rotationSequence(
            lambda s: s.transmitLifetime, secrets )
      # pylint: enable=protected-access

      model.secrets = []
      for s in secrets:
         secret = SecretModel()

         secret.secretID = s.id
         secret.secret = s.secret
         secret.rxLifetimeStart = s.receiveLifetime.start
         secret.rxLifetimeEnd = s.receiveLifetime.end
         secret.txLifetimeStart = s.transmitLifetime.start
         secret.txLifetimeEnd = s.transmitLifetime.end

         model.secrets.append( secret )

      return model

class ShowSharedSecretProfilesCmd( ShowCommand.ShowCliCommandClass,
                                   ShowSharedSecretProfileMixin ):
   syntax = "show management security session shared-secret profile"
   data = {
      'management': ConfigMgmtMode.managementShowKwMatcher,
      'security': securityKwMatcher,
      'session': 'Show security session information',
      'shared-secret': 'Show shared-secrets',
      'profile': 'Show a shared-secret profile configuration'
   }
   cliModel = SharedSecretProfiles

   @classmethod
   def handler( cls, mode, args ):
      model = cls.cliModel()
      for profile in config.profile.values():
         name = profile.name
         model.profiles[ name ] = cls._computeSharedSecretProfileModel( name )
      return model

BasicCli.addShowCommandClass( ShowSharedSecretProfilesCmd )

class ShowSharedSecretProfileCmd( ShowCommand.ShowCliCommandClass,
                                  ShowSharedSecretProfileMixin ):
   syntax = "show management security session shared-secret profile NAME"
   data = {
      'management': ConfigMgmtMode.managementShowKwMatcher,
      'security': securityKwMatcher,
      'session': 'Show security session information',
      'shared-secret': 'Show shared-secrets',
      'profile': 'Show a shared-secret profile configuration',
      'NAME': CliMatcher.DynamicNameMatcher( lambda mode: config.profile,
                                            'shared-secret profile name' )
   }
   cliModel = SharedSecretProfile

   @classmethod
   def handler( cls, mode, args ):
      name = args[ 'NAME' ]
      if name in config.profile:
         return cls._computeSharedSecretProfileModel( name )
      return cls.cliModel()

BasicCli.addShowCommandClass( ShowSharedSecretProfileCmd )

def Plugin( entityManager ):
   global config, status
   config = ConfigMount.mount( entityManager,
                               "mgmt/security/sh-sec-prof/config",
                               "Mgmt::Security::SharedSecretProfile::Config", "w" )
   status = LazyMount.mount( entityManager,
                             "mgmt/security/sh-sec-prof/status",
                             "Mgmt::Security::SharedSecretProfile::Status", "r" )
