# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import ast
import base64
import os
import re
import subprocess
import ssl
from M2Crypto import EVP, X509

CHIP_SIGNING_ROOT_CA_PATH = "/etc/chip-signing-rootCa.crt"

def processWhitespace( string ):
   """Remove all whitespace from prefdl fields that are string encoded dicts, lists,
   or tuples. This is to make the signature verification process robust to minor
   and unintentional whitespace changes in the MFG step that programs the prefdl.
   """
   try:
      evalObj = ast.literal_eval( string )
   except ( ValueError, SyntaxError ):
      return string
   if isinstance( evalObj, ( dict, list, tuple ) ): 
      return string.replace( ' ', '' )
   else:
      return string

def getCertificate( prefdl ):
   """Returns the Signing certificate stored in the Certificate prefdl field
   in Base64.
   """
   m = re.search( r"Certificate: (.+)\n", prefdl )
   return m.group( 1 ) if m else ""

def getSignature( prefdl ):
   """Returns the Signature prefdl field in hex."""
   m = re.search( r"Signature: (\w+)\n", prefdl )
   return m.group( 1 ) if m else ""

def getSignatureList( prefdl ):
   """Returns the SignatureList prefdl field in ASCII."""
   m = re.search( r"SignatureList: (.+)\n", prefdl )
   return m.group( 1 ) if m else ""

def concatTpmPublicEndorsementKey( modulus, exponent ):
   """
   Concatenate the TPM Public Endorsement Key in a format
   usable for signature verification. Returned format is a
   hex string with the modulus followed by the exponent.
   """
   hexModulus = modulus.zfill( 512 )
   hexExponent = exponent.zfill( 6 )
   return hexModulus + hexExponent

def getTpmPublicEndorsementKey():
   """
   Returns the modulus and exponent as hex strings for later usage.
   """
   try:
      output = subprocess.check_output( [ 'tpmutil', 'readpubek' ] )
   except OSError:
      # Assume that tpmutil only exists in Aboot
      # tpm_readPubEk is the EOS equivalent for getting the TPM pub ek
      try:
         output = subprocess.check_output( [ "tpm_readPubEk" ] )
      except ( OSError, subprocess.CalledProcessError ):
         return None
   m = re.search( r"Modulus: (\w+)\n", output )
   modulus = m.group( 1 ) if m else "" 
   m = re.search( r"Exponent: 0x(\w+)\n", output )
   exponent = m.group( 1 ) if m else ""
   return (modulus, exponent)

def evalSignatureList( signatureListString ):
   """Convert SignatureList from a string to a list."""
   try:
      signatureList = ast.literal_eval( signatureListString )
   except ( ValueError, SyntaxError ):
      return None
   if isinstance( signatureList, list ):
      return signatureList
   else:
      return None

def calculatePayload( prefdl, tpmPublicKey, signatureList, signatureListString ):
   """Calculates and returns the expected payload given the state of the running
   system. Signing this payload with the Signing certificate's private key results
   in the Signature. This calculated value is used to check against the actual
   Signature stored in the prefdl.

   The payload is the concatenation of the following fields in the order shown:

   TPM public EK modulus (hex string, zero padded to 512 chars) +
   TPM public EK exponent (hex string, zero padded to 6 chars) +
   SignatureList (value only) + prefdl fields listed in SignatureList (in order,
   value only)

   e.g.
   prefdl:
   ...
   SerialNumber: SSJ17450458
   HwEpoch: 02.00
   SignatureList: ['SerialNumber','HwEpoch','SwFeatures']
   SwFeatures: {'feature1':'enabled','feature2':'disabled'}
   ...

   payload:
   ccdd8520b874ca7b0968b0401ed7febd4bba1b4983a7831698fe2d54157e50c22f42803486bb18a77
   e0a2e6d14571593976ca1259910c3cc7cbb675f6e22eb8d2b7e69c1a9eb7b0e5473ec3bc0a94d4abf
   a9009dd2436944d5cc5e12916f6b6e978e10fcc05ad13404801538442738aebc53ad1df1641ceb641
   3bb9d94b5a400959aff90f36af0da2782b0c5219643dea55f7f15837abfa4e048e1b89b41f56f1289
   bdaadb56a47155fa0f41ad48e272a5dc925a860b07af8f7b5d92e0640c98146a8f33fa4271b1f69d1
   aa3a9e4ee5e8823906835d43b32d8c1098dad6ac1f10a2393c50deb81b7cef8fd07adcc325cfd95eb
   be5e0eefd099f10a6a41fadf73 + 010001 +
   ['SerialNumber','HwEpoch','SwFeatures'] + SSJ17450458 + 02.00 +
   {'feature1':'enabled','feature2':'disabled'}
   """
   payload = tpmPublicKey
   payload += processWhitespace( signatureListString ) 

   for field in signatureList:
      m = re.search( r"%s: (.+)\n" % field, prefdl )
      val = m.group( 1 ) if m else ""
      payload += processWhitespace( val ) 
   return payload

