#!/usr/bin/python
# Copyright (c) 2009, 2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import errno
import hashlib
import os
import re
from datetime import datetime

import Ark
import Cell
import Logging
import SwiSignLib
import Tac
import Tracing
import Url

from ExtensionMgr import (
      errors,
      logs,
      pypilib,
      rpmutil,
      yumlib,
)

DEFAULT_STATUS_FILE = "/var/run/extension-status"

# pkgdeps: library MgmtSecuritySsl
MgmtSslConstants = Tac.Type( "Mgmt::Security::Ssl::Constants" )

# This is the most useful (quietest) form of RPM output;
# the lower integer value RPMLOG_EMERG still emits faulty
# RPM format errors.
def rpmHook( rpm_ ):
   rpm_.setVerbosity( rpm_.RPMLOG_CRIT )

# As this module is imported by the CLI, we need to use
# the lazy importer, because rpm is a large memory
# consumer blacklisted in /src/Eos/test/CliTests.py
rpm = Ark.LazyModule( 'rpm', rpmHook )

__defaultTraceHandle__ = Tracing.Handle( 'ExtensionMgr' )
t0 = Tracing.trace0
t1 = Tracing.trace1
t2 = Tracing.trace2
t3 = Tracing.trace3

VALID_REPO_FORMATS = ( 'yum', 'pypi' )

# Presence error messages for Extension::Info records
errUnsupportedFormat = 'Unsupported format'
errInvalidSwix = 'Invalid SWIX file'

INSTALL_FUNCS =  { 'formatRpm': rpmutil.installRpm,
                   'formatYum': rpmutil.installYum,
                   'formatSwix': rpmutil.installSwix,
                   # 'formatPyPi': rpmutil.installPyPi,
                }

# File to log to, syslog if None
LOGFILE = None

def log( *args ):
   if LOGFILE:
      with open( LOGFILE, 'a' ) as logf:
         msg = datetime.now().isoformat() + ' ' + str( args[0] ) +\
               ' '.join( args[ 1: ] ) + '\n'
         logf.write( msg )
   else:
      Logging.log( *args )

# Sysdb paths
def configPath():
   return Cell.path( 'sys/extension/config' )

def repoConfigPath():
   return 'sys/extension/repoConfig'

def statusPath():
   return Cell.path( 'sys/extension/status' )


def sha1sum( filename ):
   """Computes a sha1sum of the file contents and returns it as a hex string.
   Does no error checking: any exceptions that occur propagate to the caller."""
   with open( filename ) as f:
      h = hashlib.sha1()  # pylint: disable-msg=E1101
      while True:
         chunk = f.read( 65536 )
         if chunk == '':
            break
         h.update( chunk )
   return h.hexdigest()

def latestExtensionForName( filename, extensionStatus ):
   """Returns the extension with the given name that has the highest
   generation id, or None if no extension with that name has been
   added."""
   best = None
   for info in extensionStatus.info.values():
      # If file system is case-insensitive, then do case-insensitive compare
      fs = Url.getFilesystem( "extension:" )
      if ( ( fs.ignoresCase() and info.filename.lower() == filename.lower() ) or
           ( not fs.ignoresCase() and info.filename == filename ) ):
         if best is None or best.generation < info.generation:
            best = info 
   return best

def installedExtensionForName( filename, extensionStatus ):
   """Returns the extension with the given name that is installed.  Only one
   extension with a given name may be installed. Returns None if no extension
   with that name has been installed."""
   for info in extensionStatus.info.values():
      if ( info.filename == filename
           and info.status in ( 'installed', 'forceInstalled' ) ):
         return info 
   return None

def getPackageFormat( path ):
   prefixes = { 'yum:': 'formatYum',
                'pypi:': 'formatPyPi',
             }
   suffixes = { '.rpm': 'formatRpm',
                '.tar.gz': 'formatPyPi',
                '.swix': 'formatSwix',
             }
   for prefix in prefixes:
      if path != prefix and path.startswith( prefix ):
         return prefixes[ prefix ]
   for suffix in suffixes:
      if path != suffix and path.endswith( suffix ):
         return suffixes[ suffix ]
   return 'formatUnknown'

