#!/usr/bin/env python
# Copyright (c) 2010  Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import errno
import fcntl
import optparse
import os
import re
import sys
import tempfile

from collections import defaultdict

import Tac

###########################################################################
##### Hardcoded values for different supported platforms ##################
###########################################################################
# Dictionary containing cpu/flash architecture groups and
# their member platforms. This maps platform names to an architecture
# group and is used when different actions are needed to be taken
# on different architectures.
# We now primarily use the "Aboot" cmdline option, but fall back on this
# for older Aboot versions.
platToArchGroup = {
   'norcal1' : [ ], # flashUtil was not used for norcal1
   'norcal2' : [ 'raven' ],
   'norcal3' : [ 'oak', 'blackbird', 'eaglepeak' ],
   'norcal4' : [ 'crow', 'mendocino' ],
   'norcal7' : [ 'oldfaithful' ],
   'norcal9' : [ 'woodpecker' ],
   'norcal10' : [ 'lorikeet' ],
}

pseudoTargetLUT = {
   'norcal7' : {
      'image' : ( 'fallback', 'normal' ),
   },
   'norcal9' : {
      'image' : ( 'coreboot', 'agesa', 'bootblock' ),
   },
   'norcal10' : {
      'normal' : ( 'sec_pei', 'dxe' ),
      'fallback' : ( 'bkp_sec_pei', 'bkp_dxe' ),
      'image' : ( 'agesa', 'sec_pei', 'bkp_sec_pei', 'dxe', 'bkp_dxe' ),
   },
}

LAYOUT_DIR = '/usr/share/EosAbootFirmwareUtils'

# For debugging purposes, to test different flashroms
flashrom = "flashrom"

###########################################################################
##### End hardcoded values section ########################################
###########################################################################

def run( cmd, verboseOut=False, stdoutSetting=Tac.CAPTURE ):
   if verboseOut:
      return Tac.run( cmd, asRoot=True )

   return Tac.run( cmd,
            stdout=stdoutSetting,
            stderr=Tac.DISCARD,
            asRoot=True )

def exitWithPrint( code, err=None, msg=None ):
   if msg:
      sys.stdout.write( msg + '\n' )
   if err:
      sys.stderr.write( err + '\n' )
   sys.exit( code )

def flashromDetect():
   out = run( [ flashrom ] )
   m = re.search( r"flash chip \"(.*?)\" \((\d+) kB, .*\)", out )
   if not m:
      return ( None, 0 )

   # * 1024 -> KB to B
   return ( m.group( 1 ), int( m.group( 2 ) ) * 1024 )

def _flashromWriteFile( flashromBaseCmd, fileToWrite, start, length ):
   _, spiTotalLen = flashromDetect()
   assert spiTotalLen is not None, "Failed to detect SPI flash size"

   if start != 0 or length != spiTotalLen:
      tmpFile = tempfile.NamedTemporaryFile()

      # first we make the image patch
      # All regions should be aligned at the 1KB boundary, if that is not the
      # case, we need adjust the block size we used.
      blockSize = 1024
      while blockSize > 1:
         if spiTotalLen % blockSize == 0 and start % blockSize == 0:
            break
         blockSize = blockSize / 2

      run( [ 'dd', 'if=/dev/zero', 'of=%s' % tmpFile.name,
             'bs=%d' % blockSize, 'count=%d' % ( spiTotalLen / blockSize ) ],
           opts.verbose )
      run( [ 'dd', 'if=' + fileToWrite, 'of=' + tmpFile.name,
             'bs=%d' % blockSize, 'seek=%d' % ( start / blockSize ),
           'conv=notrunc' ], opts.verbose )

      fileToWrite = tmpFile.name

   run( flashromBaseCmd + [ '-w', fileToWrite ], opts.verbose )

op = optparse.OptionParser( prog='flashUtil', usage='''
   %prog -- Wrapper Utility around flashrom to perform read
            write action on NorCal SPI flash.

   %prog [-v] [-n] [-p PLATFORM] <-r|w SECTION> <FILENAME>

   NorCal1 support is limited and should only be used with the 'total'
   SECTION option or none.
   On NorCal2 and up, the script can be directed to read or
   write any given SECTION.
   During writes <filename> provides the source data and during reads
   data is written to it.
   Caution should be used with the total option as the system can
   be rendered unbootable if a write does not complete.

   Section can be:
     total ( entire flash image )
     prefdl
     fdl
     mac
     image ( aboot )
     fallback

   * not all sections are supported on all platforms.
   ''' )

op.add_option( "-r", "--read", action='store', default=None,
      help="Read from Flash" )

op.add_option( "-w", "--write", action='store', default=None,
      help="Write to Flash" )

op.add_option( "-a", "--archGroup", action='store', default=None,
      help="Specify an archGroup (e.g. norcalN)" )

op.add_option( "-n", "--flash-name", dest='name',
         action='store_true', default=None, help="Only probe for flash chip name" )

op.add_option( "-v", "--verbose", dest='verbose',
         action='store_true', default=None, help="verbose output" )

opts, args = op.parse_args()

if opts.name:
   model, _ = flashromDetect()
   if model is not None:
      exitWithPrint( 0, msg=model )
   exitWithPrint( 0 )

read = opts.read
write = opts.write

# Check right num args
if len( args ) != 1:
   op.print_usage()
   exitWithPrint( 1, err="Invalid argument count." )

# Check that either read or write is specified
if read and not write:
   target = read
elif write and not read:
   flashrom = "flashrom-diag"
   target = write
else:
   op.print_usage()
   exitWithPrint( 1, err="Invalid option." )

fileArg = args[ 0 ]

