#!/usr/bin/env python
# Copyright (c) 2015 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import argparse
import os, shutil, tempfile, re
import subprocess 
import Tac, Tracing

traceHandle = Tracing.Handle( 'Sswan' )
t0 = traceHandle.trace0

ipsecConfigDir = "/etc/strongswan/"

ipsecConfigFile = "ipsec.conf"
ipsecSecretsFile = "ipsec.secrets"

diffieHillGroup = { '1' : '768' , '14' : '2048', '15' : '3072', '16' : '4096',
                    '19' : '256', '2' : '1024', '5': '1536' }
keyType = { 'pre-share' : 'secret', 'rsa-sig' : 'pubkey' }

authbyDefault = 'pre-share'
ikeLifetimeDef = 8
ipsecLifetimeDef = 1
keymargin = "3m"
ikeVersionDef = "ikev1"
autoDef = "add"

def isRunning():
   running = False
   sswanPidFile = '/etc/strongswan/ipsec.d/run/charon.pid'
   if os.path.exists( sswanPidFile ):
      with open( sswanPidFile, 'r' ) as pidFile:
         pid = pidFile.readline()
         
      if pid != '' and os.path.exists( '/proc/%s' % ( pid[ :-1 ], ) ):
         running = True
      
   return running      

def connectionCommand( intfName, enable=False ):
   if not isRunning():
      return
   if enable:
      subprocess.Popen( ["sudo", "strongswan", "up", intfName ] )
   else:
      subprocess.Popen( ["sudo", "strongswan", "down", intfName ] )


def secretCommand( parser, **kwargs ): # pylint: disable-msg=W0621
   secretsFilePath = os.path.join( ipsecConfigDir, ipsecSecretsFile )
   shutil.copyfile( secretsFilePath, secretsFilePath + ".bkup" )

   arguments = parser.parse_args()
   # Key is always prefixed with a \, to bypass the argument
   # parser from interpreting the special characters. Hence
   # removes the escape character if present
   if not arguments.delete and arguments.key[0] == "\\":
      arguments.key = arguments.key[1:]
   if arguments.secret == "%default":
      p = re.compile( ' : PSK.*' )
      p2 = re.compile( ' : PSK.*' )
   elif arguments.peerIp:
      p = re.compile( "%s :" % arguments.peerIp ) 
      p2 = re.compile( "%s :" % arguments.peerIp )

   found = False
   removed = False

   ## remove the old connection configuration and add the new one 
   fh, newPath = tempfile.mkstemp()
   with open( newPath, 'w') as newFile:
      with open( secretsFilePath ) as oldFile:
         for line in oldFile:
            if found and p2.match( line ):
               removed = True
            if not found and p.match( line ):
               found = True
            if not found or ( found and removed ):
               newFile.write( line )
            if found and not removed:
               removed = True
               continue 

      if not arguments.delete and arguments.secret == "%default":
         newFile.write(' : PSK %s\n' % ( arguments.key ) )
      else:
         newFile.write('%s : PSK %s\n' % (arguments.peerIp, arguments.key ) )
   os.close( fh )
   os.remove( secretsFilePath )
   shutil.move( newPath, secretsFilePath )
   

