# Copyright (c) 2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import errno
import tempfile
import os
import re

import Tac
import Tracing
from HttpServiceConstants import ServerConstants
import SslCertKey
from M2Crypto import X509

ErrorType = Tac.Type( "Mgmt::Security::Ssl::ErrorType" )

traceHandle = Tracing.Handle( 'CapiSsl' )
warn = traceHandle.trace1
info = traceHandle.trace2
trace = traceHandle.trace3
debug = traceHandle.trace4

def getCertFilepath():
   return getSslRootDir() + ServerConstants.SSL_CERT_FILE

def getKeyFilepath():
   return getSslRootDir() + ServerConstants.SSL_KEY_FILE

def getSslRootDir():
   fsRoot = os.environ.get( "FILESYSTEM_ROOT", "" )
   trace( "FileSystem root is ", fsRoot )
   rootDir = fsRoot + ServerConstants.BASE_SSL_FILE_DIR
   if not os.path.exists( rootDir ):
      trace( "Creating directory to store SSL data", rootDir )
      os.makedirs( rootDir )
   return rootDir

def parseCertKey( certKey ):
   """Finds and returns the first certificate and key in a PEM file
   with both"""

   keyBegin = '-----BEGIN[^\n]* KEY-----'
   keyEnd = '-----END[^\n]* KEY-----'
   certBegin = '-----BEGIN[^\n]* CERTIFICATE-----'
   certEnd = '-----END[^\n]* CERTIFICATE-----'

   certBeginMatch = re.search( certBegin, certKey )
   certEndMatch = re.search( certEnd, certKey )
   keyBeginMatch = re.search( keyBegin, certKey )
   keyEndMatch = re.search( keyEnd, certKey )

   certificate = certKey[ certBeginMatch.start() : certEndMatch.end() ]
   sslKey = certKey[ keyBeginMatch.start() : keyEndMatch.end() ]
   return ( certificate, sslKey )

def writeSslAttrToFile( fileName, attrText ):
   """Writes the named attribute's value to a file.
   Returns an empty string if ok, otherwise returns
   an error message."""
   trace( '_writeSslAttrToFile entry', fileName )
   if not attrText:
      info( '_writeSslAttrToFile exit empty value' )
      return 'illegal empty value'

   try:
      with open( fileName, 'w' ) as f:
         f.write( attrText )
      os.chmod( fileName, 0777 )
   except IOError as e:
      warn( '_writeSslAttrToFile exit', e )
      return str( e )
   except OSError as e:
      if e.errno != errno.EPERM:
         warn( '_writeSslAttrToFile exit chmod', e )
         return str( e )
      debug( '_writeSslAttrToFile exit chmod', e )
   except EnvironmentError as e:
      warn( '_writeSslAttrToFile exit chmod', e )
      return str( e )
   trace( '_writeSslAttrToFile exit' )
   return ""

def validateCertificateAndKey( certificate, key ):
   # Write the certificate to a file and test it! 
   # This file will disappear once we leave this with block
   trace( 'validateCertificateAndKey enter' )
   with tempfile.NamedTemporaryFile( dir=getSslRootDir() ) as certFile:
      with tempfile.NamedTemporaryFile( dir=getSslRootDir() ) as keyFile:
         certFile.write( certificate )
         certFile.flush()
         os.fsync( certFile )
         keyFile.write( key )
         keyFile.flush()
         os.fsync( keyFile )
         return validateCertificateAndKeyFiles( certFile.name, keyFile.name )

def validateCertificateAndKeyFiles( certFileName, keyFileName ):
   trace( 'validateCertificateAndKeyFiles enter' )
   try:
      with open( certFileName, 'r' ) as f:
         certificate = f.read()

      with open( keyFileName, 'r' ) as f:
         key = f.read() 
   except IOError:
      trace( 'validateCertificateAndKeyFiles error exit' )
      return "SSL certificate/key error: Unable to read Certificate/Key files"

   try:
      SslCertKey.validateCertificate( certFileName, validateDates=True, 
                                      maxPemCount=None )
   except SslCertKey.SslCertKeyError:
      trace( 'validateCertificateAndKeyFiles error exit' )
      return ServerConstants.INVALID_SSL_CERT

   try:
      SslCertKey.validateRsaPrivateKey( keyFileName )
   except SslCertKey.SslCertKeyError: 
      trace( 'validateCertificateAndKeyFiles error exit' )
      return ServerConstants.INVALID_SSL_KEY

   # Now lets make sure the that the contents match
   certKeyMatch = True
   try:
      certKeyMatch = SslCertKey.isCertificateMatchesKey( certificate, key ) 
   except X509.X509Error:
      certKeyMatch = False
   if not certKeyMatch:
      trace( 'validateCertificateAndKeyFiles error exit' )
      return ServerConstants.INVALID_SSL_CERT_KEY
   
   return ""
