#!/usr/bin/env python
# Copyright (c) 2009 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import argparse
import collections
import functools
import hashlib
import importlib
import operator
import os
import shutil
import subprocess
import sys
import tempfile
import yaml

import EosVersionValidator
import TableOutput

from Swix import schema


def sha1sum( filename ):
   try:
      f = file( filename, "r" )
      h = hashlib.sha1()
      while True:
         chunk = f.read( 65536 )
         if chunk == "":
            break
         h.update( chunk )
      f.close()
      return h.hexdigest()
   except Exception, e: 
      sys.exit( "Error computing sha1 sum for %s: %s\n" % ( filename, e ) )


def createManifest( filename, primaryRpm, rpms ):
   """Given a set of RPM file paths, creates a manifest file with the specified
   filename."""
   lines = []
   lines.append( "format: 1" )
   lines.append( "primaryRpm: %s" % os.path.basename( primaryRpm ) )
   for rpm in [ primaryRpm ] + rpms:
      basename = os.path.basename( rpm )
      lines.append( "%s-sha1: %s" % ( basename, sha1sum( rpm ) ) )
   try:
      outfile = file( filename, "w" )
      outfile.write( "\n".join( lines ) )
      outfile.write( "\n" )
      outfile.close()
   except Exception, e:
      sys.exit( "%s: %s\n" % ( filename, e ) )


def renderVersionTable( versionsToRpms ):
   headings = [ "EOS Versions", "Compatible RPMs" ]
   table = TableOutput.createTable( headings )
   versionFormat = TableOutput.Format( justify="left", maxWidth=20, wrap=True )
   versionFormat.noPadLeftIs( True )
   versionFormat.noTrailingSpaceIs( True )
   rpmsFormat = TableOutput.Format( justify="left", maxWidth=20, wrap=True )
   rpmsFormat.noPadLeftIs( True )
   rpmsFormat.noTrailingSpaceIs( True )
   table.formatColumns( versionFormat, rpmsFormat )

   for version in sorted( versionsToRpms ):
      rpmStr = ""
      for rpm in sorted( versionsToRpms[ version ] ):
         rpmStr += rpm + ", "
      if rpmStr != "":
         rpmStr = rpmStr[ :-2 ]
      table.newRow( version, rpmStr )
   print table.output()
   confirm = raw_input( "The above table shows which RPMs will be installed "
                        "for each EOS version using the YAML file packaged "
                        "with this swix. Are the versions and RPMs correct? "
                        "[y/n]: " )
   if not ( confirm == "y" or confirm == "Y" ):
      sys.exit( "Abort: Undesired versions/RPMs in install.yaml file." )


def checkInfo( infoFile, allRpms, noReleaseDb ):
   """Given a YAML file path, check if that file is well-formed and valid. If
   so, return True; Else ends with an error message"""
   # pylint: disable-msg=too-many-nested-blocks
   with open( infoFile, "r" ) as stream:
      try:
         info = yaml.safe_load( stream )
         schema.checkInfoSchema( info )
         if "version" in info:
            versionInfo = info[ "version" ]
            versionsToRpms = collections.defaultdict( set )
            for versionGroup in versionInfo:
               version = versionGroup.keys()[ 0 ]
               matchedVersions = {}
               if not noReleaseDb:
                  matchedVersions = \
                     EosVersionValidator.getMatchingVersions( version )
               # We should always check the version syntax
               else:
                  EosVersionValidator.parse( version )

               for matched in functools.reduce( operator.iconcat,
                                                matchedVersions.values() , [] ):
                  for filename in versionGroup:
                     if isinstance( filename, dict ):
                        # This is a file with instructions; get the name.
                        filename = filename.keys()[ 0 ]
                     versionsToRpms[ matched ].add( filename )

               for rpm in versionGroup[ version ]:
                  if isinstance( rpm, dict ):
                     rpm = rpm.keys()[ 0 ]
                  if rpm != "all" and not rpm in allRpms:
                     sys.exit( "%s required for version %s doesn't exist" %
                               ( rpm, version ) )
            if versionsToRpms:
               renderVersionTable( versionsToRpms )
         return True
      except Exception, e:
         sys.exit( "%s: %s\n" % ( infoFile, e ) )


def create( filename, primaryRpm, rpms, args=None, sign=False ):
   # Check if "--force" flag has been used
   force = ( args and args.force )
   if os.path.exists( filename ):
      if force:
         os.remove( filename )
      else:
         msg = "File %s exists: use --force to overwrite\n" % filename
         sys.exit( msg )
   # Try to create the SWIX file
   try:
      outfile = file( filename, "w" )
      dir = tempfile.mkdtemp( suffix=".dir",
                              prefix=os.path.basename( filename ), dir="." )
      manifest = os.path.join( dir, "manifest.txt" )
      createManifest( manifest, primaryRpm, rpms )
      filesToZip = [ manifest, primaryRpm ] + rpms
      if args and args.info:
         allRpms = [ os.path.basename( x ) for x in [ primaryRpm ] + rpms ]
         if checkInfo( args.info, allRpms, args.no_release_db ):
            filesToZip.append( args.info )
      # The -j arg causes zip to strip the directory path from filenames so
      # the output zip archive contains no directories
      p = subprocess.Popen( [ "zip", "-", "-0", "-j" ] + list( filesToZip ),
                            stdout=outfile )
      p.communicate()
      assert p.returncode == 0

      if sign:
         try:
            SwixSign = importlib.import_module( 'Swix.sign' )
            retCode = SwixSign.sign( filename, forceSign=True )
            if retCode != SwixSign.SWIX_SIGN_RESULT.SUCCESS:
               sys.exit( "Error occured during SWIX signing: %s\n" % retCode )
            else:
               print "SWIX %s successfully signed!" % filename
         except ImportError:
            # Swix signing only available in devel environments
            print "Skipping SWIX signing because the service is unavilable."

   except Exception, e:
      sys.exit( "Error occurred during generation of SWIX file: %s\n" % e )
   finally:
      shutil.rmtree( dir, ignore_errors=True )


def parseCommandArgs( args ):
   # Define parser for "swix create"
   parser = argparse.ArgumentParser( prog="swix create" )
   parser.add_argument( 'outputSwix', metavar="OUTFILE.swix",
                        help="Name of output file" )
   parser.add_argument( 'rpms', metavar="PACKAGE.rpm", type=str, nargs='+',
                        help='An RPM to add to the swix' )
   parser.add_argument( '-f', '--force', action="store_true",
                        help='Overwrite OUTFILE.swix if it already exists' )
   parser.add_argument( '-i', '--info', metavar="manifest.yaml", action='store',
                     type=str,
                     help='Location of manifest.yaml file to add metadata to swix' )
   parser.add_argument( "-n", "--no-release-db", action="store_true",
                        help="Do not check the release DB for version information" )
   sign_parser = parser.add_mutually_exclusive_group( required=False )
   sign_parser.add_argument( "--sign", dest='sign', action='store_true',
                             help="Sign the SWIX after creation" )
   sign_parser.add_argument( "--no-sign", dest='sign', action='store_false',
                             help="Do not sign the SWIX after creation (default)" )
   parser.set_defaults( sign=False )
   return parser.parse_args( args )


def createHandler( args=sys.argv[1:] ):
   args = parseCommandArgs( args )
   create( args.outputSwix, args.rpms[ 0 ], args.rpms[ 1: ], args, 
           args.sign )

if __name__ == "__main__":
   createHandler()
