# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

#
# This contains all the generic logic to deal with the cloud side
# Currently we support AWS and Azure( in near future ).
#
#
import Tracing
import threading
from AwsBackend import AwsBackend, AwsAccess, AwsParsedConfig
from AzureBackend import AzureBackend, AzureAdAccess, AzureSdkAuthAccess, \
   AzureMsiAccess, AzureParsedConfig
from GcpBackend import GcpBackend, GcpServiceAccountAccess, \
   GcpDefaultCredentialsAccess, GcpParsedConfig
from CloudHaConfig import BaseCloudConfigHandler
import socket
import CloudHaBackend
from CloudException import ConfigInvalid, \
   BackendException
import os
import re
from CloudUtil import defaultRecoveryWaitTime

# pkgdeps: import googleapiclient, oauth2client

# This can be set by test infra to skip any backend SDK calls
# This should only be set for testing code.
def skipBackend():
   return os.environ.get( 'SKIP_BACKEND', None )

t0 = Tracing.t0
t1 = Tracing.t1
t2 = Tracing.t2

def validateAddress( ipAddress, requirePrefix=False ):
   try:
      if requirePrefix:
         ip, mask = ipAddress.split( '/', 1 )
         if int( mask ) > 32 or int( mask ) < 0:
            t0( 'Bad ip address prefix %s', ipAddress )
            raise ConfigInvalid( 'Bad Ip address prefix %s' % \
               ipAddress ) 
      else:
         ip = ipAddress
      if ip.count( '.' ) != 3 or '/' in ip:
         t0( 'Bad ip address in config: ', ipAddress )
         raise ConfigInvalid( 'Bad Ip address %s' % \
               ipAddress ) 
      socket.inet_aton( ip )
   except ( socket.error, ValueError ) as e:
      t0( 'Error parsing ipAddress %s, exception: %s '\
          % ( ipAddress, e ) )
      raise ConfigInvalid( "Bad ip address %s" % ipAddress )

   return True

# cloudType passed in because all validator functions must take it as
# second argument. It is not used in general config validation.
@BaseCloudConfigHandler.haConfigValidator
def generalConfigValidation( haConfig, cloudType ):
   t1( 'General Config Validation invoked' )
   try:
      generalConfig = haConfig[ 'generalConfig' ]
      enableOption = generalConfig.get( 'enable_optional', 'true' )
      if enableOption.lower() not in [ 'true', 'false' ]:
         raise ConfigInvalid( 'enable must be true or false' )
      hysteresisTime = generalConfig.get( 'hysteresis_time_optional', 
         defaultRecoveryWaitTime )
      try:
         int( hysteresisTime )
      except ValueError:
         raise ConfigInvalid( 'hysteresis_time_optional must be an integer' )
      sourceIp = generalConfig.get( 'source_ip_optional', None )
      if sourceIp is not None:
         validateAddress( sourceIp )
   except KeyError as e:
      raise ConfigInvalid( 'Error parsing general section of config: %s' % e )
   return True

# cloudType passed in because all validator functions must take it as
# second argument. It is not used in bfd config validation.
@BaseCloudConfigHandler.haConfigValidator
def bfdConfigValidation( haConfig, cloudType ):
   t1( 'bfdConfigValidation invoked' )
   try:
      bfdConfig = haConfig[ 'bfdConfig' ]
      peerVeosIp = bfdConfig[ 'peerVeosIp' ]
      validateAddress( peerVeosIp )
      # pylint: disable-msg=W0104
      bfdConfig[ 'bfdSourceInterface' ]
      bfdConfig[ 'bfdSessionType' ]
   except KeyError as e:
      raise ConfigInvalid( 'Error parsing bfd specific config: %s' % e )
   return True

# Helper function to get route entries from AWS config
def getAwsRouteEntries( routeConfig ):
   routeEntries = {} 
   try:
      rtAndNI = routeConfig[ 'routeTableIdAndRouteNetworkInterface' ]
      if not len( rtAndNI ):
         t0( 'No route table ID specified for routing traffic' )
         raise ConfigInvalid( 'No Route table Id specified' )
      for i in rtAndNI:
         rtID = i[ 'routeTableId' ]
         ni = i[ 'routeTarget' ]
         dest = i[ 'destination' ]
         validateAddress( dest, requirePrefix=True )
         routeEntries[ ( rtID, dest ) ] = ni
   except KeyError as e:
      t0( ' Error parsing routeTable entries %s ' % e )
      raise ConfigInvalid( 'Key Error: %s' % str(e ) )
   return routeEntries

