# Copyright (c) 2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import sys

import BasicCliUtil
import CliParser
import ConfigMgmtMode
import ConfigMount
import DscpCliLib
import IpLibConsts
import LazyMount
import PyClient
import Tac
import XmppModel

# These characters are not allowed to appear in a JID.
ILLEGAL_CHARS = ( "\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r"
                  "\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19"
                  "\x1a\x1b\x1c\x1d\x1e\x1f"
                  " !\"#$%&'()*+,./:;<=>?@[\\]^_`{|}~\x7f" )

xmppConfig = None
xmppStatus = None
ipConfig = None
dscpConfig = None

ConnectionState = Tac.Type( "Mgmt::Xmpp::Status::ConnectionState" )
defaultSrcIntf = Tac.Value( "Arnet::IntfId" )

# ----------------------------------------------
# Xmpp send command, from enable mode:
#    xmpp send NEIGHBOR command COMMAND

def xmppPyClient( mode ):
   xmppPyClient_ = mode.session.sessionData( 'xmppPyClient' )
   if not xmppPyClient_:
      xmppPyClient_ = PyClient.PyClient( mode.sysname, "Xmpp",
                                         initConnectCmd="import Xmpp" )
   mode.session.sessionDataIs( 'xmppPyClient', xmppPyClient_ )
   return xmppPyClient_

def sendXmppMessage( mode, args ):
   """Sends a message to the specified JID and returns immediately."""
   to = args[ 'JID' ].encode( "utf8" )
   message = args[ 'MESSAGE' ].encode( "utf8" )
   if not xmppStatus.connectionState == ConnectionState.connected:
      mode.addWarning( "Connection to the XMPP server is not established" )
      return

   pc = xmppPyClient( mode )
   pc.eval( "Xmpp.sendXmppMessage( %r, %r, group=False, captureResponse=False )" %
            ( to, message ) )
   return

def sendXmppCommand( mode, args ):
   # Ensure we send only 'str' type to PyClient (which is unicode incompatible)
   to = args[ 'JID' ].encode( "utf8" )
   command = args[ 'COMMAND' ].encode( "utf8" )

   group = '@conference' in to
   if not xmppStatus.connectionState == ConnectionState.connected:
      mode.addWarning( "Connection to the XMPP server is not established" )
      return
   if group and not to in xmppStatus.group:
      mode.addError( "switch-group %s does not exist" % to )
      return
   pc = xmppPyClient( mode )
   id_ = pc.eval( "Xmpp.sendXmppMessage( %r, %r, group=%s )" %
                  ( to, command, group ) )
   assert isinstance( id_, str )

   xmppResponseStatus = {}
   
   def _waitStatus( pc, func, to, id_, xmppResponseStatus, targetUsers=None ):
      ''' takes in a function and returns True if the function evaluated
      on the recipient(s) return a True value '''
      xmppResponseStatus.clear()
      xmppResponseStatus.update( pc.eval( "Xmpp.getXmppMessageStatus( '%s', '%s' )"
                                 % ( to, id_ ) ) )
      if not xmppResponseStatus:
         return False
      messageValues = [ status for user, status in xmppResponseStatus.items()
                        if not targetUsers or user in targetUsers ]
      return all( map( func, messageValues ) )

   try:
      # wait for status to become 'composing' or 'active' signaling acknowledgment
      Tac.waitFor( lambda : _waitStatus(pc, ''.__ne__ , to, id_, xmppResponseStatus),
                   timeout=15,
                   warnAfter=False,
                   sleep=True,
                   description="command acknowledgement" )
   except Tac.Timeout as e:
      responsiveUsers = set( user for user, status in xmppResponseStatus.items()
                             if status )
      unresponsiveUsers = ( user for user in xmppResponseStatus
                            if user not in responsiveUsers )
      mode.addError( e )
      mode.addError( "Did not receive command acknowledgement "
                     "from the following switch(es):" )
      for user in sorted( unresponsiveUsers ):
         mode.addError( user )
      if not responsiveUsers:
         mode.addError( "No one acknowledged your command. "
                        "No longer waiting for a response" )
         pc.eval( "Xmpp.stopAwaitingXmppReply( '%s', '%s' )" % ( to, id_ ) )
         return

   # some switches sent a compose, we will wait for them
   composedUsers = set( user for user, status in xmppResponseStatus.items()
                        if status )

   try:
      # wait for status to become 'active' signaling completion.
      Tac.waitFor( lambda :
                   _waitStatus( pc, 'active'.__eq__, to, id_, xmppResponseStatus,
                                composedUsers ),
                   timeout=600,
                   warnAfter=60,
                   maxDelay=3,
                   sleep=True,
                   description="command response")
   except Tac.Timeout as e:
      responsiveUsers = set( user for user, status in xmppResponseStatus.items()
                             if status=='active' )
      unresponsiveUsers = ( user for user in xmppResponseStatus
                            if user not in responsiveUsers )
      mode.addError( e )
      mode.addError( "Did not receive command response "
                     "from the following switch(es):" )
      for user in sorted( unresponsiveUsers ):
         mode.addError( user )
      if not responsiveUsers:
         mode.addError( "No one responded to your command. "
                        "No longer waiting for a response." )
         pc.eval( "Xmpp.stopAwaitingXmppReply( '%s', '%s' )" % ( to, id_ ) )
         return

   # some switches finished responding, we will print them out
   activeUsers = set( user for user, status in xmppResponseStatus.items()
                      if status=='active' )
   reply = pc.eval( "Xmpp.getXmppReply( '%s', '%s' )" % ( to, id_ ) )
   for user in sorted( reply ):
      if user in activeUsers:
         print 'message from user: %s\n%s\n%s' % ( user, '-' * 50, reply[ user ] )

