# Copyright (c) 2010, 2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import CliSave, Management, Tac
import SysMgrLib
from IpLibConsts import DEFAULT_VRF
from CliMode.VrfConfig import VrfConfigMode
from CliMode.SshTunnel import SshTunnelMode
from CliMode.SshTunnel import SshTunnelVrfMode
from CliMode.SshUser import SshUserMode

authenticationMode = Tac.Type( "Mgmt::Ssh::AuthenticationMode" )
serverPort = Tac.Type( "Mgmt::Ssh::ServerPort" )

class SshConfigMode( Management.MgmtConfigMode ):

   def __init__( self, param ):
      Management.MgmtConfigMode.__init__( self, "ssh" )

class SshTunnelConfigMode( SshTunnelMode, CliSave.Mode ):

   def __init__( self, param ):
      SshTunnelMode.__init__( self, param )
      CliSave.Mode.__init__( self, self.longModeKey )

class SshVrfTunnelConfigMode( SshTunnelVrfMode, CliSave.Mode ):

   def __init__( self, param ):
      SshTunnelVrfMode.__init__( self, param )
      CliSave.Mode.__init__( self, self.longModeKey )

CliSave.GlobalConfigMode.addChildMode( SshConfigMode )
SshConfigMode.addCommandSequence( 'Mgmt.ssh' )

SshConfigMode.addChildMode( SshTunnelConfigMode )
SshTunnelConfigMode.addCommandSequence( 'Mgmt.ssh.tunnel' )

class SshVrfConfigMode( VrfConfigMode, CliSave.Mode ):
   def __init__( self, param ):
      VrfConfigMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

SshConfigMode.addChildMode( SshVrfConfigMode )
SshVrfConfigMode.addCommandSequence( 'Mgmt.ssh.vrf' )
SshVrfConfigMode.addChildMode( SshVrfTunnelConfigMode, after=[ 'Mgmt.ssh.vrf' ] )
SshVrfTunnelConfigMode.addCommandSequence( 'Mgmt.ssh.vrf.tunnel' )

class SshUserConfigMode( SshUserMode, CliSave.Mode ):
   def __init__( self, param ):
      SshUserMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

SshConfigMode.addChildMode( SshUserConfigMode )
SshUserConfigMode.addCommandSequence( 'Mgmt.ssh.user' )

def saveKnownHosts( config, options, cmds, cliSaveMode=SshVrfConfigMode ):
   """
   Go through the config and create a list
   of known-host commands to be saved.
   """
   for host in config.knownHost:
      knownHostEntry = config.knownHost[ host ]
      ignoreMode = SshVrfConfigMode if knownHostEntry.configuredInSshMode \
                   else SshConfigMode
      if cliSaveMode == ignoreMode:
         continue
      keyType = SysMgrLib.tacKeyTypeToCliKey[ knownHostEntry.type ]
      cmds.addCommand( 'known-host %s %s %s' % ( knownHostEntry.host, keyType,
                                                 knownHostEntry.publicKey ) )

def saveSshTunnels( config, parentMode, tunnelMode, options, vrfName="" ):
   """
   Go through the config and generate a SSH tunnel sub-mode
   to save each tunnels info.
   """
   for tunnelName in config.tunnel:
      tunnel = config.tunnel[ tunnelName ]
      # Default VRF tunnel can be configured in 'management ssh' (old way) or in
      # 'vrf default' (new way).
      ignoreMode = SshVrfTunnelConfigMode if tunnel.configuredInSshMode \
                   else SshTunnelConfigMode
      if tunnelMode == ignoreMode:
         continue
      cmds = None

      if vrfName:
         mode = parentMode[ tunnelMode ].getOrCreateModeInstance(
               ( vrfName, tunnel.name ) )
         cmds = mode[ 'Mgmt.ssh.vrf.tunnel' ]
      else:
         mode = parentMode[ tunnelMode ].getOrCreateModeInstance( tunnel.name )
         cmds = mode[ 'Mgmt.ssh.tunnel' ]
      if ( tunnel.sshServerAddress ) and \
         ( tunnel.sshServerUsername ) and \
         ( tunnel.sshServerPort != serverPort.invalid ):
         cmds.addCommand( 'ssh-server %s user %s port %d' % \
                          ( tunnel.sshServerAddress,
                            tunnel.sshServerUsername,
                            tunnel.sshServerPort ) )
      if tunnel.localPort != serverPort.invalid:
         cmds.addCommand( 'local port %d' % ( tunnel.localPort ) )
      if ( tunnel.remoteHost ) and ( tunnel.remotePort != serverPort.invalid ):
         cmds.addCommand( 'remote host %s port %d' % ( tunnel.remoteHost,
                                                       tunnel.remotePort ) )
      saveServerAlive = tunnel.serverAliveInterval != \
            tunnel.serverAliveIntervalDefault or options.saveAll
      if saveServerAlive:
         cmds.addCommand( 'server-alive interval %d' %\
               ( tunnel.serverAliveInterval ) )
      saveServerMaxLost = tunnel.serverAliveMaxLost != \
            tunnel.serverAliveMaxLostDefault or options.saveAll
      if saveServerMaxLost:
         cmds.addCommand( 'server-alive count-max %d' %\
               ( tunnel.serverAliveMaxLost ) )
      if tunnel.unlimitedRestarts:
         cmds.addCommand( 'unlimited-restarts' )
      if tunnel.enable:
         cmds.addCommand( "no shutdown" )
      elif options.saveAll:
         cmds.addCommand( "shutdown" )