def layoutPathFromPlatform( p ):
   return os.path.join( LAYOUT_DIR, p )

def loadLayoutFiles():
   res = defaultdict( lambda: {} )

   for fileName in os.listdir( LAYOUT_DIR ):
      with open( layoutPathFromPlatform( fileName ), 'r' ) as f:
         layoutContent = f.read().splitlines()
         for targetContent in layoutContent:
            m = re.search( r'([0-9a-fA-F]+):([0-9a-fA-F]+) (\w+)', targetContent )
            assert m
            start = int( m.group( 1 ), 16 )
            end = int( m.group( 2 ), 16 )
            name = m.group( 3 )
            length = end - start + 1
            res[ fileName ][ name ] = { 'start' : start, 'length' : length }

   return res

flashOffsetLUT = loadLayoutFiles()

# Read platform and archGroup from Aboot command line

platform = None
archGroup = None
try:
   cmdline = open( '/proc/cmdline', 'r' ).read()
except IOError:
   pass
else:
   platformRe  = re.search( r"platform=(.*?)($| |\n)", cmdline )
   archGroupRe  = re.search( r"Aboot=Aboot-(\w+)-", cmdline )
   if platformRe:
      platform = platformRe.group( 1 )
   if archGroupRe:
      archGroup = archGroupRe.group( 1 ).lower()
      # In case some Aboot version strings cause us to get garbage
      if archGroup not in flashOffsetLUT.keys():
         archGroup = None

# Allow user to override
if opts.archGroup:
   archGroup = opts.archGroup

# Older versions of Aboot did not include the archgroup in
# /proc/cmdline, so fall back on using the platform (eventually
# we should get rid of this code)
if not archGroup and platform:
   for key in platToArchGroup:
      if platform in platToArchGroup[ key ]:
         archGroup = key
         break

if not archGroup:
   exitWithPrint( 1,
         err="Cannot determine archGroup, please supply one with -a option." )

pod = flashOffsetLUT[ archGroup.lower() ] # Platform Offset Dictionary
pseudoTargetDict = pseudoTargetLUT.get( archGroup.lower(), {} )

if 'total' not in pod:
   exitWithPrint( 1, err="'total' section layout must be hardcoded for archGroup %s."
         % ( archGroup ) )

# Check that we have a valid target section
if target in pod:
   targets = ( target, )
elif target in pseudoTargetDict:
   if read:
      exitWithPrint( 1, err="PseudoTarget is not supported for read" )
   targets = pseudoTargetDict[ target ]
   valid = set( targets ).issubset( pod.keys() ) and len( targets ) > 1
   if not valid:
      exitWithPrint( 1,
                     err="PseudoTarget %s is invalid on archGroup %s."
                     % ( target, archGroup ) )
else:
   exitWithPrint( 1, err="Target section %s is not supported for archGroup %s."
         % ( target, archGroup ) )


totalLen = pod[ 'total' ][ 'length' ]

def writeTarget( region, layout ):
   start = pod[ region ][ 'start' ]
   length = pod[ region ][ 'length' ]
   # Check input file size
   imageSize = os.path.getsize( fileArg )
   cmd = [ flashrom, '-N', '-l', layout, '-i', region ]
   if imageSize != totalLen:
      if imageSize > length:
         exitWithPrint( 1, err="Image size doesn't match for target section %s (%d)."
                        % ( region, length ) )
      _flashromWriteFile( cmd, fileArg, start, length )
   else:
      _flashromWriteFile( cmd, fileArg, 0, totalLen )

def writePseudoTarget( regions, layout ):
   if os.path.getsize( fileArg ) != totalLen:
      exitWithPrint( 1, err="PseudoTarget needs a full-size image" )

   cmd = [ flashrom, '-N', '-l', layout ]
   for reg in regions:
      cmd += [ '-i', reg, ]
   _flashromWriteFile( cmd, fileArg, 0, totalLen )

def readTarget( region, layout ):
   start = pod[ region ][ 'start' ]
   length = pod[ region ][ 'length' ]

   cmd = [ flashrom, '-l', layout, '-i', region, '-q', '-r', '-' ]

   # Verbose is not supported with read as this will prevent us from getting the
   # expected flash content.
   flashromOutput = run( cmd, False )

   if len( flashromOutput ) < length:
      exitWithPrint( 1, err="Unexpected output size returned by flashrom" )

   dataOutput = flashromOutput[ start : start + length ]

   if fileArg != "-":
      with open( fileArg, 'w' ) as f:
         f.write( dataOutput )
   else:
      try:
         sys.stdout.write( dataOutput )
      except IOError as e:
         # Ignore broken pipes
         if e.errno != errno.EPIPE:
            raise

def acquireLock( lock, printStatus ):
   try:
      fcntl.flock( lock.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB )
   except IOError as e:
      if e.errno != errno.EAGAIN:
         raise

      if printStatus:
         print "Failed to acquire flashUtil exclusive lock."
         print "We will block until we get it."
      fcntl.flock( lock.fileno(), fcntl.LOCK_EX )

   if printStatus:
      print "flashUtil exclusive lock acquired."

# If we run multiple flashrom in parallel, we can get interferences.
with open( '/var/lock/flashUtil', 'w' ) as flashUtilLock:
   # Do not print lock information if we're asked to use stdout.
   acquireLock( flashUtilLock, fileArg != '-' )

   layoutFile = layoutPathFromPlatform( archGroup.lower() )
   if read:
      # Operation is a READ
      readTarget( target, layoutFile )
   else:
      # Operation is a WRITE
      # Program the target(s)
      if len( targets ) == 1:
         writeTarget( target, layoutFile )
      else:
         writePseudoTarget( targets, layoutFile )
