# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

'''
Functions to create files in /persist/secure/signatureVerification.
These will be used during boot time for signature verification of 
varying file types (such as extensions). 

The files will keep track of the ssl profile (among other potential
things) containing the certs to verify the signature, since Sysdb
may not be up during the time of verification.
'''

from __future__ import absolute_import, division, print_function
from enum import Enum
import errno
import os

BASEDIR = "/persist/secure/signatureVerification"

class FILE_ATTR( object ):
   ''' The various attributes in the signature verification mapping
   file. 

      enforceSignature - bool of whether verification is enforced
      sslProfile - SSL profile name containing the certs to check with '''

   ENFORCE_SIGNATURE = 'enforceSignature'
   SSL_PROFILE = 'sslProfile'

class IMAGE_TYPE( Enum ):
   ''' Type of image file we are validating, ex "extension".

   Files will be created under /persist/secure/signatureVerification
   containing the ssl profile/enabled/etc data for these items using
   the value; ex /persist/secure/signatureVerification/extension. '''

   EXTENSION = 'extension'

class FileMgr( object ):

   def __init__( self, imageType ):
      try:
         imageEnum = IMAGE_TYPE( imageType )
      except ValueError:
         assert False, "Unsupported image type: %s" % imageType

      self.mapFilename = os.path.join( BASEDIR, imageEnum.value )

   def writeAttrs( self, attrDict ):
      ''' Write the attributes in attrDict into self.mapFilename.

      attrDict is expected to be a dict where key is in FILE_ATTR
      and value is... the value. It will be written as a newline-separated
      list of key:value pairs.'''
      try:
         os.makedirs( os.path.dirname( self.mapFilename ) )
      except OSError as e:
         if e.errno == errno.EEXIST:
            pass
         else:
            raise

      with open( self.mapFilename, 'w' ) as mapFile:
         for attrType, attrValue in attrDict.items():
            mapFile.write( "%s: %s\n" % ( attrType, attrValue ) )

   def readAttrs( self ):
      ''' Read the mapping file and return dictionary mapping FILE_ATTR
      attributes to their values '''
      try:
         with open( self.mapFilename, 'r' ) as mapFile:
            attrs = {}
            for line in mapFile:
               try:
                  delim = line.index( ':' )
                  key = line[ : delim ].strip()
                  val = line[ delim + 1 : ].strip()
                  if key and val:
                     attrs[ key ] = val
               except ValueError:
                  # No ':' in line
                  continue
            return attrs
      except IOError:
         return {}

   def deleteFile( self ):
      try:
         os.remove( self.mapFilename )
      except OSError:
         pass
