#!/usr/bin/env python
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import os
import re
import shutil
import xml.etree.ElementTree as et
import requests
import time
import json
import Tac
from EosCloudInitLib import userData, runCmd, emptyUserData, \
   processUserData, readFile, writeFile, getAllIntfs, metadataIp, \
   generateUserConfig, setupConnection, cleanupExit, firstTimeBoot, getDnsIp

configs = {}

## wire server request constants
xmsVersion = '2015-04-05'
waagentName = 'WALinuxAgent'
wireServerIp = '168.63.129.16'
instanceIdentityUrl = ( "http://%s/metadata/instance/compute?api-version=2019-08-01"
                        % metadataIp )
instanceIdFile = Tac.Type( "CloudUtils::CloudAttr" ).instanceIdFile
ovfEnvFile = '/mnt/flash/.ovf-env.xml'

def retrieveGoalStateXml():
   try:
      response = requests.get( 'http://%s/machine/?comp=goalstate' % wireServerIp,
                               headers={
                                  'x-ms-version': xmsVersion,
                                  'x-ms-agent-name': waagentName
                               }, timeout=5 )
   except requests.exceptions.Timeout:
      return None
   return response.text

def getFingerPrint( certFile ):
   """ Convert from openssl x509 fingerprint format:
   'SHA1 Fingerprint=E2:FE:09:38:0C:83:F1:A6:2A:FB:7B:97:57:9B:\
         D9:8E:99:C9:9B:1A'
   to Azure fingerprint format:
   'E2FE09380C83F1A62AFB7B97579BD98E99C99B1A'
   """
   fingerPrint, _ = runCmd( [ 'sh', '-c',
                   ( '''openssl x509 -noout -fingerprint < "$0" ''' ), certFile ] )

   splitEq = fingerPrint.find( '=' )
   octets = fingerPrint[ splitEq + 1 : -1 ].split( ':' )
   return ''.join( octets )

def parseCerts( pemPath ):
   """ Parse out fingerprints and associated SSH keys from the certificates
   present in the certificate file.
   """
   certs = readFile( pemPath )
   tmpCert = []
   sshKeys = {}
   tmpCertFile = "/tmp/tmpCertFile"
   for line in certs:
      tmpCert.append( line )
      if re.match( r'[-]+END .*?KEY[-]+$', line ):
         ## ignore private keys
         tmpCert = []
      elif re.match( r'[-]+END .*?CERTIFICATE[-]+$', line ):
         cert = ''.join( tmpCert )
         writeFile( tmpCertFile, cert )
         sshKey, _ = runCmd( [ 'sh', '-c',
                      ( '''openssl x509 -noout -pubkey < "$0" |'''
                        ''' ssh-keygen -i -m PKCS8 -f /dev/stdin''' ),
                      tmpCertFile ] )
         fingerPrint = getFingerPrint( tmpCertFile )
         os.remove( tmpCertFile )
         sshKeys[ fingerPrint ] = sshKey
         tmpCert = []
   return sshKeys

def getFingerprintFromOvfEnv():
   try:
      rootElmt = et.parse( ovfEnvFile ).getroot()
   except IOError:
      print "Couldn't find file " + ovfEnvFile
      return None
   provElmt = rootElmt.find(
      './/{http://schemas.microsoft.com/windowsazure}ProvisioningSection' )
   try:
      sshKeysElmt = provElmt.find(
         './/{http://schemas.microsoft.com/windowsazure}SSH' )
      publicKeysElmt = sshKeysElmt.find(
         './/{http://schemas.microsoft.com/windowsazure}PublicKeys' )
      publicKeyElmt = publicKeysElmt.find(
         './/{http://schemas.microsoft.com/windowsazure}PublicKey' )
      fingerPrintElmt = publicKeyElmt.find(
         './/{http://schemas.microsoft.com/windowsazure}Fingerprint' )
      fingerPrint = fingerPrintElmt.text
      ## If there are multiple public keys just log it for now
      ## Currently, we support only matching the first fingerprint
      publicKeys = publicKeysElmt.findall( 
         './/{http://schemas.microsoft.com/windowsazure}PublicKey' )
      if len( publicKeys ) > 1: 
         print "Multiple public keys found"
   except AttributeError:
      print "SSH-Key not contained in ovf-env.xml file"
      return None
   return fingerPrint