def isakmpCommand( parser, **kwargs): # pylint: disable-msg=W0621

   configFilePath = os.path.join( ipsecConfigDir, ipsecConfigFile )
   shutil.copyfile( configFilePath, configFilePath + ".bkup" )

   arguments = parser.parse_args()
   
   if arguments.default:
      p = re.compile( 'conn ' + arguments.default + '$' )
   elif arguments.config:
      p = re.compile( 'conn ' + arguments.config + '$' )
   p2 = re.compile( 'conn ' )
   found = False
   removed = False

   ## remove the old connection configuration and add the new one 
   fh, newPath = tempfile.mkstemp()
   with open( newPath, 'w') as newFile:
      with open( configFilePath ) as oldFile:
         for line in oldFile:
            if found and p2.match( line ):
               removed = True
            if not found and p.match( line ):
               found = True
            if not found or ( found and removed ):
               newFile.write( line )
            if found and not removed:
               continue 

      if not arguments.delete:
         if arguments.default:
            newFile.write( '\nconn %default' + '\n' )
            if arguments.authby: 
               authby = arguments.authby.replace("_", "-")
            else: 
               authby = authbyDefault
            newFile.write( "\tauthby=%s\n" % keyType[authby] )
            if arguments.ikelifetime: 
               ikeLifetime = arguments.ikelifetime
            else: 
               ikeLifetime = ikeLifetimeDef

            newFile.write( "\tikelifetime=%dh\n" % ikeLifetime )
            #Keylife is a synonym for lifetime. Lifetime : provides the time 
            #a connection is valid before rekeying occurs. A new connection is 
            #created'rekeymargin'
            if arguments.lifetime: 
               ipsecLifetime = arguments.lifetime
            else: 
               ipsecLifetime = ipsecLifetimeDef

            newFile.write( "\tlifetime=%dh\n" % ipsecLifetime )
            newFile.write( "\trekeymargin=3m\n" )
            newFile.write( "\tkeyingtries=1\n" )
            newFile.write( "\tmobike=no\n" )

            if arguments.ikeVersion: 
               ikeVersion = arguments.ikeVersion
            else: 
               ikeVersion = ikeVersionDef
            newFile.write( "\tkeyexchange=%s\n" % ikeVersion )
           
            newFile.write( "\tauto=%s\n" % autoDef )

            if arguments.esp:
               newFile.write( "\tesp=%s\n" % arguments.esp )
            if arguments.ike:
               newFile.write( "\tike=%s\n" % arguments.ike )
         else:
            newFile.write( '\nconn %s'% arguments.config + '\n' )
            if arguments.ikeVersion:
               newFile.write( "\tkeyexchange=%s\n" % arguments.ikeVersion )
            if arguments.authby:
               authby = arguments.authby.replace("_", "-")
               newFile.write( "\tauthby=%s\n" % keyType[authby] )
            if arguments.ikelifetime:
               newFile.write( "\tikelifetime=%sh\n" % arguments.ikelifetime )
            if arguments.lifetime:
               newFile.write( "\tlifetime=%sh\n" % arguments.lifetime )
            newFile.write( "\tleft=%s\n" % arguments.left )
            if arguments.leftsubnet:
               newFile.write( "\tleftsubnet=%s\n" % arguments.leftsubnet )
            if arguments.localid:
               newFile.write( "\tleftid=%s\n" % arguments.localid )
            newFile.write( "\tright=%s\n" % arguments.right )
            if arguments.remoteid:
               newFile.write( "\trightid=%s\n" % arguments.remoteid )
            newFile.write( "\tauto=%s\n" % autoDef )
            if arguments.ike:
               newFile.write( "\tike=%s\n" % arguments.ike )
            if arguments.esp:
               newFile.write( "\tesp=%s\n" % arguments.esp )
            if arguments.tunnelMode == "IpsecGre":
               newFile.write( "\tleftprotoport=47\n" )
               newFile.write( "\trightprotoport=47\n" )
            elif arguments.tunnelMode == "IpsecVti":
               newFile.write( "\tleftsubnet=0.0.0.0/0\n" )
               newFile.write( "\trightsubnet=0.0.0.0/0\n" )
            if arguments.mark:
               newFile.write( "\tmark=%s\n" % arguments.mark )
            if arguments.action:
               newFile.write( "\tdpdaction=%s\n" % arguments.action )
               newFile.write( "\tdpddelay=%s\n" % arguments.interval )
               newFile.write( "\tdpdtimeout=%s\n" % arguments.timeout )
            if arguments.encap:
               newFile.write( "\tforceencaps=yes\n" )
            if arguments.type:
               newFile.write( "\ttype=%s\n" % arguments.type )
            if arguments.replayWindowSize:
               newFile.write( "\treplay_window=%d\n" % arguments.replayWindowSize )
            else:
               newFile.write( "\treplay_window=0\n" )
            if arguments.packetLimit:
               # delete at packetLimit
               newFile.write( "\tlifepackets=%d\n" %
                              long( arguments.packetLimit ) )
               # rekey at 0.75 * packetLimit
               newFile.write( "\tmarginpackets=%d\n" %
                              long( 0.25 * arguments.packetLimit ) )
            if arguments.byteLimit:
               newFile.write( "\tlifebytes=%d\n" %
                              long( arguments.byteLimit ) )
               newFile.write( "\tmarginbytes=%d\n" %
                              long( 0.25 * arguments.byteLimit ) )

            newFile.write( "\tmobike=no\n" )


   os.close( fh )
   os.remove( configFilePath )
   shutil.move( newPath, configFilePath )
   

