# Copyright (c) 2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.


from CliModel import Bool
from CliModel import Dict
from CliModel import List
from CliModel import Model
from CliModel import Str
from CliModel import Enum
from CliModel import Int
from CliModel import Float
from CliModel import Submodel
import TableOutput
import Tac
import textwrap
import datetime

# Copied from datetime docs to avoid pytz
ZERO = datetime.timedelta(0)
class UTC( datetime.tzinfo ):
   """UTC"""
   def utcoffset( self, dt ):
      return ZERO
   def tzname( self, dt ):
      return "UTC"
   def dst( self, dt ):
      return ZERO

utc = UTC()

ProfileState = Tac.Type( "Mgmt::Security::Ssl::ProfileState" )
ErrorType = Tac.Type( "Mgmt::Security::Ssl::ErrorType" )
ErrorAttr = Tac.Type( "Mgmt::Security::Ssl::ErrorAttr" )

def _printLineItem( label, content, space=30 ):
   fmt = '{:%ds}{:}' % space
   if space:
      print fmt.format( '%s:' % label, content )
   else:
      print '%s: %s' % ( label, content ) 

def printDN( dn ):
   _printLineItem( "      Common name", "%s" % (  dn.commonName ) )
   if dn.email:
      _printLineItem( "      Email address", "%s" % (  dn.email ) )
   if dn.organizationUnit:
      _printLineItem( "      Organizational unit", 
                     "%s" % (  dn.organizationUnit ) )
   if dn.organization:
      _printLineItem( "      Organization", 
                      "%s" % (  dn.organization ) )
   if dn.locality:
      _printLineItem( "      Locality", "%s" % (  dn.locality ) )
   if dn.stateOrProvince:
      _printLineItem( "      State", 
                      "%s" % (  dn.stateOrProvince ) )
   if dn.country:
      _printLineItem( "      Country", "%s" % (  dn.country ) )

class PublicKey( Model ):
   # Currently only RSA and ECDSA keys are supported
   encryptionAlgorithm = Enum( help="Encryption algorithm of the key", 
                               values=( "RSA", "ECDSA", ) )
   size = Int( help="Size of the key in bits" )
   modulus = Int( help="Modulus of the key (a value of 0 means no modulus)",
                  default=0 )
   publicExponent = Int( help=( "Public exponent of the RSA key. "
                                "This field is not present for DSA key" ),
                         optional=True )

class DistinguishedName( Model ):
   commonName = Str( help="Common name" )
   email = Str( help="Email address", optional=True )
   organization = Str( help="Name of the organization", optional=True )
   organizationUnit = Str( help="Name of the organizational unit", 
                           optional=True )
   locality = Str( "Name of the locality", optional=True )
   stateOrProvince = Str( "Name of the state or province", optional=True )
   country = Str( "Name of the country", optional=True )

class Extension( Model ):
   value = Str( help="Value of the X.509 version 3 extension" )
   critical = Bool( help="Whether the extension is critical" )
   
class Certificate( Model ):
   version = Int( help="X.509 version" )
   serialNumber = Int( help="Serial number" )
   subject = Submodel( help="Entity associated with the certificate",
                      valueType=DistinguishedName )
   issuer = Submodel( help="Entity who has signed and issued the certificate",
                      valueType=DistinguishedName )
   notBefore = Int( help=( "Timestamp on which the certificate " 
                           "validity period begins" ) )
   notAfter = Int( help=( "Timestamp on which the certificate " 
                          "validity period ends" ) )
   publicKey = Submodel( help="Public key information",
                         valueType=PublicKey )
   extension = Dict( help=( "Mapping from X.509 version 3 "
                            "extension name to extension value" ),
                     valueType=Extension,
                     optional=True )
   
   def render( self ):
      _printLineItem( "   Version", "%d" % ( self.version ) )
      _printLineItem( "   Serial Number", "%x" % self.serialNumber )
      _printLineItem( "   Issuer", "" )
      printDN( self.issuer )
      
      epochDt = datetime.datetime( 1970, 1, 1, tzinfo=utc )
      notBeforeDt = epochDt + datetime.timedelta( seconds=self.notBefore )
      notAfterDt = epochDt + datetime.timedelta( seconds=self.notAfter )
      
      _printLineItem( "   Validity", "" )
      _printLineItem( "      Not before", "%s GMT" % 
                     notBeforeDt.strftime( "%b %d %H:%M:%S %Y" ) )  
      _printLineItem( "      Not After", "%s GMT" % 
                     notAfterDt.strftime( "%b %d %H:%M:%S %Y" ) )  

      _printLineItem( "   Subject", "" )
      printDN( self.subject )
      _printLineItem( "   Subject public key info", "" )
      _printLineItem( "      Encryption Algorithm", 
                      "%s" % self.publicKey.encryptionAlgorithm )
      _printLineItem( "      Size", "%d bits" % self.publicKey.size )
      if self.publicKey.publicExponent:
         _printLineItem( "      Public exponent", 
                         "%d" % self.publicKey.publicExponent )
      if self.publicKey.modulus:
         hexmod = "%x" % self.publicKey.modulus
         output = textwrap.fill( hexmod, initial_indent=30 * ' ',
                                 subsequent_indent=30 * ' ',
                                 width=85 )
         _printLineItem( "      Modulus", output.lstrip() )
      if len( self.extension ):
         _printLineItem( "   X509v3 extensions", "" )
      for name, ext in self.extension.iteritems():
         _printLineItem( "      %s" % name, 
                         "Critical" if ext.critical else "",
                         space=0 ) 
         print textwrap.fill( ext.value,
                              initial_indent='         ',
                              subsequent_indent='         ',
                              width=85 )
   
