# Copyright (c) 2016 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import os
import re

import Arnet
import Assert
import BasicCli
import BasicCliUtil
import CliCommand
import CliMatcher
import CliParser
import CliPlugin.Ssl as Ssl
import CliPlugin.IpGenAddrMatcher as IpGenAddrMatcher
import CliPlugin.Security as Security
import CliMode.Pki as PkiProfileMode
import CommonGuards
import ConfigMount
import LazyMount
import SslCertKey
import Tac
import Tracing

config = None
status = None
execRequest = None

Constants = Tac.Type( 'Mgmt::Security::Ssl::Constants' )
Digest = Tac.Type( 'Mgmt::Security::Ssl::Digest' )

__defaultTraceHandle__ = Tracing.Handle( 'PkiCli' )
t0 = Tracing.trace0
pkiNotReadyError = 'PKI not ready'

pkiEnKwMatcher = CliMatcher.KeywordMatcher( 'pki',
   helpdesc='Configure PKI related options' )
keyKwMatcher = CliMatcher.KeywordMatcher( 'key', helpdesc='modify keys used' )
generateNode = CliCommand.Node(
      matcher=CliMatcher.KeywordMatcher( 'generate', helpdesc='create new item' ),
      guard=CommonGuards.standbyGuard,
      noResult=True )
rsaMatcher = CliMatcher.KeywordMatcher( 'rsa', helpdesc='Use RSA algorithm' )
validRsaKeySizes = [ 2048, 3072, 4096 ]
rsaKeySizeMap = { '%d' % bitLength: 'Use %d-bit keys' % bitLength
                  for bitLength in validRsaKeySizes }
rsaKeySizeMatcher = CliMatcher.EnumMatcher( rsaKeySizeMap )
certKwMatcher = CliMatcher.KeywordMatcher( 'certificate',
      helpdesc='work with x509 certificate' )
signRequestKwMatcher = CliMatcher.KeywordMatcher( 'signing-request',
      helpdesc='Certificate Signing Request ( CSR )' )
digestKwMatcher = CliMatcher.KeywordMatcher( 'digest',
      helpdesc='Digest to sign with' )
digests = { Digest.sha256: 'Use 256 bit SHA',
            Digest.sha384: 'Use 384 bit SHA',
            Digest.sha512: 'Use 512 bit SHA' }
validityMatcher = CliMatcher.KeywordMatcher( 'validity',
      helpdesc='Validity of certificate' )
dnsNameRe = r'[0-9a-zA-Z_\.-]+'
digestMatcher = CliMatcher.EnumMatcher( digests )

class PkiProfileConfigMode( PkiProfileMode.PkiProfileMode,
                            Ssl.ProfileConfigModeBase ):
   name = 'PKI profile configuration'
   modeParseTree = CliParser.ModeParseTree()

   def __init__( self, parent, session, profileName ):
      PkiProfileMode.PkiProfileMode.__init__( self, profileName )
      Ssl.ProfileConfigModeBase.__init__( self, parent, session, profileName,
                                          'profileTypePki' )

def _gotoPkiProfileConfigMode( mode, args ):
   profileName = args[ 'PROFILE_NAME' ]
   childMode = mode.childMode( PkiProfileConfigMode, profileName=profileName )
   if childMode.profileConfig_.profileType != 'profileTypePki':
      mode.addError( 'Not a PKI profile' )
      return
   mode.session_.gotoChildMode( childMode )

def _noPkiProfile( mode, args ):
   profileName = args[ 'PROFILE_NAME' ]
   profileConfig =  config.profileConfig.get( profileName )
   if profileConfig and profileConfig.profileType != 'profileTypePki':
      mode.addError( 'Not a PKI profile' )
      return
   del config.profileConfig[ profileName ]

class GotoPkiProfileModeCmd( CliCommand.CliCommandClass ):
   syntax = 'pki profile PROFILE_NAME'
   noOrDefaultSyntax = syntax
   data = {
            'pki': pkiEnKwMatcher,
            'profile': Ssl.profileMatcher,
            'PROFILE_NAME': Ssl.profileNameMatcher
          }
   # The config command 'pki profile' is still hidden.
   # We are only giving enable mode 'pki' commands.
   hidden = True

   handler = _gotoPkiProfileConfigMode
   noOrDefaultHandler = _noPkiProfile

