# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import hashlib
import os
import shutil
import tempfile
import Tracing
import zipfile
import SignatureFile
import SignatureRequest
import SwiSignLib
import VerifySwi

logInfo = Tracing.Handle( 'AufLib' ).trace0
logWarn = Tracing.Handle( 'AufLib' ).trace1

AUF_DEV_SIGNING_CERT="/etc/swi-signing-devCA/aufsign.crt"
AUF_DEV_SIGNING_KEY="/etc/swi-signing-devCA/aufsign.key"
AUF_DEV_INT_CA_CERT="/etc/swi-signing-devCA/aufint.crt"
AUF_DEV_INT_CA_KEY="/etc/swi-signing-devCA/aufint.key"
AUF_DEV_ROOT_CA_CERT=VerifySwi.DEV_ROOT_CA_FILE_NAME
AUF_DEV_CHAIN = ( AUF_DEV_INT_CA_CERT, AUF_DEV_ROOT_CA_CERT )

# Placeholders for BUG516648
AUF_ARISTA_SIGNING_CA_CERT="/etc/swi-signature-aufSignCa.crt"
AUF_ARISTA_INT_CA_CERT="/etc/swi-signature-aufIntCa.crt"
AUF_ARISTA_ROOT_CA_CERT=VerifySwi.ARISTA_ROOT_CA_FILE_NAME
AUF_ARISTA_CHAIN = ( AUF_ARISTA_INT_CA_CERT, AUF_ARISTA_ROOT_CA_CERT )

certs = [ AUF_DEV_CHAIN,
          AUF_ARISTA_CHAIN ]

AUF_SIGN_URL = "https://license-dev.aristanetworks.com/sign/aboot-release/"

def calcSha( _file ):
   with open( _file ) as f:
      h = hashlib.sha256( f.read() )
      return h.hexdigest()

def copyFileSection( fromFile, toFile, start, length ):
   with open( fromFile, 'r' ) as from_, open( toFile, 'w' ) as to:
      from_.seek( start )
      to.write( from_.read( length ) )

class Payload( object ):
   def __init__( self, sha, imgPath, romOffsets ):
      self.sha = sha
      self.imgPath = imgPath
      self.offset = romOffsets[ 0 ]
      self.size = romOffsets[ 1 ]

class VersionField( object ):
   '''
   Smol class to represent a version field which may be a wildcard
   '''

   def __init__( self, v='x' ):
      self.r = None

      if not isinstance( v, str ):
         self.v = int( v )
      elif v == 'x':
         self.v = None
      elif "-" in v:
         t = v.split( "-" )
         self.v = int( t[ 0 ] )
         self.r = [ int( t[ 1 ] ) ]
      else:
         self.v = int( v )

   def __call__( self ):
      if self.r:
         return [ self.v ] + self.r
      return self.v

   def __len__( self ):
      if self.r:
         return len( self.r ) + 1
      return 1

   def __str__( self ):
      if self.r:
         return "-".join( [ str( x ) for x in [ self.v ] + self.r ] )

      return 'x' if self.v is None else str( self.v )

   def __add__( self, x ):
      return VersionField( self.v + x )

   def __sub__( self, x ):
      return VersionField( self.v - x )

   def __lt__( self, other ):
      
      return None if self.v is None else self.v < other

   def __gt__( self, other ):
      return True if self.v is None else self.v > other

   def __eq__( self, other ):
      return True if self.v is None else self.v == other

   def __ne__( self, other ):
      return False if self.v is None else self.v != other

