# Copyright (c) 2016 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import AclCliLib
import BasicCli
from CliMode.Restconf import MgmtRestconfMode, RestconfTransportMode
import CliParser
import CliPlugin.AclCli as AclCli
import CliMatcher
import ConfigMount
import DscpCliLib
from IpLibConsts import DEFAULT_VRF
import LazyMount

import OpenConfigCliLib

sslConfig = None
restconfConfig = None
restconfStatus = None
restconfCheckpoint = None
aclConfig = None
aclCpConfig = None

# ------------------------------------------------------
# RESTCONF config commands
# ------------------------------------------------------

class MgmtRestconfConfigMode( MgmtRestconfMode, BasicCli.ConfigModeBase ):
   """CLI configuration mode 'management api restconf'."""

   name = "RESTCONF configuration"
   modeParseTree = CliParser.ModeParseTree()

   def __init__( self, parent, session ):
      self.config_ = restconfConfig

      MgmtRestconfMode.__init__( self, "api-restconf" )
      BasicCli.ConfigModeBase.__init__( self, parent, session )

def gotoMgmtRestconfConfigMode( mode, args ):
   childMode = mode.childMode( MgmtRestconfConfigMode )
   mode.session_.gotoChildMode( childMode )

def noMgmtRestconfConfigMode( mode, args ):
   """Resets RESTCONF configuration to default."""
   restconfConfig.enabled = False
   for name in restconfConfig.endpoints:
      noRestconfTransportConfigMode( mode, { 'TRANSPORT_NAME': name } )

class RestconfTransportConfigMode( RestconfTransportMode, BasicCli.ConfigModeBase ):
   """CLI configuration submode 'transport https <name>'."""

   name = 'Transport for RESTCONF'
   modeParseTree = CliParser.ModeParseTree()

   def __init__( self, parent, session, name ):
      self.config_ = restconfConfig
      self.name = name

      RestconfTransportMode.__init__( self, name )
      BasicCli.ConfigModeBase.__init__( self, parent, session )

def gotoRestconfTransportConfigMode( mode, args ):
   name = args[ 'TRANSPORT_NAME' ]
   if name not in restconfConfig.endpoints:
      if OpenConfigCliLib.otherEnabledTransportExists( mode, name):
         return

      endpoint = restconfConfig.newEndpoints( name )
      endpoint.transport = 'https'
      endpoint.vrfName = DEFAULT_VRF
      # since 'initially' attributes don't get attrlogged, be explicit for now
      endpoint.port = endpoint.portDefault
      endpoint.enabled = True

      OpenConfigCliLib.updateLevelEnabledFlag( restconfConfig )

   childMode = mode.childMode( RestconfTransportConfigMode, name=name )
   mode.session_.gotoChildMode( childMode )

def noRestconfTransportConfigMode( mode, args ):
   name = args[ 'TRANSPORT_NAME' ]
   endpoint = restconfConfig.endpoints.get( name )
   if endpoint is not None:
      if endpoint.enabled:
         AclCliLib.noServiceAclTypeVrfMap( mode,
                                           restconfConfig.serviceAclTypeVrfMap,
                                           None, "ip", endpoint.vrfName )
         endpoint.enabled = False
      endpoint.port = endpoint.portDefault
      try:
         del restconfConfig.endpoints[ name ]
      except KeyError:
         pass

   OpenConfigCliLib.updateLevelEnabledFlag( restconfConfig )

def shutdown( mode, args ):
   endpoint = mode.config_.endpoints[ mode.name ]
   if endpoint.enabled:
      AclCliLib.noServiceAclTypeVrfMap( mode, mode.config_.serviceAclTypeVrfMap,
                                        None, "ip", endpoint.vrfName )
      endpoint.enabled = False
      OpenConfigCliLib.updateLevelEnabledFlag( restconfConfig )

def noShutdown( mode, args ):
   # This stanza can only be enabled if there is no other
   # stanza with same 'transport' type already enabled
   endpoint = mode.config_.endpoints.get( mode.name )
   for e in mode.config_.endpoints.itervalues():
      if e.name != endpoint.name and e.enabled and e.transport == endpoint.transport:
         mode.addError( "transport '%s' of type '%s' already "
               "enabled; can not enable another" % ( e.name, e.transport ) )
         return
   endpoint.enabled = True
   mode.config_.enabled = True
   AclCliLib.setServiceAclTypeVrfMap( mode, mode.config_.serviceAclTypeVrfMap,
                                      endpoint.serviceAcl, "ip",
                                      endpoint.vrfName )