Security.SecurityConfigMode.addCommandClass( GotoPkiProfileModeCmd )

#--------------------------------------------------
# [no|default] certificate CERT_NAME key KEY_NAME
#--------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.CertificateCmd )

#--------------------------------------------------
# [no|default] trust certificate CERT_NAME
#--------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.TrustCertificateCmd )
PkiProfileConfigMode.addCommandClass( Ssl.NoTrustCertificateCmd )

#--------------------------------------------------
# [no|default] chain certificate CERT_NAME
#--------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.ChainCertificateCmd )
PkiProfileConfigMode.addCommandClass( Ssl.NoChainCertificateCmd )

#--------------------------------------------------
# [no|default] crl command
#--------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.CrlCmd )
PkiProfileConfigMode.addCommandClass( Ssl.NoCrlCmd )

#--------------------------------------------------
# [no|default] certificate requirement hostname match
#--------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.VerifyHostnameActionCmd )

#--------------------------------------------------------
# [no|default] certificate requirement extended-key-usage
#--------------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.VerifyExtendedParametersCmd )

#----------------------------------------------------------------------------
# [no|default] [trust|chain] certificate requirement basic-constraint ca true
#----------------------------------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.VerifyChainBasicConstraintTrustCmd )

#--------------------------------------------------
# [no|default] chain certificate requirement include root-ca
#--------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.ChainIncludeRootCACmd )

#---------------------------------------------------
# [no|default] certificate policy expiry-date ignore
#---------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.IgnoreExpiryDateCmd )

#------------------------------------------------------------------------------------
# security pki key generate rsa RSA_BIT_LENGTH KEY_NAME
#------------------------------------------------------------------------------------
class PkiGenerateKey( CliCommand.CliCommandClass ):
   syntax = 'security pki key generate rsa RSA_BIT_LENGTH KEY_NAME'
   data = {
            'security': Security.securityKwMatcher,
            'pki': pkiEnKwMatcher,
            'key': keyKwMatcher,
            'generate': generateNode,
            'rsa': rsaMatcher,
            'RSA_BIT_LENGTH': rsaKeySizeMatcher,
            'KEY_NAME': Ssl.keyNameMatcher,
          }

   @staticmethod
   def handler( mode, args ):
      '''
      Generate a key pair for use in PKI.

      Saves under sslkey: with key if successful 
      '''
      key = args[ 'KEY_NAME' ]
      keyAlgo = args[ 'RSA_BIT_LENGTH' ]
      if not SslCertKey.isSslDirsCreated():
         mode.addError( pkiNotReadyError )
         return

      genericError = 'Error generating key'
      # Assume that key is always RSA for now
      # keyAlgo holds the key size
      Assert.assertIn( int( keyAlgo ), validRsaKeySizes )
      keyFile = os.path.join( Constants.keysDirPath(), key )
      keyLogAction, keyHash = SslCertKey.getLogActionAndFileHash( keyFile, 'sslkey:',
                                                       'created' )
      try:
         SslCertKey.generateRsaPrivateKey( keyFile, int( keyAlgo ) )
         SslCertKey.generateSslKeyCertSysLog( keyFile, 'sslkey:',
                                              keyLogAction, keyHash )
      except SslCertKey.SslCertKeyError as e:
         mode.addError( '%s (%s)' % ( genericError, str( e ) ) )
      except StandardError as e:
         mode.addError( '%s' % genericError )

BasicCli.EnableMode.addCommandClass( PkiGenerateKey )