def readExtensionRpmData( path, fmt=None ):
   """Pass in the path to an extension file and the type of file it is (in case of 
   temporary file name) and return a dictionary mapping
   rpm name to rpm info. May raise ExtensionReadError."""
   transactionSet = rpmutil.newTransactionSet()
   fn = os.path.basename( path )
   if not fmt:
      fmt = getPackageFormat( path )
   if fmt == 'formatRpm':
      header = rpmutil.readRpmHeaderIntoDict( transactionSet, path )
      return { fn: header }
   elif fmt == 'formatSwix':
      import zipfile
      if zipfile.is_zipfile( path ):
         _, rpmHeaderInfo = rpmutil.readSwix( path, transactionSet )
         return rpmHeaderInfo
   assert False, 'Unsupported extension format: %d' % fmt
   return {}

def readExtensionFile( path, extensionStatus, transactionSet=None, pType=None, 
                       deps=None ):
   """Pass me the path to an extension file, a rpm.TransactionSet instance,
   and an instance of Extension::Status and I will examine the file and
   populate an entry in Extension::Status::info.  The status of the entry
   will indicate whether the extension is valid or not."""
   if transactionSet is None:
      transactionSet = rpmutil.newTransactionSet()
   fn = os.path.basename( path )
   if pType:
      fmt = pType
   else:
      fmt = getPackageFormat( path )
   latest = latestExtensionForName( fn, extensionStatus )
   if latest is None:
      generation = 1
   else:
      generation = latest.generation + 1

   key = Tac.Value( 'Extension::InfoKey', fn, fmt, generation )
   info = extensionStatus.info.newMember( key )
   info.filepath = path
   info.presence = 'absent'
   if not fmt or fmt == 'formatUnknown':
      t0( path, 'is not an extension! Skipping...' )
      info.presenceDetail = errUnsupportedFormat
   elif fmt == 'formatRpm' or fmt == 'formatYum':
      try:
         header = rpmutil.readRpmHeaderIntoDict( transactionSet, path )
         rpmutil.addRpm( info, fn, header )
      except errors.RpmReadError as e:
         t0( 'Error reading rpm', path, ':', e )
         info.presenceDetail = str( e )
         return

      if deps:
         basepath = os.path.dirname( path ) + '/'
         for dep in deps:
            # skip the pimary rpm
            if dep == fn or dep == '':
               continue
            fpath = basepath + dep 
            header = rpmutil.readRpmHeaderIntoDict( transactionSet, fpath )
            rpmutil.addRpm( info, dep, header )

      info.primaryPkg = fn
      info.presence = 'present'
      info.presenceDetail = ''
   elif fmt == 'formatSwix':
      # This import must be delayed until here, not done at module
      # level, due to zipfile being blacklisted in CliTests.py.
      import zipfile
      if not zipfile.is_zipfile( path ):
         t0( path, 'is not a zipfile! Skipping...' )
         info.presenceDetail = errInvalidSwix
         return
      rpmutil.saveSwixInfo( info, transactionSet )
   elif fmt == 'formatPyPi':
      # XXX handle pypi archives
      pass
   else: # unknown format type
      t0( path, 'has unsupported type', fmt )
      info.presenceDetail = errUnsupportedFormat

def readExtensionsDir( dirpath, extensionStatus, continueOnError=False ):
   """Pass me a directory path and an instance of Extension::Status and I will
   look for extension files in the directory, examine them, and populate the
   Extension::Status instance.  The caller is responsible for handling any
   exception raised during reading the directory."""
   t0( "readExtensionsDir for", dirpath )
   ents = os.listdir( dirpath )
   ts = rpmutil.newTransactionSet()
   for ext in ents:
      path = os.path.join( dirpath, ext )
      if not os.path.isfile( path ):
         continue
      t0( "processing possible extension file", path )
      try:
         readExtensionFile( path, extensionStatus, ts )
      except Exception as e:  # pylint: disable-msg=W0703
         # We really want to catch the user top level exception
         t0( "Caught exception while processing", path, ":", e )
         if not continueOnError:
            raise

