#!/var/virtualenv/CloudVirtualEnv/bin/python
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import ArPyUtils.Decorators
import CloudHaBackend
import Tracing
import threading
import json
import os
import tempfile
from CloudException import BackendException
import logging

# Enables logging for SDK.
def azureDebugSetup():
   rootLogger = logging.getLogger()
   azureLogFileHandler = logging.FileHandler( \
         '/var/log/agents/CloudHa_backend_azure.logs' )
   formatter = logging.Formatter( \
         '%(asctime)s - %(name)s - %(process)d - %(levelname)s - %(message)s')
   azureLogFileHandler.setFormatter( formatter )
   rootLogger.setLevel( logging.DEBUG )
   rootLogger.addHandler( azureLogFileHandler )
   rootLogger.debug( 'Azure Backend debugging turned on.' )

t0Saved = Tracing.t0
t1Saved = Tracing.t1
t2Saved = Tracing.t2

def debugMsg( *s ):
   msg = '** %s **: ' % ( threading.current_thread().name ) 
   for i in s:
      msg += str( i )
   return msg

def t0( *msg ):
   t0Saved( debugMsg( msg ) )

def t1( *msg ):
   t1Saved( debugMsg( msg ) )

def t2( *msg ):
   t2Saved( debugMsg( msg ) )

try:
   # pylint: disable-msg=E0611
   # pylint: disable-msg=F0401
   from azure.common.credentials import UserPassCredentials
   from azure.common.client_factory import get_client_from_auth_file
   from azure.mgmt.network import NetworkManagementClient
   from azure.mgmt.network.models import Route
   from azure.mgmt.resource import SubscriptionClient
   from msrestazure.azure_active_directory import MSIAuthentication
   import urllib3
except ImportError:
   # This is necessary so that Abuilds can pass. The Azure Python SDK is included in
   # the virtual environment created by CloudVirtualEnv, which the CloudHa Agent runs
   # inside of on vEOS Router, but Abuilds will fail breadth tests on these import
   # lines since they are not run inside of the virtual environment.
   t0( 'Error importing Azure Python SDK' )

class AzureAdAccess( object ):
   def __init__( self, email, password, subscriptionId ):
      self.email = email
      self.password = password
      self.subscriptionId = subscriptionId

class AzureSdkAuthAccess( object ):
   def __init__( self, clientId, clientSecret, subscriptionId, tenantId,
                 activeDirectoryEndpointUrl, resourceManagerEndpointUrl,
                 activeDirectoryGraphResourceId, sqlManagementEndpointUrl,
                 galleryEndpointUrl, managementEndpointUrl ):
      self.clientId = clientId
      self.clientSecret = clientSecret
      self.subscriptionId = subscriptionId
      self.tenantId = tenantId
      self.activeDirectoryEndpointUrl = activeDirectoryEndpointUrl
      self.resourceManagerEndpointUrl = resourceManagerEndpointUrl
      self.activeDirectoryGraphResourceId = activeDirectoryGraphResourceId
      self.sqlManagementEndpointUrl = sqlManagementEndpointUrl
      self.galleryEndpointUrl = galleryEndpointUrl
      self.managementEndpointUrl = managementEndpointUrl

   def __str__( self ):
      return json.dumps( self.__dict__, indent=4 )

class AzureMsiAccess( object ):
   pass

class AzureParsedConfig( object ):
   def __init__( self, access, localRoutes, peerRoutes, httpProxy, httpsProxy ):
      assert access and ( isinstance( access, AzureAdAccess ) or
                          isinstance( access, AzureSdkAuthAccess ) or
                          isinstance( access, AzureMsiAccess ) )
      self.cred = access
      self.localRoutes = localRoutes
      self.peerRoutes = peerRoutes
      self.httpProxy = httpProxy
      self.httpsProxy = httpsProxy