#------------------------------------------------------------------------------------
# security pki certificate generate { signing-request | self-signed <cert-name> }
#             [ key <key-name> [ generate rsa <2048|3072|4096> ] ] 
#             [ digest <sha256|sha384|sha512> ]
#             [ validity <days> ]
#             [ parameters common-name <common-name>
#                          [ country <country-code> ]
#                          [ state <state-name> ]
#                          [ locality <locality-name> ]
#                          [ organization <org-name> ]
#                          [ organization-unit <org-unit-name> ]
#                          [ email <email> ]
#                          [ subject-alternative-name [ ip <ip1 ip2 ..> ] 
#                                                     [ dns <name1 name2 ..> ] 
#                                                     [ email <em1 em2 ..> ] ] ]
#------------------------------------------------------------------------------------
class ParamExpression( CliCommand.CliExpression ):
   expression = ( 'parameters common-name COMMON_NAME'
                         '[ country COUNTRY_CODE ]'
                         '[ state STATE_NAME ]'
                         '[ locality LOCALITY_NAME ]'
                         '[ organization ORG_NAME ]'
                         '[ organization-unit ORG_UNIT_NAME ]'
                         '[ email EMAIL ]'
                         '[ subject-alternative-name { ( ip { IP } ) | '
                                                    '( dns { DNS } ) | '
                                                    '( EMAIL_KW { SAN_EMAIL } ) } ]'
                         )
   data = {
            'parameters': 'Signing request parameters',
            'common-name': 'Common name for use in subject',
            'COMMON_NAME': CliMatcher.QuotedStringMatcher(),
            'country': 'Two-Letter Country Code for use in subject',
            'COUNTRY_CODE': CliMatcher.QuotedStringMatcher(),
            'state': 'State for use in subject',
            'STATE_NAME': CliMatcher.QuotedStringMatcher(),
            'locality': 'Locality Name for use in subject',
            'LOCALITY_NAME': CliMatcher.QuotedStringMatcher(),
            'organization': 'Organization Name for use in subject',
            'ORG_NAME': CliMatcher.QuotedStringMatcher(),
            'organization-unit': 'Organization Unit Name for use in subject',
            'ORG_UNIT_NAME': CliMatcher.QuotedStringMatcher(),
            'email': 'Email address for use in subject',
            'EMAIL': CliMatcher.PatternMatcher( r'\S+', helpname='WORD',
               helpdesc='Email address' ),
            'subject-alternative-name': 'Subject alternative name extension',
            'ip': CliCommand.singleKeyword( 'ip',
               helpdesc='IP addresses for use in subject-alternative-name' ),
            'IP': IpGenAddrMatcher.ipGenAddrMatcher,
            'dns': CliCommand.singleKeyword( 'dns',
               helpdesc='DNS names for use in subject-alternative-name' ),
            'DNS': CliMatcher.PatternMatcher( r'^(?!email$|ip$)(%s)' % dnsNameRe,
               helpname='WORD', helpdesc='DNS name' ),
            'EMAIL_KW': CliCommand.singleKeyword( 'email',
               helpdesc='Email addresses for use in subject-alternative-name' ),
            'SAN_EMAIL': CliMatcher.PatternMatcher( r'^(?!dns$|ip$)(\S+)',
               helpname='WORD', helpdesc='Email address' )
          }

   @staticmethod
   def adapter( mode, args, argsList ):
      if 'parameters' not in args or 'SIGN_REQ_PARAMS' in args:
         return
      result = {}
      translationMap = {
                        'COMMON_NAME': 'commonName',
                        'COUNTRY_CODE': 'country',
                        'STATE_NAME': 'state',
                        'LOCALITY_NAME': 'locality',
                        'ORG_NAME': 'orgName',
                        'ORG_UNIT_NAME': 'orgUnitName',
                        'EMAIL': 'emailAddress',
                        }
      for k, v in translationMap.iteritems():
         result[ v ] = args.get( k )
      result[ 'san' ] = { 'sanIp': args.get( 'IP' ),
                          'sanDns': args.get( 'DNS' ),
                          'sanEmailAddress': args.get( 'SAN_EMAIL' ) }
      args[ 'SIGN_REQ_PARAMS' ] = result

# Taken from RFC5280, Upper Bounds ( page 123 of May 2008 rev. )
upperBounds = { 'country' : 2,
                'state' : 128,
                'locality' : 128,
                'orgName' : 64,
                'orgUnitName' : 64,
                'commonName' : 64,
                'emailAddress' : 128,
              }
 