def _internalCheckInstallPrerequisites( extensionInfo ):
   ei = extensionInfo
   if ei.presence != 'present':
      msg = "Installation failed"
      if ei.presenceDetail:
         msg += ": " + ei.presenceDetail
      raise errors.InstallError( msg )
   if ei.status in ( 'installed', 'forceInstalled' ):
      raise errors.InstallError( "Extension is already installed" )
   pr = ei.package.get( ei.primaryPkg )
   if pr is None:
      raise errors.InstallError( "Extension %s cannot be installed: missing "
                                 "primary RPM\n" % ei.filename )

def checkInstallPrerequisites( extensionInfo ):
   try:
      _internalCheckInstallPrerequisites( extensionInfo )
      return ( True, None )
   except errors.InstallError as e:
      return ( False, str( e ) )

def installExtension( status, info, force=False ):
   """Installs an extension.

   Provide an instance of Extension::Status and  Extension::Info and 
   I will install the extension by following the procedure described in 
   AID 522.  The process in which I am called must have an effective uid 
   of 0 because I must have permission to manipulate the RPM database.

   Raises:
     - InstallError: if installation fails.
   """
   assert os.geteuid() == 0, 'Not running with EUID of 0'
   _internalCheckInstallPrerequisites( info )
   pr = info.package.get( info.primaryPkg )
   try:
      fingerprint = sha1sum( info.filepath )
   except Exception as e:
      raise errors.InstallError( "Error computing SHA-1 fingerprint of %s: %s"
                                 % ( info.filename, e ) )
   log( logs.EXTENSION_INSTALLING, info.filename, pr.version, fingerprint )
   try:
      install = INSTALL_FUNCS.get( info.format )
      if install is not None:
         install( status, info, force )
      else:
         raise errors.InstallError( "Unsupported extension format '%s'" %
                                    info.format )
      log( logs.EXTENSION_INSTALLED, info.filename )
   except ( errors.InstallError, Exception ) as e:  # pylint: disable-msg=W0703
      log( logs.EXTENSION_INSTALL_ERROR, info.filename, str( e ) )
      raise

def _internalUninstallCheck( extensionInfo ):
   ei = extensionInfo
   if ei.status not in ( 'installed', 'forceInstalled' ):
      raise errors.UninstallError( "Only installed extensions may be uninstalled" )

def checkUninstallPrerequisites( extensionInfo ):
   try:
      _internalUninstallCheck( extensionInfo )
      return ( True, None )
   except errors.UninstallError as e:
      return ( False, str( e ) )