def saveSshUserConfig( config, parentMode, userMode, options, vrfName="" ):
   """
   Go through the config and generate a SSH user sub-mode
   to save each tunnels info.
   """
   for userName in sorted( config.user ):
      mode = parentMode[ userMode ].getOrCreateModeInstance( userName )
      cmds = mode[ 'Mgmt.ssh.user' ]
      userConfig = config.user[ userName ]
      if userConfig.userTcpForwarding != userConfig.userTcpForwardingDefault:
         cmds.addCommand( "tcp forwarding %s" % userConfig.userTcpForwarding )
      elif options.saveAll:
         cmds.addCommand( "no tcp forwarding" )

@CliSave.saver( 'Mgmt::Ssh::Config', 'mgmt/ssh/config' )
def saveSsh( sshConfig, root, sysdbRoot, options ):
   mode = root[ SshConfigMode ].getSingletonInstance()
   cmds = mode[ 'Mgmt.ssh' ]

   if ( sshConfig.idleTimeout.timeout !=
        sshConfig.idleTimeout.defaultTimeout or options.saveAll ):
      # Only need to save timeout if different from default.
      cmds.addCommand( "idle-timeout %s" %
                       ( int( sshConfig.idleTimeout.timeout / 60 ) ) )

   if sshConfig.authenticationMode == authenticationMode.password:
      cmds.addCommand( 'authentication mode password' )
   elif options.saveAll:
      cmds.addCommand( 'authentication mode keyboard-interactive' )

   if sshConfig.serverPort != serverPort.defaultPort or options.saveAll:
      cmds.addCommand( 'server-port %d' % sshConfig.serverPort )

   if sshConfig.cipher != sshConfig.cipherDefault or options.saveAll:
      cmds.addCommand( 'cipher %s' % sshConfig.cipher )

   if sshConfig.kex != sshConfig.kexDefault or options.saveAll:
      cmds.addCommand( 'key-exchange %s' % sshConfig.kex )

   if sshConfig.mac != sshConfig.macDefault or options.saveAll:
      cmds.addCommand( 'mac %s' % sshConfig.mac )

   if ( sshConfig.rekeyDataAmount != sshConfig.rekeyDataAmountDefault or
        sshConfig.rekeyDataUnit != sshConfig.rekeyDataUnitDefault or
        options.saveAll ):
      cmds.addCommand( 'rekey frequency %d %s' % ( sshConfig.rekeyDataAmount,
                                                   sshConfig.rekeyDataUnit ) )
   if ( sshConfig.rekeyTimeLimit != sshConfig.rekeyTimeLimitDefault or
        sshConfig.rekeyTimeUnit != sshConfig.rekeyTimeUnitDefault or
        options.saveAll ):
      cmds.addCommand( 'rekey interval %d %s' % ( sshConfig.rekeyTimeLimit,
                                                  sshConfig.rekeyTimeUnit ) )

   if sshConfig.hostkey != sshConfig.hostkeyDefault or options.saveAll:
      cmds.addCommand( 'hostkey server %s' % sshConfig.hostkey )

   if ( sshConfig.connLimit != sshConfig.connLimitDefault or
        options.saveAll ):
      cmds.addCommand( 'connection limit %s' % sshConfig.connLimit )

   if ( sshConfig.perHostConnLimit != sshConfig.perHostConnLimitDefault or
        options.saveAll ):
      cmds.addCommand( 'connection per-host %s' % sshConfig.perHostConnLimit )

   if sshConfig.fipsRestrictions:
      cmds.addCommand( 'fips restrictions' )
   elif options.saveAll:
      cmds.addCommand( 'no fips restrictions' )

   if sshConfig.enforceCheckHostKeys:
      cmds.addCommand( 'hostkey client strict-checking' )
   elif options.saveAll:
      cmds.addCommand( 'no hostkey client strict-checking' )

   val = sshConfig.permitEmptyPasswords
   if val != sshConfig.permitEmptyPasswordsDefault:
      cmds.addCommand( 'authentication empty-passwords %s' % val )
   elif options.saveAll:
      cmds.addCommand( 'authentication empty-passwords auto' )

   saveKnownHosts( sshConfig, options, cmds, cliSaveMode=SshConfigMode )

   val = sshConfig.clientAliveInterval
   if val != sshConfig.clientAliveIntervalDefault:
      cmds.addCommand( 'client-alive interval %s' % val )
   elif options.saveAll:
      cmds.addCommand( 'default client-alive interval' )

   val = sshConfig.clientAliveCountMax
   if val != sshConfig.clientAliveCountMaxDefault:
      cmds.addCommand( 'client-alive count-max %s' % val )
   elif options.saveAll:
      cmds.addCommand( 'default client-alive count-max' )

   if sshConfig.serverState == "disabled":
      cmds.addCommand( 'shutdown' )
   elif options.saveAll:
      cmds.addCommand( 'no shutdown' )

   if ( sshConfig.successfulLoginTimeout.timeout !=
        sshConfig.successfulLoginTimeout.defaultTimeout or options.saveAll ):
      # Only need to save timeout if different from default.
      time = sshConfig.successfulLoginTimeout.timeout
      if time == 0:
         cmds.addCommand( "no login timeout" )
      else:
         cmds.addCommand( "login timeout %d" % ( time, ) )

   if sshConfig.logLevel != sshConfig.logLevelDefault or options.saveAll:
      cmds.addCommand( "log-level %s" % sshConfig.logLevel )

   if sshConfig.loggingTargetEnabled:
      cmds.addCommand( "logging target system" )
   elif options.saveAll:
      cmds.addCommand( "no logging target system" )

   if sshConfig.dscpValue != sshConfig.dscpValueDefault:
      cmds.addCommand( "qos dscp %s" % sshConfig.dscpValue )
   elif options.saveAll:
      cmds.addCommand( "qos dscp %s" % sshConfig.dscpValueDefault )

   if len( sshConfig.tunnel ) > 0:
      saveSshTunnels( sshConfig, mode, SshTunnelConfigMode, options )

   if sshConfig.caKeyFiles:
      caKeyFiles = " ".join( sorted( sshConfig.caKeyFiles ) )
      cmds.addCommand( "trusted-ca key public %s" % caKeyFiles )
   elif options.saveAll:
      cmds.addCommand( "no trusted-ca key public" )

   if sshConfig.hostCertFiles:
      hostCertFiles = " ".join( sorted( sshConfig.hostCertFiles ) )
      cmds.addCommand( "hostkey server cert %s" % hostCertFiles )
   elif options.saveAll:
      cmds.addCommand( "no hostkey server cert" )

   if sshConfig.revokedUserKeysFiles:
      revokedUserKeysFiles = " ".join( sorted( sshConfig.revokedUserKeysFiles ) )
      cmds.addCommand( "user-keys revoke-list %s" % revokedUserKeysFiles )
   elif options.saveAll:
      cmds.addCommand( "no user-keys revoke-list" )

   if len( sshConfig.user ) > 0:
      saveSshUserConfig( sshConfig, mode, SshUserConfigMode, options )

   for ( vrfName, vrfConfig ) in sshConfig.vrfConfig.iteritems():
      vrfMode = mode[ SshVrfConfigMode ].getOrCreateModeInstance(
         ( vrfName, 'ssh', sshConfig ) )
      vrfCmds = vrfMode[ 'Mgmt.ssh.vrf' ]
      # pylint thinks 'cmds' is a list
      # pylint: disable-msg=E1103
      if vrfConfig.serverState == "enabled":
         vrfCmds.addCommand( "no shutdown" )
      elif vrfConfig.serverState == "disabled":
         vrfCmds.addCommand( "shutdown" )
      elif options.saveAll and vrfConfig.serverState == "globalDefault":
         vrfCmds.addCommand( "default shutdown" )
      conf = sshConfig if vrfName == DEFAULT_VRF else vrfConfig
      saveKnownHosts( conf, options, vrfCmds )
      saveSshTunnels( conf, vrfMode, SshVrfTunnelConfigMode, options,
                      vrfName=vrfName )