def _validateParams( paramName, paramValue ):
   if not paramValue:
      return
   
   printName = { 'country' : 'Country code', 
                 'state' : 'State', 
                 'locality' : 'Locality', 
                 'orgName' : 'Organization name',
                 'orgUnitName' : 'Organization unit name', 
                 'commonName' : 'Common name', 
                 'emailAddress' : 'Email address' }
     
   if paramName in upperBounds:
      if len( paramValue ) > upperBounds[ paramName ]:
         raise SslCertKey.SslCertKeyError( '%s can be at most %d characters.' % (
            printName[ paramName ], upperBounds[ paramName ] ) )
     
   if paramName == 'sanIp':
      for ip in paramValue:
         try:
            Arnet.IpGenAddr( ip )
         except ( IndexError, ValueError ):
            raise SslCertKey.SslCertKeyError( 'IP address \'%s\' is not a valid v4 '
                                              'or v6 address' % ip )
        
   if paramName == 'sanDns':
      for dnsName in paramValue:
         if not re.match( '^%s$' % dnsNameRe, dnsName ):
            raise SslCertKey.SslCertKeyError( 'DNS name \'%s\' is not valid.' )
           
   if paramName == 'sanEmailAddress':
      for emailAddress in paramValue:
         if len( emailAddress ) > upperBounds[ 'emailAddress' ]:
            raise SslCertKey.SslCertKeyError( '%s \'%s\' can be at most '
               '%d characters.' % (
                  printName[ 'emailAddress' ], emailAddress,
                  upperBounds[ 'emailAddress' ] ) )

  
def _getSignRequestParamsInteractive( mode ):
   '''
   Takes in as single line inputs each attribute required for generating a
   Certificate Signing Request and checks for errors. If no errors are found,
   generate a CSR and output to terminal.
   '''
   signReqParams = dict()
   
   inp = BasicCliUtil.getSingleLineInput( mode,
         'Common Name for use in subject: ' )
   if not inp:
      raise SslCertKey.SslCertKeyError( 'Common Name is needed' )
   signReqParams[ 'commonName' ] = inp if inp else None
   _validateParams( 'commonName', signReqParams[ 'commonName' ] )
   
   inp = BasicCliUtil.getSingleLineInput( mode,
         'Two-Letter Country Code for use in subject: ' )
   signReqParams[ 'country' ] = inp if inp else None
   _validateParams( 'country', signReqParams[ 'country' ] )
   
   inp = BasicCliUtil.getSingleLineInput( mode,
         'State for use in subject: ' )
   signReqParams[ 'state' ] = inp if inp else None
   _validateParams( 'state', signReqParams[ 'state' ] )
   
   inp = BasicCliUtil.getSingleLineInput( mode,
         'Locality Name for use in subject: ' )
   signReqParams[ 'locality' ] = inp if inp else None
   _validateParams( 'locality', signReqParams[ 'locality' ] )
   
   inp = BasicCliUtil.getSingleLineInput( mode,
         'Organization Name for use in subject: ' )
   signReqParams[ 'orgName' ] = inp if inp else None
   _validateParams( 'orgName', signReqParams[ 'orgName' ] )
   
   inp = BasicCliUtil.getSingleLineInput( mode,
         'Organization Unit Name for use in subject: ' )
   signReqParams[ 'orgUnitName' ] = inp if inp else None
   _validateParams( 'orgUnitName', signReqParams[ 'orgUnitName' ] )
   
   inp = BasicCliUtil.getSingleLineInput( mode,
         'Email address for use in subject: ' )
   signReqParams[ 'emailAddress' ] = inp if inp else None
   _validateParams( 'emailAddress', signReqParams[ 'emailAddress' ] )
  
   inp = BasicCliUtil.getSingleLineInput( mode,
      'IP addresses (space separated) for use in subject-alternative-name: ' )
   signReqParams[ 'sanIp' ] = inp.split() if inp else None
   _validateParams( 'sanIp', signReqParams[ 'sanIp' ] )
   
   inp = BasicCliUtil.getSingleLineInput( mode,
      'DNS names (space separated) for use in subject-alternative-name: ' )
   signReqParams[ 'sanDns' ] = inp.split() if inp else None
   _validateParams( 'sanDns', signReqParams[ 'sanDns' ] )
   
   inp = BasicCliUtil.getSingleLineInput( mode,
      'Email addresses (space separated) for use in subject-alternative-name: ' )
   signReqParams[ 'sanEmailAddress' ] = inp.split() if inp else None
   _validateParams( 'sanEmailAddress', signReqParams[ 'sanEmailAddress' ] )
   
   return signReqParams