class Certificates( Model ):
   certificates = Dict( help=( "Mapping from certificate name to certificate "
                               "used in SSL/TLS" ),
                        valueType=Certificate )
   def render( self ):
      for name, cert in self.certificates.iteritems():
         print "Certificate %s:" % ( name )
         cert.render()

class RevokedCertificate( Model ):
   serialNumber = Str( help="Serial number of the CRL" )
   revocationDate = Float( help="revocation Date" )

class CertificateSigningRequest( Model ):
   version = Int( help="X.509 version" )
   subject = Submodel( help="Entity associated with the certificate",
                      valueType=DistinguishedName )
   publicKey = Submodel( help="Public key information",
                         valueType=PublicKey )
   pemValue = Str( help="CSR in PEM format" )

   def render( self ):
      _printLineItem( "   Data", "" )
      _printLineItem( "   Version", "%d" % ( self.version ) )
      printDN( self.subject )
      _printLineItem( "   Subject public key info", "" )
      _printLineItem( "      Encryption Algorithm",
                      "%s" % self.publicKey.encryptionAlgorithm )
      _printLineItem( "      Size", "%d bits" % self.publicKey.size )
      if self.publicKey.publicExponent:
         _printLineItem( "      Public exponent",
                         "%d" % self.publicKey.publicExponent )
      if self.publicKey.modulus:
         hexmod = "%x" % self.publicKey.modulus
         output = textwrap.fill( hexmod, initial_indent=30 * ' ',
                                 subsequent_indent=30 * ' ',
                                 width=85 )
         _printLineItem( "      Modulus", output.lstrip() )
      _printLineItem( "      PEM Value", "\n" + self.pemValue )

class Crl( Model ):
   crlNumber = Int( help=( "CRL number" ) )
   issuer = Submodel( help="Entity who has signed and issued the certificate",
                      valueType=DistinguishedName )
   lastUpdate = Int( help=( "Timestamp on which the CRL " 
                            "validity period begins" ) )
   nextUpdate = Int( help=( "Timestamp on which the CRL " 
                            "validity period ends" ) )
   revokedList = List( help= "Serial number and Timestamp of the revoked "
                             "certificate", valueType=RevokedCertificate )
   
   def render( self ):
      _printLineItem( "   CRL Number", "%x" % self.crlNumber )
      _printLineItem( "   Issuer", "" )
      printDN( self.issuer )
      epochDt = datetime.datetime( 1970, 1, 1, tzinfo=utc )
      lastUpdate = epochDt + datetime.timedelta( seconds=self.lastUpdate )
      nextUpdate = epochDt + datetime.timedelta( seconds=self.nextUpdate )

      _printLineItem( "   Validity", "" )
      _printLineItem( "      Last Update", "%s GMT" %
                     lastUpdate.strftime( "%b %d %H:%M:%S %Y" ) )
      _printLineItem( "      Next Update", "%s GMT" %
                     nextUpdate.strftime( "%b %d %H:%M:%S %Y" ) )
      _printLineItem( "   Revoked Certificates",
                            "" if self.revokedList else "none" )
      for cert in self.revokedList:
         _printLineItem( "    - Serial Number", "%s" % cert.serialNumber )
         revocationDate = epochDt + datetime.timedelta( seconds=
               int( cert.revocationDate ) )
         _printLineItem( "      Revocation Date", "%s GMT" % 
               revocationDate.strftime( "%b %d %H:%M:%S %Y" ) )

class Crls( Model ):
   crls = Dict( help=( "Mapping from CRL name to CRL used in SSL/TLS" ),
                       valueType=Crl )
   def render( self ):
      for name, crl in self.crls.iteritems():
         print "CRL %s:" % ( name )
         crl.render()

class PublicKeys( Model ):
   publicKeys = Dict( help="Mapping from key name to public key used in SSL/TLS",
                     valueType=PublicKey )
   
   def render( self ):
      for name, key in self.publicKeys.iteritems():
         print "Key %s:" % ( name )
         _printLineItem( "   Encryption Algorithm", 
                         "%s" % key.encryptionAlgorithm )
         _printLineItem( "   Size", "%d bits" % key.size )
         if key.publicExponent:
            _printLineItem( "   Public exponent", "%d" % key.publicExponent )
         if key.modulus:
            hexmod = "%x" % key.modulus
            output = textwrap.fill( hexmod, initial_indent=30 * ' ',
                                    subsequent_indent=30 * ' ',
                                    width=85 )
            _printLineItem( "   Modulus", output.lstrip() )