def copySshKeys():
   transportPriv = '/tmp/TransportPriv.pem'
   transportPub = '/tmp/TransportPub.pem'
   p7mPath = '/tmp/Certificates.p7m'
   pemPath = '/tmp/Certificates.pem'
   keyPath = '/mnt/flash/key.pub'

   goalStateXml = retrieveGoalStateXml()
   if goalStateXml is None:
      return

   ## retrieve certificates URI
   certificatesElmt = et.fromstring( goalStateXml ).find( 'Container' ).\
                      find( 'RoleInstanceList' ).find( 'RoleInstance' ).\
                      find( 'Configuration' ).find( 'Certificates' )
   if certificatesElmt is not None:
      certificatesUri = certificatesElmt.text
   else:
      print "No ssh-keys found"
      return

   ## generate transport certificate
   runCmd( [ 'openssl', 'req', '-x509', '-nodes', '-subj', '/CN=LinuxTransport',
             '-days', '32768', '-newkey', 'rsa:2048', '-keyout',
             transportPriv, '-out', transportPub ] )
   transportKey = ''.join( readFile( transportPub )[ 1: -1 ] ).replace(
      '\n', '' )

   ## fetch certificates, extract data from xml to p7m
   certificatesXml = requests.get( certificatesUri,
                                   headers={
                                      'x-ms-version': xmsVersion,
                                      'x-ms-agent-name': waagentName,
                                      'x-ms-cipher-name': 'DES_EDE3_CBC',
                                      'x-ms-guest-agent-public-x509-cert': \
                                         transportKey
                                   } ).text
   encryptedCert = et.fromstring( certificatesXml ).find( 'Data' ).text
   p7mContents = [ '''\
MIME-Version:1.0
Content-Disposition: attachment; filename="/var/lib/waagent/Certificates.p7m"
Content-Type: application/x-pkcs7-mime; name="/var/lib/waagent/Certificates.p7m"
Content-Transfer-Encoding: base64

''' ]
   p7mContents.append( encryptedCert )
   writeFile( p7mPath, p7mContents )

   ## decode p7m and generate pem file
   runCmd( [ 'sh', '-c',
             ( '''openssl cms -decrypt -in %s -inkey '''
               '''%s -recip %s '''
               '''| openssl pkcs12 -nodes -password pass: '''
               '''-out %s''' %
               ( p7mPath, transportPriv, transportPub, pemPath ) ) ] )

   ## copy x509 pem certificate to id_rsa key pub file
   sshKeys = parseCerts( pemPath )
   fingerPrint = getFingerprintFromOvfEnv()
   sshKey = sshKeys.get( fingerPrint, None )
   if sshKey is not None:
      with open( keyPath, 'w' ) as keyFile:
         keyFile.write( sshKey )
   else:
      print "Couldn't find matching fingerprint from ovf-env.xml file"
   os.remove( transportPriv )
   os.remove( transportPub )
   os.remove( p7mPath )
   os.remove( pemPath )
   ## Should we remove this ovfEnvFile?
   ## os.remove( ovfEnvFile )

def srDevName():
   ## get sr device name
   pattern = r'(sr[0-9]+)'
   devDir = '/dev'
   srDevs = [ dev for dev in os.listdir( devDir )
              if re.match( pattern, dev ) is not None ]
   if srDevs:
      return '/dev/%s' % srDevs[ 0 ]
   else:
      return None