def neighborsAndGroups( mode ):
   d = {}
   for k in xmppStatus.neighbor:
      d[ k ] = k
   for k in xmppStatus.group:
      d[ k ] = k
   return d

# ----------------------------------------------
# Xmpp show commands, from unpriv mode:
#    show xmpp status
#    show xmpp neighbors
def _buildStatusModel():
   model = XmppModel.XmppStatus()
   model.serverAddress = xmppConfig.ipAddrOrHostname
   model.port = xmppConfig.port
   model.vrfName = xmppConfig.vrfName or IpLibConsts.DEFAULT_VRF
   model.srcInterface = xmppStatus.srcIntfName
   model.srcIpAddr = xmppStatus.srcIpAddr
   model.connectionState = xmppStatus.connectionState
   model.statusMessage = xmppStatus.presenceStatus
   model.enabled = xmppStatus.enabled
   model.initialPrivLevel = xmppConfig.initialPrivLevel
   model.username = xmppConfig.username
   model.domainName = xmppConfig.domainName
   return model

def _buildNeighbors():
   model = XmppModel.Neighbors()
   for name, neighborStatus in xmppStatus.neighbor.items():
      neighbor = XmppModel.Neighbor()
      neighbor.state = neighborStatus.state
      neighbor.statusMessage = neighborStatus.presenceStatus
      if neighborStatus.lastUpdateTime:
         neighbor.lastUpdateTimestamp = ( neighborStatus.lastUpdateTime +
                                          Tac.utcNow() - Tac.now() )
      else:
         neighbor.lastUpdateTimestamp = 0.0  # must be a float
      model.neighbors[ name ] = neighbor
   return model

def _buildGroups():
   model = XmppModel.Groups()
   for name in xmppStatus.group:
      # Any group that is in the list is enabled; we use a dictionary
      # of models now so that we can add information to Group in a
      # later Xmpp.tac update, which is expected for example, to
      # provide the group form details (such as authentication state)
      group = XmppModel.Group()
      group.enabled = True
      model.groups[ name ] = group
   return model

def showXmppStatus( mode, args ):
   """Builds the 'show xmpp status' command output."""
   return _buildStatusModel()

def showXmppNeighbors( mode, args ):
   """Builds the 'show xmpp neighbors' command output."""
   if not xmppStatus.enabled:
      mode.addWarning( "XMPP is not enabled" )
   elif not xmppStatus.connectionState == ConnectionState.connected:
      mode.addWarning(
         "Connection to the XMPP server is not established" )
   return _buildNeighbors()

# ----------------------------------------------
# Xmpp session command, from xmpp enable mode:
#    xmpp session SWITCHGROUP

@BasicCliUtil.EapiIncompatible()
def xmppSession( mode, args ):
   to = args[ 'JID' ]
   if not "@" in to:
      # Shortname given
      to += "@%s" % xmppConfig.domainName
   args = [ "XmppCli", "--sysname=%s" % mode.session_.sysname, to ]
   try:
      Tac.run( args, stdin=sys.stdin, stderr=sys.stderr, stdout=sys.stdout )
   except Tac.SystemCommandError, e:
      cmd = ' '.join(args)
      mode.addError( "'%s' returned error code: %s" % ( cmd, str( e.error ) ) )
      if e.output:
         mode.addError( e.output )