class AzureBackend( CloudHaBackend.BackendBase ):

   def __init__( self, parsedConfig ):
      super( AzureBackend, self ).__init__()
      assert parsedConfig and isinstance( parsedConfig, AzureParsedConfig )
      self.parsedConfig = parsedConfig
      self.networkClient_ = None
      # disable SSL warnings
      # pylint: disable-msg=E1101
      urllib3.disable_warnings()
      azureDebugSetup()

   def getNetworkClient( self ):
      if not self.networkClient_:
         access = self.parsedConfig.cred
         if isinstance( access, AzureAdAccess ):
            t0( 'Initiating Azure network client using AD credentials' )
            credentials = UserPassCredentials( access.email, access.password )
            self.networkClient_ = NetworkManagementClient( credentials,
                                                   str( access.subscriptionId ) )
         elif isinstance( access, AzureSdkAuthAccess ):
            t0( 'Initiating Azure network client using SDK Auth credentials' )
            azCredFilePath = os.path.join( tempfile.mkdtemp(),
                                           'az_credentials.json' )
            with open( azCredFilePath, 'w' ) as azCredFile:
               azCredFile.write( str( access ) + '\n' )
            self.networkClient_ = get_client_from_auth_file(
               NetworkManagementClient, azCredFilePath )
            os.remove( azCredFilePath )
         elif isinstance( access, AzureMsiAccess ):
            t0( 'Initiating Azure network client using MSI credentials' )
            credentials = MSIAuthentication()
            subscriptionClient = SubscriptionClient( credentials )
            subscription = next( subscriptionClient.subscriptions.list() )
            self.networkClient_ = NetworkManagementClient( credentials,
                                                str( subscription.subscription_id ) )
         # setup http/https proxy
         if self.parsedConfig.httpProxy or self.parsedConfig.httpsProxy:
            self.networkClient_.config.proxies.use_env_settings = False
            if self.parsedConfig.httpProxy:
               self.networkClient_.config.proxies.add(
                  'http', self.parsedConfig.httpProxy )
            if self.parsedConfig.httpsProxy:
               self.networkClient_.config.proxies.add(
                  'https', self.parsedConfig.httpsProxy )
      return self.networkClient_

   # returns the list of items used to identify a route to be used as args
   @staticmethod
   def getRouteId( networkClient, resourceGroupName, routeTableName, prefix ):
      for route in networkClient.routes.list( resourceGroupName,
                                              routeTableName ):
         if route.address_prefix == prefix:
            return [ resourceGroupName, routeTableName, route.name ]
      raise BackendException( 'Route with prefix %s not found' %
                                             prefix )

   def updateRouteTables( self, routeConfig ):
      networkClient = self.getNetworkClient()
      for routeTable in routeConfig[ 'routeTables' ]:
         for route in routeTable[ 'routes' ]:
            routeId = AzureBackend.getRouteId( networkClient,
                                               routeConfig[ 'resourceGroupName' ],
                                               routeTable[ 'routeTableName' ],
                                               route[ 'prefix' ] )
            newRoute = Route( 'VirtualAppliance',
                              next_hop_ip_address=route[ 'nextHopIp' ],
                              address_prefix=route[ 'prefix' ] )
            networkClient.routes.create_or_update( * ( routeId + [ newRoute ] ) )

   def updatePeerRouteTables( self ):
      try:
         self.updateRouteTables( self.parsedConfig.peerRoutes )
      # TODO make exception more specific
      # pylint: disable-msg=W0703
      except Exception as e:
         t0( 'Failed to update peer routing tables' )
         self.updateResultMessage( str( e ) )
         return False
      else:
         self.updateResultMessage( '' )
         return True

   def updateLocalRouteTables( self ):
      try:
         self.updateRouteTables( self.parsedConfig.localRoutes )
      # TODO make exception more specific
      # pylint: disable-msg=W0703
      except Exception as e:
         t0( 'Failed to update local routing tables' )
         self.updateResultMessage( str( e ) )
         return False
      else:
         self.updateResultMessage( '' )
         return True

   def validateRoutingConfig( self, routeConfig ):
      #
      # Check the following:
      # - Credentials are valid
      # - Resource group exists
      # - Route tables exist
      # - Route entries with provided prefixes exist
      #
      # All four are checked by retrieving the route
      #
      networkClient = self.getNetworkClient()
      for routeTable in routeConfig[ 'routeTables' ]:
         for route in routeTable[ 'routes' ]:
            routeId = AzureBackend.getRouteId( networkClient,
                                               routeConfig[ 'resourceGroupName' ],
                                               routeTable[ 'routeTableName' ],
                                               route[ 'prefix' ] )
            networkClient.routes.get( *routeId )

   @ArPyUtils.Decorators.retry( retryCheckEmbeddedMethodName='transientFailure',
                                attempts=4, retryInterval=30 )
   def configValidator( self ):
      try:
         # validate azure python sdk installed
         # if imports failed, NetworkManagementClient would cause name error
         try:
            NetworkManagementClient
         except NameError as e:
            raise BackendException( 'Azure Python SDK not installed' )
         # validate routing configs
         self.validateRoutingConfig( self.parsedConfig.peerRoutes )
         self.validateRoutingConfig( self.parsedConfig.localRoutes )
      # TODO make exception more specific
      # pylint: disable-msg=W0703
      except Exception as e:
         errorMsg = str( e )
         if self.shouldRetryError( errorMsg ):
            self.updateResultFailureType( 'TRANSIENT' )
            errorMsg = self.mapExceptionStr( errorMsg )
         t0( 'Config validation failed, exception: %s' % errorMsg )
         self.updateResultMessage( errorMsg )
         return False
      else:
         self.updateResultMessage( '' )
         return True
   
   @staticmethod
   def shouldRetryError( errorMsg ):
      return any( i in errorMsg for i in [
         # Probably DNS issue.
         'Name or service not known',
         # This is generalized version of issue seen when MSI extention is not
         # yet fully provisioned  when provisiong/reloading the dut and should be
         # OK to retry on other cases too as it doesn't hurt.
         # exception: HTTPConnectionPool(host='localhost', port=50342): Max
         # retries  exceeded with url: /oauth2/token (Caused by NewConnectionError
         # ('<urllib3.connection.HTTPConnection object at 0xf0c383ac>: Failed to
         # establish a new connection: [Errno 111] Connection refused'
         'Connection refused',
         # Seems like urllib3 was recently updated adding a different exception
         # when network is unreachable, so adding the following to catch the new
         # error string
         'Network is unreachable',
      ] )
   
   @staticmethod
   def mapExceptionStr( errorMsg ):
      if "Name or service not known" in errorMsg:
         errorMsg = 'cannot access Azure, DNS not yet ready'
      return errorMsg

   