def verifyCertificate( signingCert ):
   """Returns True if the Signing certificate stored in the Certificate prefdl
   field is signed by the Chip Signing Root CA. This verifies the chain of trust
   between the two certificates.
   """
   if not os.path.exists( CHIP_SIGNING_ROOT_CA_PATH ):
      return False
   chipSigningRootCa = X509.load_cert( CHIP_SIGNING_ROOT_CA_PATH )
   result = signingCert.verify( chipSigningRootCa.get_pubkey() )
   return result == 1

def verifySignatureFormat( signature ):
   """Returns True if the Signature is 256 bytes encoded in 512 hex chars."""
   if len( signature ) != 512:
      return False
   try:
      int( signature, 16 )
   except ValueError:
      return False
   return True

def verifySignature( signature, signingCert, payload ):
   """Returns True if the Signature stored in the Signature prefdl field is
   successfully verified against the running system. This is done by decrypting
   the Signature with the Signing certificate's public key and comparing the 
   result to the calculated payload.
   """
   pubKey = signingCert.get_pubkey().get_rsa()
   m = EVP.MessageDigest( 'sha256' )
   m.update( bytes( payload ) )
   digest = m.final()
   result = pubKey.verify_rsassa_pss( digest, signature.decode( 'hex' ), 'sha256',
                                      salt_length=-2 )
   return result == 1

def verifySignatureListFields( prefdl, signatureList ):
   """Returns a list of SignatureList members that are not successfully located
   in the prefdl.
   """
   missingFields = []
   for field in signatureList:
      m = re.search( r"\n%s: " % field, prefdl )
      if not m:
         missingFields.append( field )
   return missingFields

class VerificationError(Exception):
   def __init__( self, errorMsg ):
      Exception.__init__( self, errorMsg )

def verifyPrefdl( prefdl ):
   if not prefdl:
      raise VerificationError( "Failed to read prefdl" )

   certificate = getCertificate( prefdl )
   if not certificate:
      raise VerificationError( "Failed to read Certificate from prefdl" )

   try:
      derCertificate = base64.b64decode( certificate ) 
   except TypeError:
      raise VerificationError( "Certificate invalid - failed to decode as Base64" )

   try:
      pemCertificate = str( ssl.DER_cert_to_PEM_cert( derCertificate ) )
      signingCertX509 = X509.load_cert_string( pemCertificate )
   except X509.X509Error:
      raise VerificationError( 
              "Certificate invalid - failed to load into X509 object" )

   signature = getSignature( prefdl )
   if not signature:
      raise VerificationError( "Failed to read Signature from prefdl" )
   sigFormatValid = verifySignatureFormat( signature )
   if not sigFormatValid:
      raise VerificationError( "Signature is not 256-byte hex encoded string" )

   signatureListString = getSignatureList( prefdl )
   if not signatureListString:
      raise VerificationError( "Failed to read SignatureList from prefdl" )
   signatureList = evalSignatureList( signatureListString )
   if not signatureList:
      raise VerificationError( 
              "Failed to convert SignatureList to an iterable list" )
   missingFields = verifySignatureListFields( prefdl, signatureList )
   if missingFields != []:
      raise VerificationError( 
              "Failed to find %s in the prefdl" % str( missingFields ) )

   tpmPublicKeyComponents = getTpmPublicEndorsementKey()
   if not tpmPublicKeyComponents:
      raise VerificationError( "Failed to read TPM public endorsement key" )
   tpmPubKeyConcat = concatTpmPublicEndorsementKey( tpmPublicKeyComponents[0],
                                                    tpmPublicKeyComponents[1] )
   payload = calculatePayload( prefdl, tpmPubKeyConcat, signatureList,
                               signatureListString ) 

   certValid = verifyCertificate( signingCertX509 )
   if not certValid:
      raise VerificationError(
             "Chain of trust verification failed between Signing certificate " \
             "and Chip Signing Root CA" 
            )
   sigValid = verifySignature( signature, signingCertX509, payload )
   if not sigValid:
      raise VerificationError( 
              "Failed to verify Signature against expected payload"  )

   return tpmPublicKeyComponents