def uninstallExtension( extensionInfo, extensionStatus, force=False ):
   """Uninstall an extension.

   The process in which I am called must have an effective uid of 0 because
   I must have permission to manipulate the RPM database.

   Args:
     - extensionInfo: The instance of Extension::Info of the extension to
       uninstall by following the procedure described in AID 522.
     - extensionStatus: Our Extension::Status.
     - force: Whether or not to force the uninstall.

   Raises:
     - UninstallError: if I fail.
   """
   assert os.geteuid() == 0, "Must run as root, not UID %d" % os.geteuid()
   _internalUninstallCheck( extensionInfo )
   filename = extensionInfo.filename
   log( logs.EXTENSION_UNINSTALLING, filename )

   ts = rpmutil.newTransactionSet()

   if force:
      rpmutil.setForceProblemFlags( ts )
   else:
      ts.setProbFilter( 0 )

   try:
      try:
         # We want to unmount in the reverse mounting order in case it matters.
         while extensionInfo.mountpoints:
            mountpoint = extensionInfo.mountpoints.pop()
            Tac.run( [ 'umount', mountpoint ], asRoot=True, stderr=Tac.CAPTURE )
            try: # Remove if empty.
               os.rmdir( mountpoint )
            except EnvironmentError as e:
               # ENOTEMPTY (39) could happen if we've mounted on non-empty dir.
               # EBUSY (16) could happen if we've mounted on the same dir twice.
               if e.errno not in ( errno.ENOTEMPTY, errno.EBUSY ):
                  raise errors.UninstallError( 'Failed to remove directory %r: %s ' %
                                               ( mountpoint, e ) )
      except Tac.SystemCommandError as e:
         # Unmounting didn't work. Reinsert the mountpoint in the extension info.
         extensionInfo.mountpoints.push( mountpoint )
         error = 'Failed to unmount %r' % mountpoint

         # If the unmounting failed due to files open, find the procs and report.
         if 'target is busy' in e.output:
            # '+c0' means don't truncate the procs' names.
            output = Tac.run( [ 'lsof', '+c0', mountpoint ], stdout=Tac.CAPTURE )
            procs = output.splitlines()[ 1: ] # Behead the output.
            procs = { tuple( line.split()[ :2 ] ) for line in procs }
            error += ' due to files kept open by '
            error += ', '.join( '%s (pid %s)' % procAndPid for procAndPid in procs )

         raise errors.UninstallError( error )

      packages = [ r.packageName for r in 
                   extensionInfo.installedPackage.itervalues() ]
      try:
         rpmutil.uninstallRpms( packages, ts, force )
      except errors.RpmUninstallError, e:
         extensionInfo.statusDetail = str( e )
         raise errors.UninstallError( "RPM uninstall error: %s" % e )

      extensionInfo.installedPackage.clear()
      extensionInfo.status = 'notInstalled'
      extensionInfo.statusDetail = ''
      # If this extension is absent (e.g. its file was removed) then
      # also garbage collect its associated status as there is no way
      # it can be re-installed anyway.
      if extensionInfo.presence == 'absent':
         del extensionStatus.info[ extensionInfo.key ]

      log( logs.EXTENSION_UNINSTALLED, filename )
   except Exception as e:  # pylint: disable-msg=W0703
      log( logs.EXTENSION_UNINSTALL_ERROR, filename, str( e ) )
      raise

def addToConfig( config, name, pFormat, forced, signatureIgnored=False ):
   """Takes an Extension::Config, an extension name, an Extension::Format, a 
   a boolean indicating whether the extension was forcibly installed, and 
   a boolean indicating whether to ignore the signature during verification,
   then adds it to the config installation collection with an index higher 
   than any index currently in the collection."""
   highest = 0
   installed = config.installation.values()
   for extension in installed:
      if extension.filename == name:
         # This is unexpected -- why was this extension already in the config?
         # There's no point to adding the same extension twice.
         return
      if extension.index > highest:
         highest = extension.index
   index = highest + 1
   item = config.installation.newMember( name, pFormat, index )
   item.force = forced
   item.signatureIgnored = signatureIgnored

def removeFromConfig( config, name ):
   """Takes an Extension::Config and an extension name, then removes the entry
   with that name from the config installation collection."""
   installed = config.installation.values()
   for extension in installed:
      if extension.filename == name:
         try:
            del config.installation[ extension.index ]
         except Exception as e:  # pylint: disable-msg=W0703
            # We really want to catch the user top level exception
            t0( "Failed to remove item from config.installation:", e )
         break

  