def _getKeyParams( mode, keyParams ):
   key = None
   newKeyBits = None
     
   if not keyParams:
      key = BasicCliUtil.getSingleLineInput( mode, 'PKI Key to use for CSR: ' )
   else:
      key = keyParams[ 'key' ]
      newKeyBits = keyParams[ 'newKeyBits' ]
  
   if not key:
      raise SslCertKey.SslCertKeyError( 'Key is needed' )
   
   keyFilepath = os.path.join( Constants.keysDirPath(), key )

   if not newKeyBits and not os.path.isfile( keyFilepath ):
      raise SslCertKey.SslCertKeyError( 'Key not found under sslkey:' )
      
   return ( keyFilepath, newKeyBits )
  
def getSignRequestParams( mode, signReqParams ):
   if not signReqParams:
      signReqParams = _getSignRequestParamsInteractive( mode )
   else:
      sanParams = signReqParams.pop( 'san' )
      if sanParams:
         # Create a dict and fill in missing keys with None
         sanParams = dict( sanParams )
         sanParams[ 'sanIp' ] = sanParams.get( 'sanIp', None )
         sanParams[ 'sanDns' ] = sanParams.get( 'sanDns', None )
         sanParams[ 'sanEmailAddress' ] = sanParams.get( 'sanEmailAddress', None )
         if sanParams[ 'sanIp' ]:
            sanParams[ 'sanIp' ] = [ x.stringValue for x in sanParams[ 'sanIp' ] ]
      else:
         sanParams = { 'sanIp': None, 'sanDns': None, 'sanEmailAddress': None }
      
      signReqParams.update( sanParams )
   
      for p in [ 'commonName', 'country', 'state', 'locality',
                 'orgName', 'orgUnitName', 'emailAddress',
                 'sanIp', 'sanDns', 'sanEmailAddress' ]:
         _validateParams( p, signReqParams[ p ] )
   
   return signReqParams   
      
def pkiGenerateCsrOrCert( mode, args ):
   '''
   Generate CSR or self signed certificate
   '''
   digest = args.get( 'DIGEST', Constants.defaultDigest )
   signReqParams = args.get( 'SIGN_REQ_PARAMS' )
   cert = args.get( 'CERT_NAME' )
   keyParams = None
   if 'KEY_NAME' in args:
      keyParams = { 'key': args[ 'KEY_NAME' ],
                    'newKeyBits': args.get( 'RSA_BIT_LENGTH' ) }
   validityDays = None
   if 'CERT_NAME' in args:
      validityDays = args.get( 'VALIDITY', Constants.defaultCertValidity )

   if not SslCertKey.isSslDirsCreated():
      mode.addError( pkiNotReadyError )
      return
   
   genericError = 'Error generating %s' % ( 'certificate' if cert else 'CSR' )
   
   try:
      ( keyFilepath, newKeyBits ) = _getKeyParams( mode, keyParams )
      genNewKey = True if newKeyBits else False

      signReqParams = getSignRequestParams( mode, signReqParams )
      certFilepath = os.path.join( Constants.certsDirPath(), cert ) if cert else None
      
      keyLogAction, keyHash = SslCertKey.getLogActionAndFileHash( keyFilepath,
                                                       'sslkey:', 'created' )
      certLogAction, certHash = SslCertKey.getLogActionAndFileHash( certFilepath,
                                                         'certificate:', 'created' )

      signRequest = False if cert else True
      ( csr, _ ) = SslCertKey.generateCertificate( keyFilepath=keyFilepath,
                                        certFilepath=certFilepath, 
                                        signRequest=signRequest,
                                        genNewKey=genNewKey, 
                                        newKeyBits=newKeyBits,
                                        digest=digest,
                                        validity=validityDays,
                                        **signReqParams )
      if genNewKey:
         SslCertKey.generateSslKeyCertSysLog( keyFilepath, 'sslkey:',
                                              keyLogAction, keyHash )

      if signRequest:
         mode.addMessage( csr )
      else:
         mode.addMessage( 'certificate:%s generated' % cert )
         SslCertKey.generateSslKeyCertSysLog( certFilepath, 'certificate:',
                                              certLogAction, certHash )
   except SslCertKey.SslCertKeyError as e:
      mode.addError( '%s (%s)' % ( genericError, str( e ) ) )
   except EnvironmentError as e:
      mode.addError( '%s (%s)' % ( genericError, e.strerror ) )
   except StandardError as e:
      mode.addError( '%s' % genericError )

