# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import BasicCli
from CliMode.Models import (
      MgmtModelsMode,
      ProviderAFTMode,
      ProviderSmashMode,
      ProviderSysdbMode,
)
import CliParser
import ConfigMount

octaConfig = None

# ------------------------------------------------------
# Management models config commands
# ------------------------------------------------------

class MgmtModelsConfigMode( MgmtModelsMode, BasicCli.ConfigModeBase ):
   """CLI configuration mode 'management api models'."""

   name = "Models configuration"
   modeParseTree = CliParser.ModeParseTree()

   def __init__( self, parent, session ):
      self.config_ = octaConfig

      MgmtModelsMode.__init__( self, "api-models" )
      BasicCli.ConfigModeBase.__init__( self, parent, session )

def gotoMgmtModelsConfigMode( mode, args ):
   childMode = mode.childMode( MgmtModelsConfigMode )
   mode.session_.gotoChildMode( childMode )

def noMgmtModelsConfigMode( mode, args ):
   """Resets Models configuration to default."""
   noProviderSmashConfigMode( mode, None )
   noProviderSysdbConfigMode( mode, None )

# ------------------------------------------------------
# provider aft
# ------------------------------------------------------

class ProviderAFTConfigMode( ProviderAFTMode, BasicCli.ConfigModeBase ):
   """CLI configuration submode 'provider aft'."""

   name = 'Provider for AFT'
   modeParseTree = CliParser.ModeParseTree()

   def __init__( self, parent, session ):
      self.config_ = octaConfig
      ProviderAFTMode.__init__( self, "aft" )
      BasicCli.ConfigModeBase.__init__( self, parent, session )

def gotoProviderAFTConfigMode( mode, args ):
   childMode = mode.childMode( ProviderAFTConfigMode )
   mode.session_.gotoChildMode( childMode )

def noProviderAFTConfigMode( mode, args ):
   octaConfig.aftOptions.ipv4Unicast = False
   octaConfig.aftOptions.ipv6Unicast = False

def aftIPv4Unicast( mode, args ):
   octaConfig.aftOptions.ipv4Unicast = True

def noAFTIPv4Unicast( mode, args ):
   octaConfig.aftOptions.ipv4Unicast = False

def aftIPv6Unicast( mode, args ):
   octaConfig.aftOptions.ipv6Unicast = True

def noAFTIPv6Unicast( mode, args ):
   octaConfig.aftOptions.ipv6Unicast = False

# ------------------------------------------------------
# provider smash
# ------------------------------------------------------

class ProviderSmashConfigMode( ProviderSmashMode, BasicCli.ConfigModeBase ):
   """CLI configuration submode 'provider smash'."""

   name = 'Provider for Smash'
   modeParseTree = CliParser.ModeParseTree()

   def __init__( self, parent, session ):
      self.config_ = octaConfig
      smashIncludes = octaConfig.option.get( 'smashincludes', "" )
      if smashIncludes != "":
         self.smashIncludes = set( smashIncludes.split( ',' ) )
      else:
         self.smashIncludes = set()

      smashExcludes = octaConfig.option.get( 'smashexcludes', "" )
      if smashExcludes != "":
         self.smashExcludes = set( smashExcludes.split( ',' ) )
      else:
         self.smashExcludes = set()

      ProviderSmashMode.__init__( self, "smash" )
      BasicCli.ConfigModeBase.__init__( self, parent, session )

   def onExit( self ):
      """Commit onExit"""
      if self.smashIncludes:
         octaConfig.option[ 'smashincludes' ] = ','.join(
               sorted( self.smashIncludes ) )
      else:
         del octaConfig.option[ 'smashincludes' ]
      if self.smashExcludes:
         octaConfig.option[ 'smashexcludes' ] = ','.join(
               sorted( self.smashExcludes ) )
      else:
         del octaConfig.option[ 'smashexcludes' ]

def getSmashPathFromInput( smashPath ):
   smashPath = smashPath.strip( '/' )
   if smashPath.startswith( "Smash/" ):
      smashPath = smashPath[ 6: ]
   return smashPath

def gotoProviderSmashConfigMode( mode, args ):
   childMode = mode.childMode( ProviderSmashConfigMode )
   mode.session_.gotoChildMode( childMode )

def noProviderSmashConfigMode( mode, args ):
   option = octaConfig.option
   del option[ 'smashexcludes' ]
   del option[ 'smashincludes' ]

def setSmashPath( mode, args ):
   smashPath = getSmashPathFromInput( args.get( 'SMASHPATH', "" ) )
   if 'disabled' in args:
      mode.smashIncludes.discard( smashPath )
      mode.smashExcludes.add( smashPath )
   else:
      mode.smashExcludes.discard( smashPath )
      mode.smashIncludes.add( smashPath )

def noSmashPath( mode, args ):
   smashPath = getSmashPathFromInput( args.get( 'SMASHPATH', "" ) )
   if 'disabled' in args:
      mode.smashExcludes.discard( smashPath )
   else:
      mode.smashIncludes.discard( smashPath )

# ------------------------------------------------------
# provider sysdb
# ------------------------------------------------------

class ProviderSysdbConfigMode( ProviderSysdbMode, BasicCli.ConfigModeBase ):
   """CLI configuration submode 'provider sysdb'."""

   name = 'Provider for Sysdb'
   modeParseTree = CliParser.ModeParseTree()

   def __init__( self, parent, session ):
      self.config_ = octaConfig
      sysdbExcludes = octaConfig.option.get( 'sysdbexcludes', "" )
      if sysdbExcludes != "":
         self.sysdbExcludes = set( sysdbExcludes.split( ',' ) )
      else:
         self.sysdbExcludes = set()

      ProviderSysdbMode.__init__( self, "sysdb" )
      BasicCli.ConfigModeBase.__init__( self, parent, session )

   def onExit( self ):
      """Commit onExit"""
      if self.sysdbExcludes:
         octaConfig.option[ 'sysdbexcludes' ] = ','.join(
               sorted( self.sysdbExcludes ) )
      else:
         del octaConfig.option[ 'sysdbexcludes' ]

def getSysdbPathFromInput( sysdbPath ):
   sysdbPath = sysdbPath.strip( '/' )
   if sysdbPath.startswith( "Sysdb/" ):
      sysdbPath = sysdbPath[ 6: ]
   return sysdbPath

def gotoProviderSysdbConfigMode( mode, args ):
   childMode = mode.childMode( ProviderSysdbConfigMode )
   mode.session_.gotoChildMode( childMode )

def noProviderSysdbConfigMode( mode, args ):
   option = octaConfig.option
   del option[ 'sysdbexcludes' ]

def setSysdbPath( mode, args ):
   sysdbPath = getSysdbPathFromInput( args.get( 'SYSDBPATH', "" ) )
   mode.sysdbExcludes.add( sysdbPath )

def noSysdbPath( mode, args ):
   sysdbPath = getSysdbPathFromInput( args.get( 'SYSDBPATH', "" ) )
   mode.sysdbExcludes.discard( sysdbPath )

def Plugin( entityManager ):
   global octaConfig

   octaConfig = ConfigMount.mount( entityManager, "mgmt/octa/config",
                                   "Octa::Config", "w" )