def saveInstalledExtensions( sysdbRoot, dstFile ):
   """Writes the installed extensions to dstFile in the boot-extensions
   format. If any installed extensions do not have valid signatures and were
   not installed with "signature-verification ignored", SignatureVerificationError
   is raised with error message saying which extensions are affected."""
   config = sysdbRoot.entity[ configPath() ]
   status = sysdbRoot.entity[ statusPath() ]
   mgmtSecConfig = sysdbRoot.entity[ 'mgmt/security/config' ]
   mgmtSslConfig = sysdbRoot.entity[ 'mgmt/security/ssl/config' ]
   checkSigs = signatureVerificationEnabled( mgmtSecConfig )
   warningMsg = ''
   signingCerts = []
   if checkSigs:
      try:
         signingCerts = signatureVerificationCerts( mgmtSecConfig, mgmtSslConfig )
      except errors.SignatureVerificationError as e:
         warningMsg = "Skipping SWIX signature verification: " + e.message
         checkSigs = False

   installed = sorted( config.installation.values(), key=lambda a: a.index )
   invalidSwixSigs = []
   for installCfg in installed:
      info = latestExtensionForName( installCfg.filename, status )
      # Don't add swix to boot-extensions file if it needs to be signed properly
      # and isn't
      if ( checkSigs and not installCfg.signatureIgnored and 
           installCfg.format == 'formatSwix' ):
         swixPath = info.filepath
         sigValid, _ = SwiSignLib.verifySwixSignature( swixPath, signingCerts )
         if not sigValid:
            invalidSwixSigs.append( installCfg.filename )
            continue

      bootExtLine = "%s %s%s" % ( installCfg.filename,
                                  " force " if installCfg.force else " no ",
                                  installCfg.format )

      # For yum/rpm packages we need to write their dependencies that exist
      # in /mnt/flash/.extensions. The RPM transactionSet requires
      # them for proper installation.
      if installCfg.format == 'formatRpm' or installCfg.format == 'formatYum':
         if info:
            bootExtLine += ' ' + str( info.package.keys() )
         else:
            t3( 'Error! Installed pkg %s found but no info found' %
                installCfg.filename )

      dstFile.write( bootExtLine + '\n' ) 
   if invalidSwixSigs:
      warningMsg = ( "%s not copied to boot-extensions because of missing or"
                     " invalid signature." ) % ", ".join( invalidSwixSigs )
   if warningMsg:
      raise errors.SignatureVerificationError( warningMsg )

def checkIfUpgrading( status, newInfo ):
   """
   If we are upgrading an extension to a new version we have to mark the
   old extension as not installed. RPM will take care of actually upgrading
   the system itself. This only takes care of Sysdb state.
   """
   pkgNew = newInfo.package[ newInfo.primaryPkg ]
   for info in status.info.itervalues():
      if info.status in ( 'installed', 'forceInstalled' ):
         pkg = info.package[ info.primaryPkg ]
         # packageName is the actual RPM package name
         if pkg.packageName == pkgNew.packageName:
            return info
   return None

def getInstalledVersion( info ):
   pkg = info.package[ info.primaryPkg ]
   if info.key.format == 'formayPypi':
      # BUG94891 once we support pip packages we need to get this
      # from the actual system. Right now I'm returning
      # what will always match so the warning does not mprint.
      return ( pkg.version, pkg.release )
   else:
      # yum, rpm, and swix types
      ts = rpmutil.newTransactionSet()
      hdrs = ts.dbMatch( rpm.RPMTAG_NAME, pkg.packageName )
      if len( hdrs ) != 1:
         # RPM may have been removed under the covers
         return ( None, None )
      hdr = hdrs.next()
      return ( hdr[ rpm.RPMTAG_VERSION ], hdr[ rpm.RPMTAG_RELEASE ] )

# Package manager interface

# Implementation modules for package formats
impls = { 'formatYum': yumlib,
          'formatPyPi': pypilib }


def updateRepository( repo ):
   if repo is None:
      return
   impl = impls.get( repo.format )
   if impl is not None:
      impl.updateRepository( repo )


def deleteRepository( repo ):
   if repo is None:
      return
   impl = impls.get( repo.format )
   if impl is not None:
      impl.deleteRepository( repo )


def packageDownload( name, repos, repotype, status ):
   t0( 'packageDownload', name, 'from type', repotype )
   impl = impls.get( repotype )
   if impl is not None:
      impl.packageDownload( name, repos, status )