class DiffieHellmanParameters( Model ):
   size = Int( help="Size of the prime number in bits" )
   prime = Int( help="Prime number used in Diffie-Hellman key exchange" )
   generator = Int( help="Generator used in Diffie-Hellman key exchange" )
         
class DiffieHellman( Model ):
   dhparamsResetInProgress = Bool( 
            help="Whether Diffie-Hellman parameters is being reset" )
   dhparamsLastResetFailed = Bool(
            help="Whether last attempt to reset Diffie-Hellman parameters failed",
            optional=True )
   dhparamsLastSuccessfulReset = Int( 
            help="Last successful Diffie-Hellman parameters reset timestamp",
            optional=True )
   diffieHellmanParameters = Submodel( 
            help="Diffie-Hellman parameters",
            valueType=DiffieHellmanParameters,
            optional=True )

   def render( self ):
      if self.dhparamsResetInProgress:
         print "Diffie-Hellman parameters reset in progress"
      else:
         if self.dhparamsLastResetFailed:
            print "Last attempt to reset Diffie-Hellman parameters failed"
         if self.dhparamsLastSuccessfulReset:
            resetDt = datetime.datetime.fromtimestamp( 
                                        self.dhparamsLastSuccessfulReset )
            print( "Last successful reset on %s" % 
                   resetDt.strftime( "%b %d %H:%M:%S %Y" ) )  
         dh = self.diffieHellmanParameters
         if not dh:
            return
         print "Diffie-Hellman Parameters %s bits" % ( dh.size )
         _printLineItem( "   Generator", "%s" % dh.generator, space=20 )

         hexprime = "%x" % dh.prime
         output = textwrap.fill( hexprime, initial_indent=20*' ',
                                 subsequent_indent=20*' ',
                                 width=85 )
         _printLineItem( "   Prime", "%s" % output.lstrip(), space=20 )


class ProfileError( Model ):
   errorAttr = Enum( help=( "SSL profile attribute to which the"
                            " error applies" ),
                           values=ErrorAttr.attributes )
   errorAttrValue = Str( help="SSL profile attribute value" )
   errorType = Enum( help="Error Type",
                     values=ErrorType.attributes )

class ProfileStatus( Model ):
   profileState = Enum( help="Ssl profile state", 
                        values=ProfileState.attributes )
   profileError = List( help="List of SSL profile errors in 'invalid' state",
                        valueType=ProfileError, optional=True )

class SslStatus( Model ):
   profileStatus = Dict( help="Mapping from SSL profile name to status",
                         valueType=ProfileStatus )
   _hasError = Bool( help="Whether there is atleast one SSL profile with error" )
   
   def _errorMessage( self, profileError ):
      msgs = []
      attrDict = { "profile": "Profile",
                   "certificate": "Certificate",
                   "key": "Key",
                   "trustedCertificate": "Certificate",
                   "chainedCertificate": "Certificate",
                   "crl": "CRL" }
      errTypeDict = { "noProfileData": "has no data",
                      "notExist": "does not exist",
                      "notMatchingCertKey": "does not match with key",
                      "certNotYetValid": "is not yet valid",
                      "certChainNotValid": "has invalid certificate chain",
                      "certTrustChainNotValid": ( "has invalid trusted certificate"
                                                  " chain" ),
                      "missingCrlForTrustChain": ( "has missing crl issued by" 
                                                   " the trusted chain" ),
                      "certExpired": "has expired",
                      "noExtendedKeyUsage": "has no extended key usage value",
                      "noCABasicConstraintTrust": ( "is trusted certificate but"
                                                    " does not have CA basic"
                                                    " constraint set to True" ),
                      "noCABasicConstraintChain": ( "is chained certificate but" 
                                                    " does not have CA basic" 
                                                    " constraint set to True" ),
                      "noCrlSign": ( "is a CRL and signed by a CA who does not have"
                                     " the cRLSign key usage bits set" ),
                      "crlNotSignedByCa": ( "is a CRL but is not signed by any"
                                            " configured trusted certificate" ),
                      "hostnameMismatch": ( "hostname of this device does not match"
                                            " any entry of the Common Name nor"
                                            " Subject Alternative Name in the"
                                            " certificate" ) }
      for e in profileError:
         msg = ""
         msg += attrDict[ e.errorAttr ] + " "
         if not ( e.errorAttr == "profile" ):
            msg += "'%s' " % e.errorAttrValue
         msg += errTypeDict[ e.errorType ]
         msgs.append( msg )
      return '\n'.join( msgs )
      
   def render( self ):
      if self._hasError:
         tableHeadings = ( "Profile", "State", "Additional Info" )
      else:
         tableHeadings = ( "Profile", "State" )
      table = TableOutput.createTable( tableHeadings )
      f = TableOutput.Format( justify="left", maxWidth=40, wrap=True )
      table.formatColumns( *( [ f ] * len( tableHeadings ) ) )
      for name, profStatus in self.profileStatus.iteritems():
         table.newRow( 
            name,
            profStatus.profileState,
            self._errorMessage( profStatus.profileError ) )
      print table.output()