def getAwsParsedConfig( haConfig ):
   peerRouteEntries = {} 
   localRouteEntries = {} 
   try:
      awsConfig = haConfig[ 'awsConfig' ]
      # Get region
      region = awsConfig[ 'region' ]
      # Get aws credentials if provided
      awsCreds = awsConfig.get( 'aws_credentials_optional', None )
      awsKey = awsAccessKey = None
      if awsCreds is not None:
         awsKey = awsCreds[ 'aws_access_key_id' ]
         awsAccessKey = awsCreds[ 'aws_secret_access_key' ]
         if awsKey == "" and awsAccessKey == "":
            awsKey = awsAccessKey = None

      # Get aws proxy if provided
      awsProxy = awsConfig.get( 'http_proxy_optional', None )
      port = proxy = proxy_port = proxy_user = proxy_passwd = port_int = \
             proxy_port_int = None
      if awsProxy is not None:
         port = awsProxy.get( 'http_port_optional', None )
         proxy = awsProxy.get( 'http_proxy_optional', None )
         if proxy:
            validateAddress( proxy )
         # TBD Validate proxy IP address and ports
         proxy_port = awsProxy.get( 'http_proxy_port_optional', None )
         proxy_user = awsProxy.get( 'http_proxy_user_optional', None )
         proxy_passwd = awsProxy.get( 'http_proxy_password_optional', None )
         try:
            port_int = int( port ) if port else None
            proxy_port_int = int( proxy_port ) if proxy_port else None
         except ( TypeError, ValueError ):
            t0( ' Bad port/proxy port %s %s ' % ( port, proxy_port ) )
            raise ConfigInvalid( 'Bad port/proxy port: %s %s ' % \
               ( port, proxy_port ) )

      # Get route table configs
      peerRouteEntries = getAwsRouteEntries( \
         haConfig[ 'awsPeerRoutingConfig' ] )
      assert peerRouteEntries
      localRouteEntries = getAwsRouteEntries( \
         haConfig[ 'awsLocalRoutingConfig' ] )
      assert localRouteEntries
   except KeyError as e :
      t0( 'Error parsing awsConfig %s' % e )
      raise ConfigInvalid( "Error parsing %s" % str( e ) ) 
   except AssertionError as e:
      raise ConfigInvalid( str( e ) ) 

   access = AwsAccess( region, awsKey, awsAccessKey, port_int, \
      proxy, proxy_port_int, proxy_user, proxy_passwd )
   return AwsParsedConfig( access, localRouteEntries, \
      peerRouteEntries ) 
   
#
# This will make sure the AWS config is good to work with
# It just parses to mak sure the format is as expected. 
# the real validation can only be done by interfacing with AWS
# and happens later.
#
@BaseCloudConfigHandler.haConfigValidator
def AwsParseConfigValidation( haConfig, cloudType ):
   if cloudType != 'AWS':
      t1( 'AwsConfigValidation skipped because cloudType is %s' % cloudType )
      return True
   t1( 'AwsConfigValidation invoked' )
   # validate that config is good as given. This doesn't runtime check 
   # the values are correct though.
   return getAwsParsedConfig( haConfig )

def validateAzureRoutingConfig( config ):
   try:
      assert config[ 'resourceGroupName' ]
      assert len( config[ 'routeTables' ] ) > 0
      for routeTable in config[ 'routeTables' ]:
         assert routeTable[ 'routeTableName' ]
         assert len( routeTable[ 'routes' ] ) > 0
         for route in routeTable[ 'routes' ]:
            assert validateAddress( route[ 'prefix' ], requirePrefix=True )
            assert validateAddress( route[ 'nextHopIp' ] )
   except AssertionError as e:
      t0( 'Error parsing azure routing config', e )
      raise e

