# Copyright (c) 2012 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from CliMode.HttpService import HttpServerVrfConfigModeBase
import CliSave
import Management
import Tac
from HttpServiceConstants import ServerConstants

capiConstants = Tac.Type( "HttpService::Constants" )
logLevels = Tac.Type( "HttpService::LogLevel" )

class HttpServerConfigMode( Management.MgmtConfigMode ):

   def __init__( self, param ):
      Management.MgmtConfigMode.__init__( self, "http-server" )

CliSave.GlobalConfigMode.addChildMode( HttpServerConfigMode )
HttpServerConfigMode.addCommandSequence( 'Mgmt.http-server' )

class HttpServerVrfConfigMode( HttpServerVrfConfigModeBase, CliSave.Mode ):

   def __init__( self, param ):
      HttpServerVrfConfigModeBase.__init__( self, param )
      CliSave.Mode.__init__( self, param )

   def skipIfEmpty( self ):
      return True

HttpServerConfigMode.addChildMode( HttpServerVrfConfigMode,
                                   after=[ 'Mgmt.http-server' ] )
HttpServerVrfConfigMode.addCommandSequence( 'Mgmt.http-server-vrf' )

@CliSave.saver( 'HttpService::Config', 'mgmt/capi/config',
                requireMounts=( 'acl/cpconfig/cli', ) )
def saveConfig( config, root, sysdbRoot, options, requireMounts ):
  
   # Build list of commands
   if config.useHttpServiceCli:
      httpServerCmds = saveCommonHttpServerCmds( config, options )
   
      if httpServerCmds:
         mode = root[ HttpServerConfigMode ].getOrCreateModeInstance(
               'http-server' )
         currCmds = mode[ 'Mgmt.http-server' ]
         for cmd in httpServerCmds:
            currCmds.addCommand( cmd )
   
   saveVrf( config, root, options, requireMounts )

def saveCommonHttpServerCmds( config, options ):

   cmds = []
   
   # Need to save https before http, since it defaults to "enabled":
   saveProtocolConf( cmds, 'https', config.httpsConfig,
                     capiConstants.defaultSecurePort,
                     capiConstants.defaultSecureEnabled,
                     options.saveAll)
   saveProtocolConf( cmds, 'http', config.httpConfig,
                     capiConstants.defaultInsecurePort,
                     capiConstants.defaultInsecureEnabled,
                     options.saveAll )
   saveProtocolConf( cmds, 'http localhost', config.localHttpConfig,
                     capiConstants.defaultInsecureLocalPort,
                     capiConstants.defaultInsecureLocalEnabled,
                     options.saveAll )
   saveUnixProtocolConf( cmds, config.unixConfig, capiConstants.defaultUnixEnabled, 
                         options.saveAll )
   saveQosParams( cmds, options, config.qosDscp )
   saveLoggingParams( cmds, options, config.syslogLevel )
   saveDefaultServicesEnabled( cmds, options, config.defaultServicesEnabled )
   saveContentFrameAncestorsParams( cmds, options, config.contentFrameAncestor )
   saveCorsOrigins( cmds, options, config.corsAllowedOrigins )
   saveSslProfile( cmds, options, config.httpsConfig.sslProfile )

   return cmds
      