def readProvisioningDrive( srDev ):
   ## mount sr device
   srMount = "/mnt/sr_disk"
   os.makedirs( srMount )
   _, rc = runCmd( [ "mount", srDev, srMount] )
   if rc != 0:
      cleanupExit( "Could not mount %s" % srDev, rc )

   ## read provisioning information and user data from drive
   provisionFile = "ovf-env.xml"
   try:
      shutil.copy2( os.path.join( srMount, provisionFile ), ovfEnvFile )
   except IOError:
      print "ovf-env.xml could not be copied"
   rootElmt = et.parse( os.path.join( srMount, provisionFile ) ).getroot()
   provElmt = rootElmt.find(
      './/{http://schemas.microsoft.com/windowsazure}ProvisioningSection' )
   hostNameElmt = provElmt.find(
      './/{http://schemas.microsoft.com/windowsazure}HostName' )
   userNameElmt = provElmt.find(
      './/{http://schemas.microsoft.com/windowsazure}UserName' )
   userPasswordElmt = provElmt.find(
      './/{http://schemas.microsoft.com/windowsazure}UserPassword' )
   hostName = hostNameElmt.text if hostNameElmt is not None else None
   userName = userNameElmt.text if userNameElmt is not None else None
   userPassword = userPasswordElmt.text if userPasswordElmt is not None else None

   ## copy custom data file from sr device to user data file
   customData = "CustomData.bin"
   try:
      shutil.copyfile( os.path.join( srMount, customData ), userData )
   except IOError:
      print "CustomData.bin could not be copied to .userdata; " \
         "using default empty user-data instead"
      with open( userData, 'w' ) as userDataFile:
         userDataFile.write( emptyUserData )

   ## unmount sr device
   runCmd( [ "umount", srMount ] )

   ## add provisioning config lines
   configLines = []
   if hostName:
      configLines.append( 'hostname %s\n' % hostName )
   if userName:
      configLines.extend(
         generateUserConfig( userName, userPassword ) )

   return configLines

def setupIntfLink():
   ## Azure hypervisor expects interface link to be up so that network
   ## packet are sent over the synthetic and bonded Accelerated networking
   ## interface immediately after the driver initializing. Otherwise the
   ## hyperVisor ejects the Accelerated networking interface bonded to
   ## synthetic interface because of the Azure failsafe logic.
   ## Solution is to bring the link up on all synthetic interfaces.
   netDevs = getAllIntfs()
   for netDev in netDevs:
      output, rc = runCmd( [ "ethtool", "-i", netDev ] )
      if rc == 0:
         drv = re.search( 'driver: (.+)', output ).group( 1 )
         if drv == 'hv_netvsc':
            runCmd( ["ip", "link", "set", "dev", netDev, "up"] )
            ## run dhclient on each interface and kill it
            runCmd( [ "dhclient", netDev ] )
            time.sleep( 2 )
            runCmd( [ 'pkill', 'dhclient' ] )

def getMlxIntfs():
   netDevs = getAllIntfs()
   mlxDevs = []
   for netDev in netDevs:
      output, rc = runCmd( [ "ethtool", "-i", netDev ] )
      if rc == 0:
         drv = re.search( 'driver: (.+)', output ).group( 1 )
         if drv == 'mlx4_en':
            mlxDevs.append( netDev )
   return mlxDevs

def setMlxOffloads():
   netDevs = getMlxIntfs()
   for netDev in netDevs:
      _, rc = runCmd( [ "ethtool", "-K",  netDev, "rx", "off",
                        "tx", "off", "sg", "off", "tso", "off",
                        "lro", "off", "gro", "off" ] )
      if rc != 0:
         print "Set Offload failed with error: %d" % rc

def cloudInit():
   ## Bringup the links on all synthetic interfaces
   setupIntfLink()

   ## Turn off mellanox vf offloads
   setMlxOffloads()
   ## setup connection to gateway
   setupConnection( ping=False )

   instanceIdentity, rc = runCmd( [ "curl", "-H", "Metadata:true",
                                  instanceIdentityUrl ] )
   if rc != 0:
      cleanupExit( "Could not get the instanceId. Skipping downloading user data. " +
                   "RC %d" % rc, rc )
   try:
      instanceId = json.loads( instanceIdentity )[ 'resourceId' ]
      with open( instanceIdFile, "w" ) as fd:
         fd.write( instanceId )
   except( IOError, KeyError ):
      cleanupExit( "Failed to write instance id file" )

   configLines = None

   if firstTimeBoot():
      ## get sr device name
      srDev = srDevName()

      ## read provisoning info from drive if attached
      if srDev is None:
         print "No provisioning drive attached"
      else:
         configLines = readProvisioningDrive( srDev )
         ## set DNS server (required for Waagent in ARM)
         configLines.append( 'ip name-server %s\n' % getDnsIp() )

      ## copy ssh keys over to /mnt/flash/key.pub from wire server
      copySshKeys()
   else:
      ## PLACEHOLDER recovery user-data download here
      pass

   ## process first time boot
   processUserData( configs, extraConfig=configLines )