def packageSearch( query, repos ):
   t0( 'packageSearch', query )

   yumRepos = []
   pypiRepos = []
   for repo in repos:
      if repo.format == 'formatYum':
         yumRepos.append( repo )
      elif repo.format == 'formatPyPi':
         pypiRepos.append( repo )

   pkgs = []
   for repo in yumRepos:
      pkgs.extend( impls[ 'formatYum' ].packageSearch( query, repo ) )
   for repo in pypiRepos:
      pkgs.extend( impls[ 'formatPyPi' ].packageSearch( query, repo ) )
   return pkgs

repoTypeStrMap = {
   'formatYum': 'yum',
   'formatPyPi': 'pypi',
   'formatSwix': 'swix',
   'formatUnknown': 'unknown'
   }

def signatureVerificationEnabled( mgmtSecConfig ):
   ''' Returns true if signature-verification is enabled '''
   return mgmtSecConfig and mgmtSecConfig.enforceSignature 

def signatureVerificationCerts( mgmtSecConfig, mgmtSslConfig ):
   ''' Returns the list of potential signing certs to use for signature
   verification '''
   sslProfileName = mgmtSecConfig.sslProfile
   sslProfile = mgmtSslConfig.profileConfig.get( sslProfileName )
   if not sslProfile:
      msg = "Unable to verify SWIX signatures. SSL profile '%s' does not exist"
      raise errors.SignatureVerificationError( msg % sslProfileName )
   certs = [ MgmtSslConstants.certPath( cert ) for cert in sslProfile.trustedCert ]
   if not certs:
      msg = "Unable to verify SWIX signatures. No trusted certificates defined"
      raise errors.SignatureVerificationError( msg )
   return certs

def verifySignature( swixFile, mgmtSecConfig, mgmtSslConfig ):
   ''' Returns false if SWIX signature is invalid '''
   signingCerts = signatureVerificationCerts( mgmtSecConfig, mgmtSslConfig )
   sigValid, _ = SwiSignLib.verifySwixSignature( swixFile, signingCerts )
   return sigValid

def printExtensionInfo( output, info, sigValid=None ):
   signedStr = ''
   if sigValid:
      signedStr = 'signed'
   elif sigValid is False:
      # False (the SWIX is not signed properly) != None (this isn't a swix)
      signedStr = 'notSigned'

   def _safestr( s ):
      return re.sub( "[\t\n]", " ", s )
   installed = " ".join( info.installedPackage.keys() )
   line = "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s" % (
      info.key.filename, info.key.format,
      info.presence, _safestr( info.presenceDetail ), info.status,
      _safestr( info.statusDetail ), signedStr, installed )
   t0( "status:", line )
   output.write( line + '\n' )

# save current extensions to a file
def saveInstalledExtensionStatus( sysdbRoot, dstFile ):
   """This is similar to saveInstalledExtensions, but the file is for
   LoadExtensionStatus instead of LoadExtension."""
   if not dstFile:
      return
   t0( "save status for all extensions:", dstFile )
   config = sysdbRoot.entity[ configPath() ]
   status = sysdbRoot.entity[ statusPath() ]

   try:
      tmp = dstFile + '.tmp'
      with open( tmp, "w" ) as output:
         for installCfg in sorted( config.installation.values(),
                                   key=lambda a: a.index ):
            info = latestExtensionForName( installCfg.filename, status )
            if info:
               # Here is a bit of inconsistency:
               #
               # printExtensionInfo expects sigValid to be 3 values: None/True/False
               # which translate to ''/'signed'/'notSigned'. But LoadExtensionStatus
               # will turn it into signatureIgnored which is a boolean (notSigned
               # means True). We try to convert the boolean to the 3 values but
               # essentially None and True are the same.
               if info.format == "formatSwix":
                  sigValid = not installCfg.signatureIgnored
               else:
                  sigValid = None
               printExtensionInfo( output, info, sigValid=sigValid )
            else:
               t3( "Error: no info found for installed package",
                   installCfg.filename )

      os.rename( tmp, dstFile )
   except IOError, e:
      # ignore the error and continue
      log( logs.EXTENSION_STATUS_SAVE_ERROR, dstFile, e.strerror )
