#!/usr/bin/env python
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import CliSave
import Management
import Tac
from CliMode.Ldap import ( ServerConfigModeBase, GroupPolicyConfigModeBase,
                           LdapServerGroupModeBase )
from CliSavePlugin.NetworkCliSave import networkConfigCmdSeq
import LdapConstants
from IpLibConsts import DEFAULT_VRF

DEFAULT_ROLE_PRIVLEVEL = 1

class LdapConfigSaveMode( Management.MgmtConfigMode ):
   def __init__( self, param ):
      Management.MgmtConfigMode.__init__( self, "ldap" )

   def skipIfEmpty( self ):
      return True

class ServerConfigSaveMode( ServerConfigModeBase, CliSave.Mode ):
   def __init__( self, param ):
      cmd, index = param
      ServerConfigModeBase.__init__( self, cmd )
      CliSave.Mode.__init__( self, cmd )
      self.param = cmd
      self.index_ = index

   def __cmp__( self, other ):
      if other.param == "defaults":
         return 1
      return cmp( self.index_, other.index_ )

   def skipIfEmpty( self ):
      return self.param_ == 'defaults'

class GroupPolicyConfigSaveMode( GroupPolicyConfigModeBase, CliSave.Mode ):
   def __init__( self, param ):
      GroupPolicyConfigModeBase.__init__( self, param )
      CliSave.Mode.__init__( self, param )

   def skipIfEmpty( self ):
      return True

class LdapServerGroupConfigSaveMode( LdapServerGroupModeBase, CliSave.Mode ):
   def __cmp__( self, other ):
      return cmp( self.param_, other.param_ )

   def __init__( self, param ):
      LdapServerGroupModeBase.__init__( self, param )
      CliSave.Mode.__init__( self, param )

CliSave.GlobalConfigMode.addChildMode( LdapConfigSaveMode )
LdapConfigSaveMode.addCommandSequence( 'Mgmt.ldap' )

LdapConfigSaveMode.addChildMode( ServerConfigSaveMode )

ServerConfigSaveMode.addCommandSequence( 'ldap.server' )

LdapConfigSaveMode.addChildMode( GroupPolicyConfigSaveMode,
                                 after=[ ServerConfigSaveMode ] )
GroupPolicyConfigSaveMode.addCommandSequence( 'ldap.groupPolicy' )

CliSave.GlobalConfigMode.addChildMode( LdapServerGroupConfigSaveMode,
                                       before=[ 'Aaa.global' ] )
LdapServerGroupConfigSaveMode.addCommandSequence( 'ldap.serverGroup' )

CliSave.GlobalConfigMode.addCommandSequence( 'Ldap.global',
                                             before=[
                                                'Aaa.global',
                                                LdapServerGroupConfigSaveMode ],
                                             after=[ networkConfigCmdSeq ] )

def serverHostCmd( spec ):
   cmd = "host %s" % spec.hostname
   assert spec.vrf != ""
   if spec.vrf != DEFAULT_VRF:
      cmd += " vrf %s" % spec.vrf
   if spec.port != LdapConstants.DEFAULT_LDAP_PORT:
      cmd += " port %d" % spec.port
   return cmd

def addServerDefaults( serverCmds, ldapConfig, options ):
   if ldapConfig.baseDn:
      serverCmds.addCommand( 'base-dn %s' % ldapConfig.baseDn )
   if ldapConfig.userRdnAttribute:
      serverCmds.addCommand( 'rdn attribute user %s' % ldapConfig.userRdnAttribute )
   if ldapConfig.sslProfile:
      serverCmds.addCommand( 'ssl-profile %s' % ldapConfig.sslProfile )
   if hasattr( ldapConfig, 'defaultGroupPolicy' ) and ldapConfig.defaultGroupPolicy:
      serverCmds.addCommand( 'authorization group policy %s' %
                             ldapConfig.defaultGroupPolicy )
   elif hasattr( ldapConfig, 'activeGroupPolicy' ) and ldapConfig.activeGroupPolicy:
      serverCmds.addCommand( 'authorization group policy %s' %
                             ldapConfig.activeGroupPolicy )

   if ldapConfig.searchUsernamePassword != Tac.Value(
         "Ldap::UsernamePassword", "", "" ):
      username = ldapConfig.searchUsernamePassword.username
      serverCmds.addCommand( 'search username %s password 7 %s' % (
         '"{}"'.format( username ) if ' ' in username else username,
         CliSave.sanitizedOutput( options,
                                  ldapConfig.searchUsernamePassword.password ) ) )

def addGroupPolicy( groupPolicyCmds, groupPolicy ):
   if groupPolicy.searchFilter != Tac.Value( "Ldap::ObjectClassOptions", "", "" ):
      groupPolicyCmds.addCommand( 'search filter objectclass % s attribute % s' % (
         groupPolicy.searchFilter.group, groupPolicy.searchFilter.member ) )
   for groupRolePriv in groupPolicy.groupRolePrivilege.itervalues():
      cmd = 'group "%s" role %s' % ( groupRolePriv.group, groupRolePriv.role )
      if groupRolePriv.privilege != DEFAULT_ROLE_PRIVLEVEL:
         cmd += ' privilege %d' % groupRolePriv.privilege
      groupPolicyCmds.addCommand( cmd )

@CliSave.saver( 'Ldap::Config', "security/aaa/ldap/config" )
def saveLdap( ldapConfig, root, sysdbRoot, options ):
   ldapMode = root[ LdapConfigSaveMode ].getOrCreateModeInstance( 'ldap' )
   # Add 'server host <hostnameOrIp> ...
   for spec in sorted( ldapConfig.host.values(), key=lambda host: host.index ):
      cmd = serverHostCmd( spec )
      serverMode = ldapMode[ ServerConfigSaveMode ].getOrCreateModeInstance(
         ( cmd, spec.index ) )
      serverCmds = serverMode[ 'ldap.server' ]
      addServerDefaults( serverCmds, ldapConfig.host[ spec.spec ].serverConfig,
                         options )
   serverMode = ldapMode[ ServerConfigSaveMode ].getOrCreateModeInstance(
      ( 'defaults', None ) )
   serverCmds = serverMode[ 'ldap.server' ]
   # Add commands in 'server defaults' mode
   addServerDefaults( serverCmds, ldapConfig, options )

   # Add commands in 'group policy <policyName>' mode
   for name, groupPolicy in ldapConfig.groupPolicy.iteritems():
      groupPolicyMode = ldapMode[ GroupPolicyConfigSaveMode
      ].getOrCreateModeInstance( name )
      groupPolicyCmds = groupPolicyMode[ 'ldap.groupPolicy' ]
      addGroupPolicy( groupPolicyCmds, groupPolicy )

@CliSave.saver( 'Aaa::HostGroup', 'security/aaa/config' )
def saveHostGroup( entity, root, sysdbRoot, options ):
   if entity.groupType != 'ldap':
      return
   mode = root[ LdapServerGroupConfigSaveMode ].getOrCreateModeInstance(
      entity.name )
   cmds = mode[ 'ldap.serverGroup' ]
   for m in entity.member.values():
      cmd = "server %s" % ( m.spec.hostname )
      assert m.spec.vrf != ''
      if m.spec.vrf != DEFAULT_VRF:
         cmd += " vrf %s" % ( m.spec.vrf )
      if m.spec.port != LdapConstants.DEFAULT_LDAP_PORT or options.saveAll:
         cmd += " port %d" % ( m.spec.port )
      cmds.addCommand( cmd )