def saveVrf( config, root, options, requireMounts ):
   def saveServiceAcl( aclType ):
      for vrf in sorted( aclCpConfig.cpConfig[ aclType ].serviceAcl ):
         vrfConfig = aclCpConfig.cpConfig[ aclType ].serviceAcl[ vrf ]
         serviceAclConfig = vrfConfig.service.get( ServerConstants.serviceName ) 
         if serviceAclConfig and serviceAclConfig.aclName:
            mode = root[ HttpServerConfigMode 
                       ].getOrCreateModeInstance( 'http-server' )
            vrfMode = mode[ HttpServerVrfConfigMode 
                          ].getOrCreateModeInstance( ( vrf, config ) )
            cmds = vrfMode[ 'Mgmt.http-server-vrf' ]
            cmds.addCommand( '%s access-group %s in' % ( aclType, 
                                                         serviceAclConfig.aclName ) )

   # Save the VRF in the proper sub-submode
   for ( vrf, vrfConfig ) in config.vrfConfig.items():
      if vrfConfig.serverState == "globalDefault" and not options.saveAll:
         continue
      mode = root[ HttpServerConfigMode ].getOrCreateModeInstance( 'http-server' )
      vrfMode = mode[ HttpServerVrfConfigMode 
                    ].getOrCreateModeInstance( ( vrf, config ) )
      vrfCmd = vrfMode[ 'Mgmt.http-server-vrf' ]
      if vrfConfig.serverState == "enabled":
         vrfCmd.addCommand( 'no shutdown' )
      elif vrfConfig.serverState == "disabled":
         vrfCmd.addCommand( 'shutdown' )
      else:
         vrfCmd.addCommand( 'default shutdown' )
 
   # Save the VRF Acl configuration in the proper sub-submode
   aclCpConfig = requireMounts[ 'acl/cpconfig/cli' ]
   for t in ( 'ip', 'ipv6' ):
      saveServiceAcl( t )

def saveUnixProtocolConf( cmds, conf, defaultEnabled, saveAll ):
   if conf.enabled and not defaultEnabled:
      cmds.append( "protocol unix-socket" )
   elif not conf.enabled and defaultEnabled:
      cmds.append( "no protocol unix-socket" )
   elif saveAll:
      prefix = "" if defaultEnabled else "no "
      cmds.append( "%sprotocol unix-socket" % prefix )

def saveProtocolConf( cmds, proto, conf, defaultPort, defaultEnabled, saveAll ):
   if conf.port != defaultPort and conf.enabled:
      cmds.append( "protocol %s port %d" % ( proto, conf.port ) )
   elif conf.enabled != defaultEnabled:
      cmds.append( "%sprotocol %s" %
                   ( "" if conf.enabled else "no ", proto ) )
   elif saveAll:
      if defaultEnabled:
         cmds.append( "protocol %s port %s" % ( proto, defaultPort ) )
      else:
         cmds.append( "no protocol %s port %s" % ( proto, defaultPort ) )

def saveSslProfile( cmds, options, sslProfile ):
   if not sslProfile:
      if options.saveAll:
         cmds.append( 'no protocol https ssl profile' )
   else:
      cmds.append( 'protocol https ssl profile %s' % sslProfile )

def saveCorsOrigins( cmds, options, origins ):
   if not origins:
      if options.saveAll:
         cmds.append( "no cors allowed-origin" )
   else:
      for key in origins:
         cmds.append( "cors allowed-origin %s" % key )

def saveContentFrameAncestorsParams( cmds, options, frameAncestor ):
   if not frameAncestor:
      if options.saveAll:
         cmds.append( "no header csp frame-ancestors" )
   else:
      cmds.append( "header csp frame-ancestors %s" % frameAncestor )

def saveQosParams( cmds, options, dscpValue ):
   if dscpValue == 0:
      if options.saveAll:
         cmds.append( "qos dscp %s" % dscpValue )
   else:
      cmds.append( "qos dscp %s" % dscpValue )

def saveLoggingParams( cmds, options, loggingLevel ):
   if loggingLevel == logLevels.none:
      if options.saveAll:
         cmds.append( "no log-level" )
   else:
      cmds.append( "log-level %s" % loggingLevel )

def saveDefaultServicesEnabled( cmds, options, enabled ):
   if not enabled:
      cmds.append( "no default-services" )
   elif options.saveAll:
      cmds.append( "default-services" )

def sslConfigs( cmds, cmd, config, options, defaultValues, hasConfig ):
   if config:
      cmds.append( cmd % " ".join( config.itervalues() ) )
   elif hasConfig:
      cmds.append( cmd % " ".join( defaultValues ) )

# attrPath is an array of strings that define the path off of
# capiConfig. For example, [ 'httpsConfig', 'enabled' ] corresponds to
# the attribute capiConfig[ 'httpsConfig' ].httpsConfig.enabled
def getAttr( config, attrPath ):
   val = config
   for attr in attrPath:
      val = getattr( val, attr )
   return val
