# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import re
import tempfile
from SslCertKey import dirCreate
import Tracing
import Tac

traceHandle = Tracing.Handle( 'SshCertLib' )
t0 = traceHandle.trace0

Constants = Tac.Type( "Mgmt::Ssh::Constants" )
KeyType = Tac.Type( "Mgmt::Ssh::KeyType" )

class SshHostCertError( Exception ):
   pass

class SshKeyError( Exception ):
   pass

def createSshDirs():
   dirCreate( Constants.sshBaseDir )
   dirCreate( Constants.caKeysDirPath() )
   dirCreate( Constants.hostCertsDirPath() )
   dirCreate( Constants.revokeListsDirPath() )

# Validate file containing public keys
def validateMultipleKeysFile ( keyFile ):
   t0( "Validating key file: %s" % keyFile )
   with open( keyFile, "r" ) as keyFileHandle:
      for key in keyFileHandle.readlines():
         if not key.strip():
            continue
         with tempfile.NamedTemporaryFile() as tmpFile: 
            tmpFile.write( key )
            tmpFile.flush()
            try:
               Tac.run( [ "ssh-keygen", "-lf", tmpFile.name ],
                        stdout=Tac.CAPTURE, stderr=Tac.CAPTURE, 
                        ignoreReturnCode=False )
               t0( "Valid key: %s" % key )
            except Tac.SystemCommandError as e:
               t0( "Invalid key: %s : Error: %s" % ( key.rstrip(), 
                                                     e.output.rstrip() ) )
               raise SshKeyError( "Invalid key: %s" % key.rstrip() )

def validateHostCert( certFile ):
   try:
      cert = Tac.run( [ "ssh-keygen", "-Lf", certFile ],
                      stdout=Tac.CAPTURE, stderr=Tac.CAPTURE,
                      ignoreReturnCode=False )
      return cert
   except Tac.SystemCommandError as e:
      t0( "Invalid cert: %s" % e.output.rstrip() )
      raise SshHostCertError( "Invalid certificate" )

   if not re.search( ".*Type: .*cert.*@openssh.com host certificate", cert ):
      raise SshHostCertError( "Not a host certificate" )

def getHostCertKeyType( certFile ):
   certPath = Constants.hostCertPath( certFile )
   try:
      # Make sure cert is valid
      validateHostCert( certPath )
      with open( certPath, 'r' ) as certFileHandler:
         cert = certFileHandler.read().strip()
         return cert.split( ' ' )[ 0 ]
   except ( OSError, IOError, SshHostCertError ):
      return KeyType.invalid

def hostCertsByKeyTypes( certFiles ):
   hostCertKeyTypes = {}
   for cert in certFiles:
      keyType = getHostCertKeyType( cert )
      if keyType is not None:
         hostCertKeyTypes.setdefault( keyType, [] ).append( cert )
   return hostCertKeyTypes
