#!/usr/bin/env python
# Copyright (c) 2015 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import CliSave
from CliMode.Gnmi import MgmtGnmiMode, GnmiTransportMode
from IpLibConsts import DEFAULT_VRF

class MgmtGnmiSaveMode( MgmtGnmiMode, CliSave.Mode ):
   def __init__( self, param ):
      MgmtGnmiMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

class GnmiTransportSaveMode( GnmiTransportMode, CliSave.Mode ):
   def __init__( self, param ):
      GnmiTransportMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

CliSave.GlobalConfigMode.addChildMode( MgmtGnmiSaveMode )
MgmtGnmiSaveMode.addCommandSequence( "Mgmt.gnmi" )

MgmtGnmiSaveMode.addChildMode( GnmiTransportSaveMode )
GnmiTransportSaveMode.addCommandSequence( 'Mgmt.gnmi.transport' )

@CliSave.saver( "Gnmi::Config", "mgmt/gnmi/config",
      requireMounts = ( 'mgmt/octa/config', ) )
def saveGnmi( gnmiConfig, root, sysdbRoot, options, requireMounts ):
   octaConfig = requireMounts[ 'mgmt/octa/config' ]
   if octaConfig.enabled:
      mode = root[
         MgmtGnmiSaveMode ].getOrCreateModeInstance( "api-gnmi" )
      cmds = mode[ "Mgmt.gnmi" ]
      cmds.addCommand( "provider eos-native" )
   for name in gnmiConfig.endpoints:
      parentMode = root[
            MgmtGnmiSaveMode ].getOrCreateModeInstance( "api-gnmi" )
      mode = parentMode[ GnmiTransportSaveMode ].getOrCreateModeInstance(
                           name )
      cmds = mode[ 'Mgmt.gnmi.transport' ]
      endpoint = gnmiConfig.endpoints[ name ]

      if not endpoint.enabled:
         cmds.addCommand( "shutdown" )
      elif options.saveAll:
         cmds.addCommand( "no shutdown" )

      if endpoint.sslProfile != '':
         cmds.addCommand( "ssl profile %s" % endpoint.sslProfile )
      elif options.saveAll:
         cmds.addCommand( 'no ssl profile' )

      if endpoint.port != endpoint.portDefault or options.saveAll:
         cmds.addCommand( "port %s" % endpoint.port )

      if endpoint.vrfName != DEFAULT_VRF or options.saveAll:
         cmds.addCommand( "vrf %s" % endpoint.vrfName )

      if endpoint.serviceAcl != '':
         cmds.addCommand( 'ip access-group %s' % endpoint.serviceAcl )
      elif options.saveAll:
         cmds.addCommand( 'no ip access-group' )

      if endpoint.qosDscp:
         cmds.addCommand( "qos dscp %s" % endpoint.qosDscp )

      if endpoint.authorization:
         cmds.addCommand( "authorization requests" )