def setSslProfile( mode, args ):
   profileName = args[ 'PROFILENAME' ]
   mode.config_.endpoints[ mode.name ].sslProfile = profileName

def noSslProfile( mode, args ):
   mode.config_.endpoints[ mode.name ].sslProfile = ''

profileNameMatcher = CliMatcher.DynamicNameMatcher(
      lambda mode: sslConfig.profileConfig,
      'Profile name')

def setVrfName( mode, args ):
   vrfName = args.get( 'VRFNAME', DEFAULT_VRF)
   mode.config_.endpoints[ mode.name ].vrfName = vrfName

def setPort( mode, args ):
   port = args.get( 'PORT', mode.config_.endpoints[ mode.name ].portDefault)
   mode.config_.endpoints[ mode.name ].port = port

# ---------------------------------------------------------------------
# switch(config-mgmt-api-restconf-transport-<name>)# qos dscp <dscpValue>
# ---------------------------------------------------------------------
def setDscp( mode, args ):
   mode.config_.endpoints[ mode.name ].qosDscp = args[ 'DSCP' ]

def noDscp( mode, args ):
   mode.config_.endpoints[ mode.name ].qosDscp = 0

DscpCliLib.addQosDscpCommandClass( RestconfTransportConfigMode, setDscp, noDscp )

def setRestconfAcl( mode, args ):
   aclName = args[ 'ACLNAME' ]
   AclCliLib.checkServiceAcl( mode, aclConfig, aclName )
   endpoint = mode.config_.endpoints.get( mode.name )
   endpoint.serviceAcl = aclName
   if endpoint.enabled:
      AclCliLib.setServiceAclTypeVrfMap( mode, mode.config_.serviceAclTypeVrfMap,
                                         endpoint.serviceAcl, "ip",
                                         endpoint.vrfName )

def noRestconfAcl( mode, args ):
   endpoint = mode.config_.endpoints.get( mode.name )
   if endpoint.enabled:
      AclCliLib.noServiceAclTypeVrfMap( mode, mode.config_.serviceAclTypeVrfMap,
                                        None, "ip", endpoint.vrfName )
   endpoint.serviceAcl = ""

#-------------------------------------------------------------------------------
# The "show management api restconf access-list" command
#-------------------------------------------------------------------------------
def showRestconfAcl( mode, args ):
   params = [ args.get( 'ACL' ), 'summary' in args ]
   return AclCli.showServiceAcl( mode,
                                 restconfConfig.serviceAclTypeVrfMap,
                                 restconfStatus.aclStatusService,
                                 restconfCheckpoint,
                                 'ip', params, supressVrf=True )

def clearRestconfAclCounters( mode, args ):
   AclCli.clearServiceAclCounters( mode,
                                   restconfStatus.aclStatusService,
                                   restconfCheckpoint, 'ip' )

def Plugin( entityManager ):
   global restconfConfig, sslConfig, aclConfig, aclCpConfig
   global restconfStatus, restconfCheckpoint
   restconfConfig = ConfigMount.mount( entityManager, "mgmt/restconf/config",
                                   "Restconf::Config", "w" )
   restconfStatus = LazyMount.mount( entityManager, "mgmt/restconf/status",
                                     "Restconf::Status", "r" )
   restconfCheckpoint = LazyMount.mount( entityManager,
                                         "mgmt/restconf/checkpoint",
                                         "Acl::CheckpointStatus", "w" )
   sslConfig = LazyMount.mount( entityManager,
                                "mgmt/security/ssl/config",
                                "Mgmt::Security::Ssl::Config",
                                "r" )
   aclConfig = ConfigMount.mount( entityManager, "acl/config/cli",
                                  "Acl::Input::Config", "w" )
   aclCpConfig = ConfigMount.mount( entityManager, "acl/cpconfig/cli",
                                    "Acl::Input::CpConfig", "w" )