def getAzureParsedConfig( haConfig ):
   access = None
   localRoutingConfig = None
   peerRoutingConfig = None
   try:
      # parse user credentials
      azureConfig = haConfig[ 'azureConfig' ]
      validAzureAccessKeys = set( [ 'azureActiveDirectoryCredentials',
                                    'azureSdkAuthCredentials' ] )
      azureAccessKeys = validAzureAccessKeys.intersection( azureConfig.keys() )
      if len( azureAccessKeys ) > 1:
         t0( 'Error parsing azure config: multiple credential types given' )
         raise ConfigInvalid( 'multiple credential types given' )
      if azureAccessKeys:
         azureAccessKey = list( azureAccessKeys )[ 0 ]
         userCredentials = azureConfig[ azureAccessKey ]
         if azureAccessKey == 'azureActiveDirectoryCredentials':
            access = AzureAdAccess(
               userCredentials[ 'email' ],
               userCredentials[ 'password' ],
               userCredentials[ 'subscriptionId' ]
            )
         elif azureAccessKey == 'azureSdkAuthCredentials':
            access = AzureSdkAuthAccess(
               userCredentials[ 'clientId' ],
               userCredentials[ 'clientSecret' ],
               userCredentials[ 'subscriptionId' ],
               userCredentials[ 'tenantId' ],
               userCredentials[ 'activeDirectoryEndpointUrl' ],
               userCredentials[ 'resourceManagerEndpointUrl' ],
               userCredentials[ 'activeDirectoryGraphResourceId' ],
               userCredentials[ 'sqlManagementEndpointUrl' ],
               userCredentials[ 'galleryEndpointUrl' ],
               userCredentials[ 'managementEndpointUrl' ]
            )
      else:
         # no explicit credentials provided, will use MSI
         access = AzureMsiAccess()

      localRoutingConfig = haConfig[ 'azureLocalRoutingConfig' ]
      validateAzureRoutingConfig( localRoutingConfig )

      peerRoutingConfig = haConfig[ 'azurePeerRoutingConfig' ]
      validateAzureRoutingConfig( peerRoutingConfig )

      httpProxy = azureConfig.get( 'http_proxy_optional', None )
      httpsProxy = azureConfig.get( 'https_proxy_optional', None )

   except ( KeyError, AssertionError ) as e:
      t0( 'Error parsing azureConfig %s' % e )
      raise ConfigInvalid( str( e ) )
   return AzureParsedConfig( access, localRoutingConfig,
                             peerRoutingConfig, httpProxy, httpsProxy )

#
# This will make sure the Azure config is good to work with
# It just parses to mak sure the format is as expected. 
# the real validation can only be done by interfacing with Azure
# and happens later.
#
@BaseCloudConfigHandler.haConfigValidator
def AzureParseConfigValidation( haConfig, cloudType ):
   if cloudType != 'Azure':
      t1( 'AzureConfigValidation skipped because cloudType is %s' % cloudType )
      return True
   t1( 'AzureConfigValidation invoked' )
   # validate that config is good as given. This doesn't runtime check 
   # the values are correct though.
   return getAzureParsedConfig( haConfig )

# Helper function to get route entries from GCP config
def getGcpRouteEntries( routeConfig ):
   routeEntries = set()
   try:
      if not routeConfig:
         msg = 'No routes present'
         t0( msg )
         raise ConfigInvalid( msg )
      for route in routeConfig:
         macAddr= route.get( 'macAddr', '' )
         vpc = route.get( 'vpc', '' )
         destination = route[ 'destination' ]
         tag = route.get( 'tag', '' )
         validateAddress( destination, requirePrefix=True )
         routeEntries.add( ( macAddr, vpc, destination, tag ) )
   except KeyError as e:
      t0( ' Error parsing routing config entries %s ' % e )
      raise ConfigInvalid( 'Key Error: %s' % str(e ) )
   return routeEntries

def getGcpParsedConfig( haConfig ):
   def _isValidProxy( proxy, https=False ):
      # Verify that proxy has a valid pattern. The proxy address can be a domain or
      # an IP address. It can have a username and a password. The proxy string must
      # specify a port number. Some valid proxy are:
      #  - http://proxy:1000
      #  - http://proxy.site:100
      #  - https://user@www.proxy.site:2000
      #  - http://10.192.0.16:3128
      #  - https://user:pass@10.193.23.53:1234
      proxyPattern = ( r'://(?:[^:]+(?:.+)?@)?'
            r'(?:[0-9a-zA-Z]+\.)*[0-9a-zA-Z]+:[\d]{1,5}$' )
      proxyPattern = ( r'^https' if https else r'^http' ) + proxyPattern
      return True if re.match( proxyPattern, proxy ) else False

   try:
      gcpConfig = haConfig[ 'gcpConfig' ]
      # Get project
      project = gcpConfig[ 'project' ]
      # Get gcp credentials
      serviceFile = gcpConfig.get( 'service_file_optional', None )
      if serviceFile:
         assert os.path.isfile( serviceFile ) and os.access( serviceFile, os.R_OK )

      # Get gcp proxy if provided
      httpProxy = gcpConfig.get( 'http_proxy_optional', None )
      assert not httpProxy or _isValidProxy( httpProxy, False )
      httpsProxy = gcpConfig.get( 'https_proxy_optional', None )
      assert not httpsProxy or _isValidProxy( httpsProxy, True )

      # Get route configs
      peerRouteEntries = getGcpRouteEntries( haConfig[ 'gcpPeerRoutingConfig' ] )
      assert peerRouteEntries
      localRouteEntries = getGcpRouteEntries( haConfig[ 'gcpLocalRoutingConfig' ] )
      assert localRouteEntries
   except KeyError as e :
      t0( 'Error parsing gcpConfig %s' % e )
      raise ConfigInvalid( "Error parsing %s" % str( e ) )
   except AssertionError as e:
      raise ConfigInvalid( str( e ) )

   if serviceFile:
      access = GcpServiceAccountAccess( project, serviceFile )
   else:
      access = GcpDefaultCredentialsAccess( project )
   return GcpParsedConfig( access, localRouteEntries, peerRouteEntries,
                           httpProxy, httpsProxy )