# ----------------------------------------------
# Xmpp config commands, from xmpp config mode:
#    [no] shutdown
#    server HOSTNAME port [PORT]
#    username NAME password PASSWORD
class XmppConfigMode( ConfigMgmtMode.ConfigMgmtMode ):
   """CLI configuration mode 'management xmpp'."""

   name = "XMPP configuration"
   modeParseTree = CliParser.ModeParseTree()

   def __init__( self, parent, session ):
      ConfigMgmtMode.ConfigMgmtMode.__init__( self, parent, session, "xmpp" )
      self.config_ = xmppConfig
      self.dscpConfig_ = dscpConfig

   def shutdown( self, args ):
      self.config_.enabled = False

   def noShutdown( self, args ):
      self.config_.enabled = True

   def configureServer( self, args ):
      hostname = args[ 'SERVER' ]
      if 'PORT' in args:
         self.config_.port = args[ 'PORT' ]
      self.config_.ipAddrOrHostname = hostname
      self.updateDscpRules()

   def noServer( self, args ):
      self.config_.port = self.config_.defaultPort
      self.config_.vrfName = IpLibConsts.DEFAULT_VRF
      self.config_.ipAddrOrHostname = ""
      self.updateDscpRules()

   def configureVrf( self, args ):
      vrfName = args.get( 'VRF' )
      # config_.vrfName is a L3::VrfName, so we must assign it a string.
      if not vrfName or vrfName == IpLibConsts.DEFAULT_VRF_OLD:
         self.config_.vrfName = IpLibConsts.DEFAULT_VRF
      else:
         self.config_.vrfName = vrfName
      self.updateDscpRules()

   def noVrf( self, args ):
      self.config_.vrfName = IpLibConsts.DEFAULT_VRF
      self.updateDscpRules()

   def configureDomainName( self, domainName ):
      """Sets the XMPP domain name, lower casing it in the process."""
      error = _checkDomain( domainName )
      if error is None:
         self.config_.domainName = domainName.lower().encode( "utf8" )
      else:
         self.addError( error )

   def noDomainName( self, args ):
      self.config_.domainName = ""

   def configureUsername( self, args ):
      """Configures the CLI's XMPP username and password.

      This method will extract the domain automatically from mail
      addresses (i.e., those with a '@').

      The supplied username/mail address will be lower cased prior
      to being set in the configuration.
      """
      username = args[ 'USERNAME' ]
      password = args[ 'PASSWORD' ]
      if '@' in username and not self.config_.domainName:
         domain = username[ username.index( '@' ) + 1: ]
         self.configureDomainName( domain )
      try:
         username = appendDomain( self, username )
      except ValueError:
         return

      if username:
         self.config_.username = username.encode( "utf8" )
         self.config_.password = password.encode( "utf8" )

   def noUsername( self, args ):
      self.config_.username = ""
      self.config_.password = ""
      
   def configurePrivLevel( self, args ):
      self.config_.initialPrivLevel = args[ 'PRIVILEGE' ]

   def noPrivLevel( self, args ):
      self.config_.initialPrivLevel = self.config_.defaultInitialPrivLevel

   def addGroup( self, args ):
      group = args[ 'GROUPNAME' ]
      password = args.get( 'PASSWORD' )
      password = password or ""
      try:
         group = appendDomain( self, group, group=True )
      except ValueError:
         return
      else:
         g = self.config_.group.get( group )
         if g:
            if g.password == password:
               # Nothing changed
               return
            else:
               # delete the group and recreate it if changed
               del self.config_.group[ group ]
         self.config_.newGroup( group, password )

   def removeGroups( self, args ):
      groups = args.get( 'GROUPNAME' )
      if groups:
         for g in groups:
            try:
               g = appendDomain( self, g, group=True )
            except ValueError:
               return
            del self.config_.group[ g ]
      else:
         self.config_.group.clear()

   def configurePermitUnencrypted( self, args ):
      self.config_.starttlsPermitUnencrypted = True

   def defaultPermitUnencrypted( self, args ):
      self.config_.starttlsPermitUnencrypted = False

   def setSrcIntfName( self, args ):
      intf = args[ 'IPINTF' ]
      ipAddr = ipAddrZero = Tac.Value( "Arnet::IpAddr" ).ipAddrZero
      if intf.name in ipConfig.ipIntfConfig:
         ipIntfConfig = ipConfig.ipIntfConfig[ intf.name ]
         ipAddr = ipIntfConfig.addrWithMask.address
      if ipAddrZero == ipAddr:
         self.addWarning( "Interface IP address not configured" )
      self.config_.srcIntfName = intf.name

   def noSrcIntfName( self, args ):
      self.config_.srcIntfName = defaultSrcIntf

   def updateDscpRules( self ):
      dscpValue = self.config_.dscpValue

      if not dscpValue:
         del self.dscpConfig_.protoConfig[ 'xmpp' ]
         return

      protoConfig = self.dscpConfig_.newProtoConfig( 'xmpp' )
      ruleColl = protoConfig.rule
      ruleColl.clear()

      if self.config_.ipAddrOrHostname:
         # Traffic to external xmpp server.
         DscpCliLib.addDscpRule( ruleColl, self.config_.ipAddrOrHostname,
                                 self.config_.port, False, self.config_.vrfName,
                                 'tcp', dscpValue )

   def setDscp( self, args ):
      self.config_.dscpValue = args[ 'DSCP' ]
      self.updateDscpRules()

   def noDscp( self, args=None ):
      self.config_.dscpValue = self.config_.dscpValueDefault
      self.updateDscpRules()