@CliSave.saver( 'Acl::Input::CpConfig', 'acl/cpconfig/cli' )
def saveSshIpAclRev1( aclCpConfig, root, sysdbRoot, options, requireMounts ):
   def saveServiceAcl( aclType ):
      for vrfName, serviceAclVrfConfig in \
            aclCpConfig.cpConfig[ aclType ].serviceAcl.iteritems():
         serviceConfig = serviceAclVrfConfig.service.get( 'ssh' )
         if serviceConfig:
            if serviceConfig.aclName != '':
               mode = root[ SshConfigMode ].getSingletonInstance()
               cmds = mode[ 'Mgmt.ssh' ]
               if vrfName == DEFAULT_VRF:
                  cmds.addCommand( '%s access-group %s in' %
                                    ( aclType, serviceConfig.aclName ) )
               else:
                  cmds.addCommand( '%s access-group %s vrf %s in' %
                                    ( aclType, serviceConfig.aclName, vrfName ) )
            elif options.saveAll and serviceConfig.defaultAclName != '':
               mode = root[ SshConfigMode ].getOrCreateModeInstance( 'ssh' )
               cmds = mode[ 'Mgmt.ssh' ]
               if vrfName == DEFAULT_VRF:
                  cmds.addCommand( '%s access-group %s in' %
                                    ( aclType, serviceConfig.defaultAclName ) )
               else:
                  cmds.addCommand( '%s access-group %s vrf %s in' %
                                    ( aclType, serviceConfig.defaultAclName,
                                      vrfName ) )

   for t in ( 'ip', 'ipv6' ):
      saveServiceAcl( t )