class Auf( object ):
   """ Class that represents an .auf
   There are two parts, a temporary directory with the contents of an auf, and the
   auf itself. All operations work on the temporary directory, call genAuf to 
   update(remake) the auf file.
   
   If the passed in file already exists then we will parse it and unpack the contents
   into the temproray dir. The temp dir is removed once the object is destroyed
   """

   compatScript = "compat_check.sh"

   def __init__( self, aufFile ):
      self.aufFile = aufFile
      self.sections = {}
      self.layout = {}
      self.line = VersionField()
      self.major = VersionField()
      self.minor = VersionField()
      self.version = 1
      self.name = ""

      self.tmpdir = tempfile.mkdtemp()
      logInfo( "Working in temporary directory %s" % self.tmpdir )

      if os.path.isfile( self.aufFile ) and os.path.getsize( self.aufFile ) != 0:
         if not zipfile.is_zipfile( self.aufFile ):
            raise IOError( 'Invalid file format' )
         logInfo( "auf exists, parsing" )

         zf = zipfile.ZipFile( self.aufFile, mode='r' )
         zf.extractall( self.tmpdir )
         zf.close()
         self.parseLayout()
         with open( self.toTmpDir( "info" ), "r" ) as f:
            self.parseInfo( f.readlines() )

      if not os.path.isdir( self.toTmpDir( "payloads" ) ):
         os.makedirs( self.toTmpDir( "payloads" ) )

   def __del__( self ):
      shutil.rmtree( self.tmpdir )
   
   def delete( self ):
      os.remove( self.aufFile )

   def copyToTmpDir( self, from_, to ):
      shutil.copyfile( from_, self.toTmpDir( to ) )

   def toTmpDir( self, p ):
      return "%s/%s" % ( self.tmpdir, p )

   def parseInfo( self, lines ):
      """
      parse info file, e.g.
      version: 1
      line: 0
      major: 0
      minor: 0
      name: ...
      sections: section1 0ffe1..., section2 f6e0a...
      """

      lines = [ x.strip() for x in lines ]
      for line in lines:
         parts = line.split( ":" )
         if len( parts ) != 2:
            raise SyntaxError( "Invalid info line %s" % line )
         key = parts[ 0 ].strip()
         value = parts[ 1 ].strip()

         if key == "version":
            self.version = int( value )
         elif key == "line":
            self.line = VersionField( value )
         elif key == "major":
            self.major = VersionField( value )
         elif key == "minor":
            self.minor = VersionField( value )
         elif key == "name":
            self.name = value
         elif key == "sections":
            sections = [ x.strip() for x in value.split( "," ) ]
            for s in sections:
               if s == "":
                  continue
               parts =  s.split( " " )
               if len( parts ) != 2:
                  raise SyntaxError( "Invalid section format %s" % s )
               name, sha = parts[ 0 ], parts[ 1 ]
               if len( sha ) != 64:
                  raise TypeError( "Invalid hash" )

               imgLoc = self.toTmpDir( "payloads/%s.img" % name )
               if calcSha( imgLoc ) != sha:
                  raise ValueError( "hash mismatch" )

               offsets = self.layout[ name ]
               self.sections[ name ] = Payload( sha, imgLoc, offsets )
         else:
            raise SyntaxError( "Unknown key %s in info" % key )

   def updateInfo( self ):
      sectionStr = ", ".join(
            [ "%s %s" % ( sectionName, payload.sha ) 
               for sectionName, payload in self.sections.items() ])
      info = \