if __name__ == "__main__":
   parser = argparse.ArgumentParser()
   parser.set_defaults(method = isakmpCommand)
   group = parser.add_mutually_exclusive_group()
   group.add_argument( '--default', help="Set the Default profile" )
   group.add_argument( '--config', help="Set the Connection profile" )
   group.add_argument( '--secret', help="Set the Secrets key and value" )
   parser.add_argument( '--delete', help="Delete the Connection", 
                        action='store_true' )
   parser.add_argument( "--name", type=str, help="Connection name" )
   parser.add_argument( "--authby", type=str, help="Authentication mode" )
   parser.add_argument( "--ike", type=str, help="IKE parameteters" )
   parser.add_argument( "--ikeVersion", type=str, help="IKEv1 or IKEv2 mode" )
   parser.add_argument( "--esp", type=str, help="ESP parameteters" )
   parser.add_argument( "--ikelifetime", type=int, help="IKE lifetime" )
   parser.add_argument( "--lifetime", type=int, help="IPSEC SA lifetime" )
   parser.add_argument( "--left", type=str, help="Left tunnel Id" ) 
   parser.add_argument( "--leftsubnet", type=str, help="Left Subnet" ) 
   parser.add_argument( "--right", type=str, help="Right tunnel Id" ) 
   parser.add_argument( "--auto", type=str, help="Initiator/Reactor config" ) 
   parser.add_argument( "--peerIp", type=str, help="Peer IP Address" ) 
   parser.add_argument( "--key", type=str, help="Key value" ) 
   parser.add_argument( "--tunnelMode", type=str, help="GRE/VTI mode" ) 
   parser.add_argument( "--up", type=str, help="Enables Ipsec connection" ) 
   parser.add_argument( "--down", type=str, help="Disable Ipsec connection" ) 
   parser.add_argument( "--mark", type=str, help="packet mark for VTI"  ) 
   parser.add_argument( "--interval", type=str, help="DPD keepalive interval"  ) 
   parser.add_argument( "--timeout", type=str, help="DPD Timeout interval"  ) 
   parser.add_argument( "--action", type=str, help="DPD Action"  ) 
   parser.add_argument( "--encap", type=str, help="Force UDP encapsulation"  ) 
   parser.add_argument( "--type", type=str, help="Ipsec Mode"  ) 
   parser.add_argument( "--localid", type=str, help="Local identification" ) 
   parser.add_argument( "--remoteid", type=str, help="Remote peer identification"  ) 
   parser.add_argument( "--replayWindowSize", type=int,
                        help="IPsec Anti-Replay Window" )
   parser.add_argument( "--packetLimit", type=int, help="IPsec SA packet limit" )
   parser.add_argument( "--byteLimit", type=int, help="IPsec SA byte limit" )

   args = parser.parse_args()
   if args.default == "%default":
      isakmpCommand( parser )
   elif args.config and (args.delete or ( args.left and args.right )):
      isakmpCommand( parser )
   elif args.secret:
      secretCommand( parser )
      subprocess.Popen( [ "sudo", "strongswan", "rereadall" ] )
   elif args.up:
      connectionCommand( args.up, enable=True )
   elif args.down:
      connectionCommand( args.down, enable=False )

   subprocess.Popen( [ "sudo", "strongswan", "update" ] )
   if args.config and args.delete:
      t0(" Stop connection" )
      connectionCommand( args.config, enable=False )