class PkiGenerateCsrOrCertCmd( CliCommand.CliCommandClass ):
   syntax = ( 'security pki certificate generate signing-request '
                                    '[ key KEY_NAME '
                                                '[ generate rsa RSA_BIT_LENGTH ] ] '
                                    '[ digest DIGEST ] '
                                    '[ SIGN_REQ_PARAMS ]' )
   data = {
            'security': Security.securityKwMatcher,
            'pki': pkiEnKwMatcher,
            'certificate': certKwMatcher,
            'generate': generateNode,
            'signing-request': signRequestKwMatcher,
            'key': keyKwMatcher,
            'KEY_NAME': Ssl.keyNameMatcher,
            'rsa': rsaMatcher,
            'RSA_BIT_LENGTH': rsaKeySizeMatcher,
            'digest': digestKwMatcher,
            'DIGEST': digestMatcher,
            'SIGN_REQ_PARAMS': ParamExpression
          }
   handler = pkiGenerateCsrOrCert

BasicCli.EnableMode.addCommandClass( PkiGenerateCsrOrCertCmd )

class PkiGenerateSelfSignedCertCmd( CliCommand.CliCommandClass ):
   syntax = ( 'security pki certificate generate self-signed CERT_NAME '
                                    '[ key KEY_NAME '
                                                '[ generate rsa RSA_BIT_LENGTH ] ] '
                                    '[ digest DIGEST ] '
                                    '[ validity VALIDITY ] '
                                    '[ SIGN_REQ_PARAMS ]' )
   data = {
            'security': Security.securityKwMatcher,
            'pki': pkiEnKwMatcher,
            'certificate': certKwMatcher,
            'generate': generateNode,
            'self-signed': 'Self signed certificate',
            'CERT_NAME': Ssl.certificateNameMatcher,
            'key': keyKwMatcher,
            'KEY_NAME': Ssl.keyNameMatcher,
            'rsa': rsaMatcher,
            'RSA_BIT_LENGTH': rsaKeySizeMatcher,
            'validity': validityMatcher,
            'VALIDITY': CliMatcher.IntegerMatcher( Constants.minCertValidity,
               Constants.maxCertValidity, helpdesc='Days' ),
            'digest': digestKwMatcher,
            'DIGEST': digestMatcher,
            'SIGN_REQ_PARAMS': ParamExpression
          }
   handler = pkiGenerateCsrOrCert

BasicCli.EnableMode.addCommandClass( PkiGenerateSelfSignedCertCmd )

def Plugin( entityManager ):
   global config, status, execRequest
   config = ConfigMount.mount( entityManager, 'mgmt/security/ssl/config',
                               'Mgmt::Security::Ssl::Config', 'w' )
   status = LazyMount.mount( entityManager, 'mgmt/security/ssl/status',
                             'Mgmt::Security::Ssl::Status', 'r' )
   execRequest = LazyMount.mount( entityManager, 'mgmt/security/ssl/execRequest',
                                  'Mgmt::Security::Ssl::ExecRequest', 'w' )