# This will make sure the GCP config is good to work with.
# It just parses to make sure the format is as expected.
# The real validation can only be done by interfacing with GCP
# and happens later.
@BaseCloudConfigHandler.haConfigValidator
def GcpParseConfigValidation( haConfig, cloudType ):
   if cloudType != 'GCP':
      t1( 'GcpConfigValidation skipped because cloudType is %s' % cloudType )
      return True
   t1( 'GcpConfigValidation invoked' )
   # Validate that config is good as given. This doesn't runtime check
   # the values are correct though.
   return getGcpParsedConfig( haConfig )

#
# This class should be used from normal agent code
# This calls the backend which runs in its own thread due to 
# blocking nature.
# If the caller cares about the RPC  response, it needs to arrange for 
# calling back
#
class CloudHelper( object ):
   def __init__( self, config, cloudType ):
      # instantiate backend for current Cloud Environment
      self.cloudType = cloudType
      if cloudType == 'AWS':
         self.backend = AwsBackend( getAwsParsedConfig( config ) )
      elif cloudType == 'Azure' and not skipBackend():
         self.backend = AzureBackend( getAzureParsedConfig( config ) )
      elif cloudType == 'GCP' and not skipBackend():
         self.backend = GcpBackend( getGcpParsedConfig( config ) )

      # setup RPC variables
      self.result = [ None ]
      self.backendResponsePending = None
      self.thread = None

   def rpcHandler( self, func, args ):
      #
      # make sure there is nothing pending as we don't support 
      # more than one thread concurrently for now for each backend 
      # object. The user needs to make sure it doesn't intermix 
      # APIs from different threads for now.
      #
      if self.backendResponsePending:
         raise BackendException(
            'Internal error: Already waiting for ' \
            'response: %s' % ( self.backendResponsePending ) ) 

      self.result[ 0 ] = CloudHaBackend.BackendResult( func, args )
      threadName = 'BackEndThread' 
      self.thread = threading.Thread( target=self.backend.rpcHandler,
                        name=threadName, args= [ self.result ] )
      t2( 'starting backend thread for api %s' % func )
      self.thread.start()
      self.backendResponsePending = threadName + '_' + str( self.result[ 0 ] )

   # This retrieves the response from backend if available.
   # Always call after asyncResponseDone() returns True.
   #
   def getRpcResponse( self ):
      if skipBackend():
         fakeResult = CloudHaBackend.BackendResult( None, None )
         fakeResult.ret = True
         return fakeResult
      assert self.result[ 0 ].ret != None, "Internal Error"
      t0( 'Result from rpc call for %s ' % self.result[ 0 ] )
      return self.result[ 0 ]

   # This should be called before calling getRpcResponse()
   # This returns true if we have the result from the remote
   # rpc call.
   def asyncResponseDone( self ):
      if skipBackend():
         return True
      if self.result[ 0 ].ret == None :
         return False
      else:
         # Clean up for next call
         self.backendResponsePending = None 
         return True

   # Check if response is pending
   def responsePending( self ):
      return self.backendResponsePending != None

   def updatePeerRoutes( self ):
      t0( 'Taking control of the peer\'s subnets routing' )
      t2( 'starting backend thread for peer Route update' )
      if not skipBackend():
         self.rpcHandler( self.backend.updatePeerRouteTables, None )

   def updateLocalRoutes( self ):
      t0( 'updateLocalRoutes: taking back local default traffic' ) 
      t2( 'starting backend thread for local route update' )
      if not skipBackend():
         self.rpcHandler( self.backend.updateLocalRouteTables, None )

   def runTimeConfigValidation( self ):
      # Initiate validation
      t0( 'starting back end runtime config validation' )
      t2( 'starting backend thread for config validation' )
      if not skipBackend():
         self.rpcHandler( self.backend.configValidator, None )

   def getResponse( self ):
      return self.getRpcResponse()

   def getResponseResult( self ):
      return self.getRpcResponse().ret

   def getResponseMessage( self ):
      return self.getRpcResponse().message

   def getResponseException( self ):
      return self.getRpcResponse().exception

   def getFailureTypeStatus( self ):
      return self.getRpcResponse().failureType