def gotoXmppConfigMode( mode, args ):
   childMode = mode.childMode( XmppConfigMode )
   mode.session_.gotoChildMode( childMode )

def noXmppConfigMode( mode, args ):
   """Resets Xmpp agent configuration to default."""
   # Disable first to trigger shutdown of an active XMPP client.
   xmppConfig.enabled = False
   xmppConfig.domainName = ""
   # Empty the switch group configuration collection
   xmppConfig.group.clear()
   xmppConfig.initialPrivLevel = xmppConfig.defaultInitialPrivLevel
   xmppConfig.ipAddrOrHostname = ""
   xmppConfig.password = ""
   xmppConfig.port = xmppConfig.defaultPort
   xmppConfig.username = ""
   xmppConfig.vrfName = IpLibConsts.DEFAULT_VRF
   xmppConfig.srcIntfName = defaultSrcIntf
   mode.childMode( XmppConfigMode ).noDscp()

# ----------------------------------------------
# Xmpp group commands, from XmppConfig mode:
#    [no] switch-group GROUP1 GROUP2 GROUP3
#    no switch-group [GROUP1 GROUP2 GROUP3]
#    show xmpp switch-group

def showXmppGroup( mode, args ):
   """Prints a list of joined group names."""
   return _buildGroups()

def appendDomain( mode, username, group=False ):
   """Appends the configured domain name to the supplied username.

   The group argument (bool) specifies whether a conference chat (group)
   format is required.

   The username/mail address is lower cased to avoid triggering BUG25431
   """
   username = username.lower()
   if '@' in username:
      return username
   elif xmppConfig.domainName:
      if group:
         return "%s@conference.%s" % ( username, xmppConfig.domainName )
      else:
         return "%s@%s" % ( username, xmppConfig.domainName )
   else:
      mode.addWarning( "No XMPP domain name configured. Use 'domain NAME' first" )
      raise ValueError

def _checkDomain( domain ):
   """Checks that the domain does not contain invalid characters.

   The sleekxmpp.jid.JID class checks JIDs, but does so in a broken manner,
   so we perform the necessary checks here prior to instantiating the client.

   In addition, we cannot import any of sleekxmpp as they all eventually
   import sleekxmpp.xmlstream, which imports ssl. ssl is on the list of
   memory hog packages in Eos/test/CliTests.py, so we must avoid importing
   it in the CLI process.

   Args:
      domain: str, the domain name to check for errors.

   Returns:
      None if there was no error, else a str error message.
   """
   if "\x00" in domain:
      return "Domain contains the null character"
   if domain.startswith( "-" ) or domain.endswith( "-" ):
      return "Domain cannot start or end with a '-'"

   illegalChars = []
   for part in domain.split( "." ):
      if not part:
         return "Domain contains consecutive '.', which is invalid"
      for char in part:
         if char in ILLEGAL_CHARS:
            illegalChars.append( char )
      if illegalChars:
         suffix = ""
         if len( illegalChars ) > 1:
            suffix = "s"
         return "Domain contains the illegal character%s: %s" % (
            suffix, ", ".join( "'%s'" % ( c, ) for c in illegalChars ) )
   return None

def Plugin( entityManager ):
   global xmppConfig, xmppStatus, ipConfig, dscpConfig
   xmppConfig = ConfigMount.mount( entityManager, "mgmt/xmpp/config",
                                   "Mgmt::Xmpp::Config", "w" )
   xmppStatus = LazyMount.mount( entityManager, "mgmt/xmpp/status",
                                 "Mgmt::Xmpp::Status", "r" )
   ipConfig = LazyMount.mount( entityManager, "ip/config",
                                 "Ip::Config", "r" )
   dscpConfig = ConfigMount.mount( entityManager,  "mgmt/dscp/config",
                                   "Mgmt::Dscp::Config", "w" )