'''version: %d
line: %s
major: %s
minor: %s
name: %s
sections: %s
''' % ( self.version, self.line, self.major, self.minor, self.name, sectionStr )
      with open( self.toTmpDir( 'info' ), 'w' ) as i:
         i.write( info )

   def parseLayout( self ):
      """
      Parsing layout file to get info about sectionName e.g.
      21000:3FFFFF me
      2000:10FFF prefdl
      1000:1FFF mac
      20000:20FFF mfgdata
      1000:20FFF pdr
      A00000:FFFFFF normal
      400000:9EFFFF fallback
      """

      # Remove existing layout defs
      self.layout = {}
      # As the new layout may redefine where current payloads are located, clear
      # current payloads
      self.sections = {}

      with open( self.toTmpDir( "layout" ) ) as f:
         sections = [ x.strip().split(' ') for x in f.readlines() ]

      for section in sections:
         try:
            addrRange = section[ 0 ].split( ':' )
            start, end = (
                  int( addrRange[ 0 ], 16 ),
                  int( addrRange[ 1 ], 16 ) )
            self.layout[ section[ 1 ] ] = ( start, end-start+1 )
         except:
            raise SyntaxError( "Error parsing layout, line %s" % section )

   def addSectionFromRom( self, sectionName, rom ):
      if not os.path.exists( self.toTmpDir( "layout" ) ):
         raise IOError( "missing layout" )

      start, size = self.layout[ sectionName ]

      sectionImg = self.toTmpDir( "payloads/%s.img" % sectionName )
      copyFileSection( rom, sectionImg, start, size )

      checksum = calcSha( sectionImg )
      self.sections[ sectionName ] = Payload( checksum, sectionImg, ( start, size ) )

   def addCompatabilityScript( self, compat ):
      if os.path.exists( self.toTmpDir( self.compatScript ) ):
         logWarn( "Overriding compatability script" )
      self.copyToTmpDir( compat, self.compatScript )

   def getCompatibilityScript( self ):
      if os.path.exists( self.toTmpDir( self.compatScript ) ):
         return self.toTmpDir( self.compatScript )
      return None

   def getSections( self ):
      return self.sections

   def genAuf( self ):
      self.updateInfo()
      zf = zipfile.ZipFile( self.aufFile, mode='w' )
      zf.write( self.toTmpDir( "info" ), "info" )
      zf.write( self.toTmpDir( "layout" ), "layout" )
      if os.path.exists( self.toTmpDir( self.compatScript ) ):
         zf.write( self.toTmpDir( self.compatScript ), self.compatScript )
      for sectionName in self.sections:
         loc = "payloads/%s.img" % sectionName
         zf.write( self.toTmpDir( loc ), loc )
      zf.close()

   def genSignedAuf( self, useDevCA=False, user=None, passwd=None ):
      if os.path.exists( self.toTmpDir( "swi-signature" ) ):
         os.remove( self.toTmpDir( "swi-signature" ) )
      self.genAuf()

      sig = SignatureFile.Signature()
      aufData = SignatureFile.prepareDataForServer( self.aufFile, self.version, sig )
      try:
         if useDevCA:
            sigData = SignatureRequest.getDataFromDevCA( self.aufFile, aufData,
                    devCaKeyPair=( AUF_DEV_SIGNING_CERT, AUF_DEV_SIGNING_KEY ) )
         else:
            sigData = SignatureRequest.getDataFromServer( self.aufFile, aufData,
                    licenseServerUrl=AUF_SIGN_URL,
                    user=user, passwd=passwd )
         SignatureFile.generateSigFileFromServer( sigData, self.aufFile, sig )
      except SignatureRequest.SigningServerError, e:
         if os.path.exists( self.toTmpDir( "swi-signature" ) ):
            os.remove( self.toTmpDir( "swi-signature" ) )
         self.genAuf()
         raise e

   def _isSigned( self, cert ):
      valid = SwiSignLib.verifySwiSignature( self.aufFile, rootCA=cert[ 0 ] )
      intCertValid = VerifySwi.verifyAufSig( cert[ 1 ], cert[ 0 ] )
      return valid[ 0 ] and intCertValid == VerifySwi.VERIFY_SWI_RESULT.SUCCESS

   def isSigned( self, useDevCA=False ):
      cert = AUF_DEV_CHAIN if useDevCA else AUF_ARISTA_CHAIN
      return self._isSigned( cert )

   def addLayout( self, layout ):
      if os.path.exists( self.toTmpDir( "layout" )  ):
         logWarn( "Overriding layout" )
      self.copyToTmpDir( layout, "layout" )
      self.parseLayout()

   def setAbootVersion( self, version ):
      v = version.split(".")
      try:
         self.line = VersionField( v[ 0 ] )
         self.major = VersionField( v[ 1 ] )
         self.minor = VersionField( v[ 2 ] )
      except:
         raise SyntaxError( "Invalid version string %s" % version )

   def setName( self, name ):
      self.name = name

   def __str__( self ):
      s = "File %s\nAuf version %d, Aboot version %s.%s.%s\n" % \
          ( self.aufFile, self.version, self.line, self.major, self.minor )
      s += "Name: %s\n" % self.name

      signed = SwiSignLib.swiSignatureExists( self.aufFile )
      if signed:
         for cert in certs:
            if self._isSigned( cert ):
               s += "Auf signed:\n"
               s += "\tCert %s, %s\n" % ( cert[ 0 ], cert[ 1 ] )

      s += "Sections:\n"
      for sectionName, payload in self.sections.items():
         s += "\t%s: offset: 0x%x bytes, size 0x%x bytes\n\t\tsha256(%s)\n" % \
               ( sectionName, payload.offset, payload.size, payload.sha )
      return s

