#!/usr/bin/env python2.7
#
# Copyright (c) 2017-2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

''' Access a Check Point Software Inc. Management Server and firewall/gateway
    via the REST API.  Reference:
    https://sc1.checkpoint.com/documents/latest/APIs/index.html#introduction~v1.1
'''
# pylint: disable=E1101

import re
import json
import requests
import urllib3
from copy import deepcopy
from MssPolicyMonitor import Lib
from MssPolicyMonitor.Error import ServiceDeviceError, FirewallAPIError
from MssPolicyMonitor.Lib import ( t0, t2, t4,
                                   LINK_STATE_UNKNOWN, LINK_STATE_UP,
                                   LINK_STATE_DOWN, HA_ACTIVE, HA_PASSIVE,
                                   HA_ACTIVE_PASSIVE, HA_ACTIVE_ACTIVE,
                                   isValidIpAddr )

from MssPolicyMonitor.PluginLib import ( ServiceDeviceHAState, ServiceDevicePolicy,
                                         NetworkInterface,
                                         ServiceDeviceRoutingTables )
import pprint
pp = pprint.PrettyPrinter( indent=2 )
def printObj( obj ):
   pp.pprint( obj )

# TODO: suppress warning message until we support TLS/SSL certificate validation
urllib3.disable_warnings( urllib3.exceptions.InsecureRequestWarning )


POLICY_TAGS_REGEX = re.compile( r'tags[[(](.+?)[])]', re.IGNORECASE | re.DOTALL )
DEFAULT_VLAN_RANGE = '1-4094'
INTF_TYPE_FILTER = [ 'eth', 'bond' ]  # ignore other intf types, bond = LAG intf

# The Checkpoint Management Server defines "Network Layers" that are groups of
# security policies. Mss configuration requires specifying the network layer to be
# considred. The APIs refer to each layer as "<name> Network"
# The default network layer is called "Standard." However, the APIs do not use this
# name, all replies refert this layer as "Network" without the layer's name
DEFAULT_POLICY_LAYER_NAME = 'Standard Network'
DEFAULT_POLICY_LAYER_NAME_IN_API = 'Network'

t4( 'imported python requests version:', requests.__version__ )

####################################################################################
class CheckPointDevice( object ):
   ''' REST API:
       Operation = HTTP method:
       ADD = Post, MODIFY = Put, READ = Get, DELETE = Delete
    '''
   def __init__( self, config, urlApiPath, mgmtIp=None ):
      self.config = config
      self.ipAddr = mgmtIp if mgmtIp else config[ 'ipAddress' ]
      protocol = config[ 'protocol' ]
      portNum = config[ 'protocolPortNum' ]
      self.username = config[ 'username' ]
      self.password = config[ 'password' ]
      self.timeout = config[ 'timeout' ]
      self.retries = config[ 'retries' ]
      self.deviceInfo = None
      self.baseUrl = '%s://%s:%s/%s/' % ( protocol, self.ipAddr, portNum,
                                          urlApiPath )
      t4('CheckPointDevice API baseUrl:', self.baseUrl )
      self.requestHeaders = { 'content-Type': 'application/json',
                              'accept':       'application/json',
                              'X-chkp-sid': '' }
      self.chkpApi = requests.session()  # use HTTP 1.1 persistent (TCP) connection
      self.chkpApi.mount( 'http://',
                          requests.adapters.HTTPAdapter( max_retries=self.retries ) )
      self.chkpApi.mount( 'https://',
                          requests.adapters.HTTPAdapter( max_retries=self.retries ) )
      self.chkpApi.verify = config[ 'verifyCertificate' ]  # also see getUrl
      self.logins = 0
      self.sslProfileName = config[ 'sslProfileName' ]
      self.trustedCertsPath = config.get( 'trustedCertsPath', '' )

   def closeApiConnection( self ):
      if self.requestHeaders[ 'X-chkp-sid' ]:
         t4('Closing API connection')
         self.getUrl( 'logout' )
         self.chkpApi.close()
         self.requestHeaders[ 'X-chkp-sid' ] = ''

   def loginAsNeededAndConfirm( self ):
      """ Login to Management Server.  Return True on success.
      """
      if self.requestHeaders[ 'X-chkp-sid' ]:
         return True
      credentials = { 'user': self.username, 'password' : self.password }
      resp = self.getUrl( 'login', requestBody=credentials, loginRequest=True )
      if not resp or 'sid' not in resp:
         raise ServiceDeviceError( 'login failed' )
      else:
         self.requestHeaders[ 'X-chkp-sid' ] = resp[ 'sid' ]
         self.logins += 1
         t2('login successful')
         return True

   def getUrl( self, command, requestBody=None, loginRequest=False ):
      """ Make REST API calls to a Check Point device
      """
      if not loginRequest and not self.loginAsNeededAndConfirm():
         return None  # login failed
      if not requestBody:
         requestBody = {}
      url = self.baseUrl + command

      if loginRequest:
         t2('API REQ LOGIN', 'logins:', self.logins )
      else:
         t2('API REQ URL', url, 'BODY:', requestBody, 'logins:', self.logins )

      self.trustedCertsPath = self.config.get( 'trustedCertsPath', '' )
      if self.sslProfileName and not self.trustedCertsPath:
         raise ServiceDeviceError( Lib.SSL_ERROR_MSG )

      connectionFailed = False
      for attempt in range( 1, self.retries + 2 ):
         resp = None
         try:
            resp = self.chkpApi.post(
               url, data=json.dumps( requestBody ), headers=self.requestHeaders,
               verify=( self.trustedCertsPath if self.sslProfileName else False ),
               timeout=self.timeout )
            connectionFailed = False

            if resp.status_code == requests.codes.ok:  # pylint: disable=E1101
               break

            # FIXME: remove when gateway API session keep alives work properly
            elif ( resp.status_code == requests.codes.unauthorized and resp.text and
                   ( 'Session may be expired' in resp.text or
                     'Session was expired' in resp.text ) ):
               t2( self.ipAddr, 'API session expired, trying re-login:', resp.text )
               self.requestHeaders[ 'X-chkp-sid' ] = ''  # clear expired session id
               self.loginAsNeededAndConfirm()
         except ServiceDeviceError:
            raise
         except FirewallAPIError:
            raise
         except requests.exceptions.SSLError:
            raise ServiceDeviceError( Lib.SSL_ERROR_MSG )
         except Exception as ex:  # pylint: disable=W0703
            if loginRequest:
               t4( '%s CheckPoint API login access attempt %s, %s' % (
                    self.ipAddr, attempt, type( ex ) ) )
            else:
               t4( '%s CheckPoint API access attempt %s, %s' % (
                    self.ipAddr, attempt, ex ) )

            connectionFailed = True

      if connectionFailed:
         # Connection failed after max retries
         raise ServiceDeviceError( 'Connection error' )
      try:
         respJson = resp.json()
      except Exception as ex:  # pylint: disable=W0703
         t4( '%s CheckPoint API call %s status %s resp.json() failed, '
              'text: %s' % ( self.ipAddr, command, resp.status_code,
                             resp.text[ :210 ] if resp else '' ) )
         raise FirewallAPIError( resp.status_code, None )
      t4('API RESP', resp.status_code, json.dumps( respJson, indent=3, sort_keys=1 ))
      if resp.status_code != requests.codes.ok:  # pylint: disable=E1101
         if respJson[ 'code' ] == 'err_login_failed':
            raise ServiceDeviceError( 'login failed' )
         raise FirewallAPIError( resp.status_code, respJson[ 'code' ] )
      return respJson

####################################################################################
class CheckPointMgmtServer( CheckPointDevice ):
   ''' For API access to a Check Point Security Management Server
   '''
   def __init__( self, config ):
      super( CheckPointMgmtServer, self ).__init__( config, 'web_api' )
      if self.config[ 'group' ]:
         self.policyLayerName = self.config[ 'group' ] + ' Network'
      else:
         self.policyLayerName = ''
      # Dict keys:  gateway names, values: CheckPointGateway object
      self.managedGateways = None
      self.clusters = {}
      self.gwToIntfAndZoneMap = {}
      self.interfaces = {}
      self.nodes = {}

   def getDeviceInfo( self, cachedOk=True ):
      if cachedOk and self.deviceInfo:
         t4('get device info, cached')
         return self.deviceInfo

      resp = self.getUrl( 'show-gateways-and-servers', { 'details-level': 'full' } )
      if not resp:
         return {}
      devInfo = {}
      gatewayAddrs = {}
      for obj in resp[ 'objects' ]:
         self.nodes[ obj[ 'name' ] ] = obj[ 'type' ]
         if 'interfaces' in obj:
            name = obj[ 'name' ]
            if name not in self.interfaces:
               self.interfaces[ name ] = {}
            if name not in self.gwToIntfAndZoneMap:
               self.gwToIntfAndZoneMap[ name ] = {}
            for intf in obj[ 'interfaces' ]:
               intfName = intf[ 'interface-name' ]
               self.interfaces[ name ][ intfName ] = str( intf[ 'ipv4-address' ] )
               if 'topology' in intf and 'security-zone' in intf[ 'topology' ] :
                  self.gwToIntfAndZoneMap[ name ][ intfName ] = str( intf[
                     'topology' ][ 'security-zone' ][ 'name' ] )
         if ( 'management-blades' in obj and
              'network-policy-management' in obj[ 'management-blades' ] and
              obj[ 'management-blades' ][ 'network-policy-management' ] is True ):
            devInfo[ 'name' ] = str( obj[ 'name' ] )
            devInfo[ 'ipAddr' ] = str( obj[ 'ipv4-address' ] )
            devInfo[ 'model' ] = str( obj[ 'hardware' ] )
            self.deviceInfo = devInfo
         elif obj[ 'type' ] == 'simple-gateway':
            t4( obj[ 'type' ], obj[ 'name' ], obj[ 'ipv4-address' ] )
            gatewayAddrs[ str( obj[ 'name' ] ) ] = str( obj[ 'ipv4-address' ] )
         elif obj[ 'type' ] == 'CpmiClusterMember':
            t4( obj[ 'type' ], obj[ 'name' ], obj[ 'ipv4-address' ] )
            gatewayAddrs[ str( obj[ 'name' ] ) ] = str( obj[ 'ipv4-address' ] )
         elif obj[ 'type' ] == 'CpmiGatewayCluster':
            t4( obj[ 'type' ], obj[ 'name' ], obj[ 'ipv4-address' ] )
            name = str( obj[ 'name' ] )
            self.clusters[ name ] = { 'ipv4-address':
               str( obj[ 'ipv4-address' ] ), 'members':
               [ str( member ) for member in obj[ 'cluster-member-names' ] ] }
      self.managedGateways = { name: CheckPointGateway( self.config, addr,
                               mgmtServer=self ) for name,
                               addr in gatewayAddrs.items() }
      return devInfo

   def getPolicies( self, mssTags=None ):
      policies = {}
      resp = self.getUrl( 'show-access-layers' )
      if not resp:
         return {}
      layers = [ layer[ 'name' ] for layer in resp[ 'access-layers' ] ]
      for layer in layers:
         t2( 'Access layer:', layer )
         layerName = self.policyLayerName
         if self.policyLayerName == DEFAULT_POLICY_LAYER_NAME:
            # The API uses 'Network' for the Standard layer name
            layerName = DEFAULT_POLICY_LAYER_NAME_IN_API
         if layer != layerName:
            continue
         layerPolicies = self.getLayerPolicies( layer, mssTags )
         for mgmtIp in layerPolicies:
            if mgmtIp in policies:
               policies[ mgmtIp ].extend( layerPolicies[ mgmtIp ] )
            else:
               policies[ mgmtIp ] = layerPolicies[ mgmtIp ]
      return policies

   def getLayerPolicies( self, layer, mssTags ):
      ''' Retrieves access-rulebase for a specified policy layer on a Check Point
          Security Management Server the the matching tags.  The layer name is in
          the config dict. All rules/policies in the specified layer are intended
          for MSS processing. Returns a dict where key=gateway mgmt IP address,
          value=ServiceDevicePolicy list
      '''
      policies = {}
      allGateways = set()
      t2('getPolicies from Management Server for layer:', self.policyLayerName )
      resp = self.getUrl( 'show-access-rulebase', { 'name': layer,
                                                    'details-level': 'full' } )
      if not resp:
         return {}
      objectsDict, zoneIntfs = parseObjectsDict( resp )
      for rule in resp[ 'rulebase' ]:
         policyName = rule[ 'name' ] if 'name' in rule \
                                     else 'ruleNum%s' % rule[ 'rule-number' ]
         policyComment = str( rule[ 'comments' ] ) if 'comments' in rule else ''
         match = POLICY_TAGS_REGEX.search( policyComment )
         if not match:
            continue
         tags = [ t.strip() for t in match.group( 1 ).split( ',' ) ]
         tags = set( [ t for t in tags if t ] ) # filter empty strings, make set
         t4( policyName, 'extracted policy tags:', tags )
         for tag in tags:
            if tag in mssTags:
               break
         else:
            t2( 'skipping rule. No matching tags:', policyName )
            continue
         if not rule[ 'enabled' ]:
            t2('skipping disabled rule:', policyName )
            continue
         policyName = str( policyName )
         policy = ServiceDevicePolicy( policyName, number=rule[ 'rule-number' ] )
         policy.dstZoneType = Lib.LAYER3
         policy.srcZoneType = Lib.LAYER3
         resolveL4Ports( objectsDict, rule[ 'service' ], policy )
         policy.action = str( objectsDict[ rule[ 'action' ] ][ 'name' ] )
         policy.tags = tags
         ( policy.srcZoneName,
           policy.srcIpAddrList ) = getPolicyZoneInfo( rule[ 'source' ],
                                                       objectsDict, zoneIntfs )
         ( policy.dstZoneName,
           policy.dstIpAddrList ) = getPolicyZoneInfo( rule[ 'destination' ],
                                                       objectsDict, zoneIntfs )
         # generate a policy for each gateway

         for gateway in self.getRuleTargetGateways( rule, objectsDict, allGateways ):
            mgmtIp = self.managedGateways[ gateway ].mgmtIp
            gwPolicy = deepcopy( policy )
            gwPolicy.managementIp = mgmtIp
            gwPolicy.srcZoneInterfaces = self.getZoneInterfacesForGw(
               policy.srcZoneName, mgmtIp )
            gwPolicy.dstZoneInterfaces = self.getZoneInterfacesForGw(
               policy.dstZoneName, mgmtIp )
            if mgmtIp in policies:
               policies[ mgmtIp ].append( gwPolicy )
            else:
               policies[ mgmtIp ] = [ gwPolicy ]
      return policies

   def getZoneInterfacesForGw( self, zoneName, gateway ):
      intfs = []
      zones = self.getGwToIntfAndZoneMap()
      intfsNames = []
      if gateway in zones:
         for intfName, zone in zones[ gateway ].items():
            if zone == zoneName:
               netIntf = NetworkInterface(
                  intfName, zone=zoneName,
                  isEthernet=intfName.lower().startswith( 'eth' ) )
               intfs.append( netIntf )
               intfsNames.append( intfName )
      if gateway in self.clusters:
         for memberGw in self.clusters[ gateway ][ 'members' ]:
            memberGwIntfs = self.getZoneInterfacesForGw( zoneName, memberGw )
            for intf in memberGwIntfs:
               if intf.name not in intfsNames:
                  intfs.append( intf )
                  intfsNames.append( intf.name )
      return intfs

   def getRuleTargetGateways( self, rule, objectsDict, allGatewayIpAddrs ):
      ''' Determine target gateway(s) for the rule.
          Mutates allGatewayIpAddrs
      '''
      gatewayTargets = set()
      for uid in rule[ 'install-on' ]:
         if uid in objectsDict:
            obj = objectsDict[ uid ]
            if obj[ 'type' ] == 'simple-gateway':
               #gatewayTargets.update( [ obj[ 'ipv4-address' ] ] )
               gatewayTargets.update( [ obj[ 'name' ] ] )
            elif obj[ 'type' ] == 'Global' and obj[ 'name' ] == 'Policy Targets':
               if not allGatewayIpAddrs:  # init when needed
                  allGatewayIpAddrs.update( self.getAllGatewayIPAddrs() )
               gatewayTargets.update( allGatewayIpAddrs )
            elif obj[ 'type' ] == 'CpmiGatewayCluster':
               clusterName = obj[ 'name' ]
               if clusterName in self.clusters:
                  gatewayTargets.update( self.clusters[ clusterName ][ 'members' ] )
      #t4('rule:', rule[ 'rule-number' ], 'gatewayTargets:', gatewayTargets )
      return [ str( g ) for g in gatewayTargets ]

   def getAllGatewayIPAddrs( self ):
      if self.managedGateways is None:
         self.getDeviceInfo( cachedOk=False )
      return { gw:str( geObj.mgmtIp ) for gw, geObj in self.managedGateways.items() }

   def getGwToIntfAndZoneMap( self ):
      ''' Returns dict of dicts where key=gateway, value=dict where
          key = intf, value = zone name
      '''
      if self.managedGateways is None:
         self.getDeviceInfo( cachedOk=False )
      return self.gwToIntfAndZoneMap

####################################################################################
class CheckPointGateway( CheckPointDevice ):
   ''' For API access to a Check Point Gateway/Firewall
   '''
   def __init__( self, config, mgmtIp, mgmtServer=None ):
      super( CheckPointGateway, self ).__init__( config, 'gaia_api', mgmtIp )
      self.deviceInfo = {}
      self.mgmtIp = mgmtIp
      self.mgmtServer = mgmtServer
      self.name = ''

   def getGwUrl( self, command, requestBody=None ):
      if not self.mgmtServer:
         return self.getUrl( command, requestBody )

      proxyCommand = 'gw/' + command

      if requestBody:
         requestBody[ "gwIP" ] = self.ipAddr
      else:
         requestBody = { "gwIP":self.ipAddr }

      resp = self.mgmtServer.getUrl( proxyCommand, requestBody=requestBody )

      if resp and 'responseMessage' in resp:
         return resp[ 'responseMessage' ]
      return {}

   def getDeviceInfo( self, cachedOk=True ):
      if cachedOk and self.deviceInfo:
         t4('get device info, cached')
         return self.deviceInfo

      resp = self.getGwUrl( 'show-hostname' )
      if not resp:
         return {}
      devInfo = {}
      self.name = str( resp[ 'name' ] )
      devInfo[ 'name' ] = self.name

      devInfo[ 'ipAddr' ] = ''
      resp = self.getGwUrl( 'show-interfaces', {} )
      if not resp:
         return {}
      intfs = resp.get( 'objects' )
      if not intfs:
         return {}
      for intf in intfs:
         if intf[ u'name' ] == 'Mgmt':
            devInfo[ 'ipAddr' ] = str( intf[ u'ipv4-address' ] )

      resp = self.getGwUrl( 'show-asset' )
      if not resp:
         return {}
      systemDict = { item[ 'key' ]:item[ 'value' ] for item in resp[ 'system' ] }
      devInfo[ 'model' ] = str( systemDict[ 'Model' ] )
      self.deviceInfo = devInfo
      return devInfo

   def getDeviceResources( self ):
      # cli command: cpstat os -f [ perf | cpu | memory | disk ]
      resp = self.getGwUrl( 'show-asset' )
      if not resp:
         return {}
      memoryDict = { item[ 'key' ]:item[ 'value' ] for item in resp[ 'memory' ] }
      systemDict = { item[ 'key' ]:item[ 'value' ] for item in resp[ 'system' ] }
      diskDict = { item[ 'key' ]:item[ 'value' ] for item in resp[ 'disk' ] }
      info = 'Mem: ' + str( memoryDict[ 'Total Memory' ] )
      info += '\nCPU Cores: ' + str( systemDict[ 'Number of Cores' ] )
      info += '\nDisk: ' + str( diskDict[ 'Total Disks size' ] )
      return { 'resourceInfo' : info }

   def getDeviceRoutingTables( self ):
      ''' Returns a ServiceDeviceRoutingTables object
      '''
      routingTables = ServiceDeviceRoutingTables()
      routes = self.getGwUrl( 'show-routes-static', {} )
      vrfName = 'default'
      for route in routes[ 'objects' ] :
         if 'protocol' not in route or route[ 'protocol' ] != 'Static':
            t2( 'Skipping route:', route )
            continue
         destination = str( route[ 'address' ] )
         destination += '/' + str( route[ 'mask-length' ] )
         nexthop = route[ 'next-hop' ]
         if 'gateways' in nexthop:
            interface = str( nexthop[ 'gateways' ][ 0 ][ 'interface' ] )
            nexthop = str( nexthop[ 'gateways' ][ 0 ][ 'address' ] )
         else:
            interface = str( nexthop[ 'interface' ] )
            nexthop = ''
         routingTables.addRoute( vrfName, destination, interface, nexthop )
      routingTables.featureSupported = True
      return routingTables

   def getHighAvailabilityState( self ):
      haState = ServiceDeviceHAState()
      haState.mgmtIp = self.ipAddr
      resp = self.getGwUrl( 'show-cluster-state' )
      haState.enabled = True if 'mode' in resp else False
      if not haState.enabled:
         haState.mode = ''
         return haState

      if 'high-availability' in resp[ 'mode' ]:
         haState.mode = HA_ACTIVE_PASSIVE
      else:
         haState.mode = HA_ACTIVE_ACTIVE
      if 'this-cluster-member' in resp and \
         'active' in  resp[ 'this-cluster-member' ][ 'status' ]:
         haState.state = HA_ACTIVE
      else:
         haState.state = HA_PASSIVE
      if 'other-cluster-members' in resp:
         peerGw = resp[ 'other-cluster-members' ][ 0 ] [ 'name' ]
         if 'active' in  resp[ 'other-cluster-members' ][ 0 ] [ 'status' ]:
            haState.peerDeviceState = HA_ACTIVE
         else:
            haState.peerDeviceState = HA_PASSIVE
         if self.mgmtServer:
            gwAddrs = self.mgmtServer.getAllGatewayIPAddrs()
            if peerGw in gwAddrs:
               haState.peerMgmtIp = gwAddrs[ peerGw ]
      return haState

   def getInterfacesInfo( self, resolveZoneNames=True ):
      ''' Get all necessary interface information for service devices.
          Returns a list of NetworkInterface objects.
      '''
      if resolveZoneNames:
         # FIXME hold ms object for persistent login session
         # TODO
         intfToZone = {}
         #intfToZone = CheckPointMgmtServer( self.config ).getGwToIntfAndZoneMap()

      resp = self.getGwUrl( 'show-interfaces' )
      if not resp:
         return {}
      interfaces = {}
      clusterVIPs = None
      if self.mgmtServer:
         for cluster, info in self.mgmtServer.clusters.items() :
            if self.name in info[ 'members' ]:
               if cluster in self.mgmtServer.interfaces:
                  clusterVIPs = self.mgmtServer.interfaces[ cluster ]
                  t4( 'clusterVIPs:', clusterVIPs )
      else:
         t4( 'No management server for gateway:', self.name )
      for intfInfo in resp[ u'objects' ] :
         intfName = str( intfInfo[ 'name' ] )
         intfType = str( intfInfo[ 'type' ] )
         linkState = LINK_STATE_UNKNOWN
         if not Lib.filterOnStartsWith( intfName, INTF_TYPE_FILTER ):
            continue
         if intfType == 'physical':
            resp = self.getGwUrl( 'show-physical-interface', { 'name': intfName } )
            linkState = translateIntfState( resp[ 'status' ][ 'link-state' ] )
            netIntf = NetworkInterface( intfName, state=linkState,
                                        vlans=[ DEFAULT_VLAN_RANGE ],
                                        isEthernet=True )
         elif intfType == 'vlan':
            vlan = intfName.split( '.' )[ 1 ]
            resp = self.getGwUrl( 'show-vlan-interface', { 'name': intfName } )

            # linkstate will be assigned later by looking at the parent intf
            netIntf = NetworkInterface( intfName, state=linkState, isSubIntf=True,
                                        vlans=[ vlan ] )
            # Add the parent interface as a physical interface. This will be
            # expanded later for lags, and the link state will be set
            netIntf.addPhysicalIntf( str( resp[ 'parent' ] ) )
         elif intfType == 'bond':
            # LAGs are called bond interfaces in CheckPoint lingo
            # linkstate and physical interfaces will be assigned later by
            # calling getLagIntfMembers
            netIntf = NetworkInterface( intfName, state=linkState, isLag=True )
         else:
            t2( 'Unknown interface type:', intfName, intfType, linkState )
            continue
         if clusterVIPs and intfName in clusterVIPs:
            ipAddr = str( clusterVIPs[ intfName ] )
         else:
            ipAddr = str( intfInfo[ 'ipv4-address' ] )
         if isValidIpAddr( ipAddr ):
            netIntf.ipAddr = ipAddr
         interfaces[ intfName ] = netIntf

         if resolveZoneNames and netIntf.name in intfToZone:
            netIntf.zone = intfToZone[ netIntf.name ]

      # for LAGs populate physical intf and link states
      for intfName, netIntf in interfaces.iteritems():
         if netIntf.isLag:
            for ethIntfName in self.getLagIntfMembers( intfName ):
               if ethIntfName in interfaces:
                  netIntf.addPhysicalIntf( ethIntfName,
                                           interfaces[ ethIntfName ].state )
               else:
                  netIntf.addPhysicalIntf( ethIntfName )
            netIntf.state = self.determineLagLinkState( netIntf.physicalIntfs )

      # for subinterfaces populate physical intf and set the link state
      for intfName, netIntf in interfaces.iteritems():
         if netIntf.isSubIntf:
            parentIntfName = netIntf.physicalIntfs[ 0 ].name
            parentIntf = interfaces.get( parentIntfName )
            if not parentIntf:
               t2('unable to find parent intf', parentIntfName, 'for', netIntf )
               continue
            # Clear the physical interfaces and repopulate with the correct
            # link state and expand lag interfaces
            netIntf.physicalIntfs = []
            if parentIntf.isLag:
               for ethIntf in parentIntf.physicalIntfs:
                  netIntf.addPhysicalIntf( ethIntf.name, ethIntf.state )
            else:
               netIntf.addPhysicalIntf( parentIntfName, state=parentIntf.state )
            netIntf.state = parentIntf.state
         t2( 'getInterfacesInfo:', str( netIntf ) )

      return { 'root': interfaces }

   def determineLagLinkState( self, physicalIntfs ):
      return ( LINK_STATE_UP if any(
         [ intf.state == LINK_STATE_UP for intf in physicalIntfs ] )
         else LINK_STATE_DOWN )

   def getLagIntfMembers( self, intfName ):
      ''' LAGs are called bond interfaces in CheckPoint lingo
      '''
      resp = self.getGwUrl( 'show-bond-interface', { 'name': intfName } )
      if not resp:
         return {}
      lagIntfMembers = []
      if 'slaves' in resp:
         # LAGs in checkpoint lingo are called Bonded interfaces and
         # LAG members are called slaves
         lagIntfMembers = resp[ 'slaves' ]
      elif 'members' in resp:
         lagIntfMembers = resp[ 'members' ]
      t4( 'lagIntfMembers:', lagIntfMembers )
      return lagIntfMembers

   def getInterfaceNeighbors( self ):
      neighbors = {}
      #resp = self.getGwUrl( 'show-lldp/neighbors' )
      resp = self.getGwUrl( 'show-lldp' )
      if not resp:
         return {}
      for nbor in resp.values():
         if ( 'Interface' not in nbor or 'Name' not in nbor[ 'Interface' ] or
              'Chassis' not in nbor or 'ChassisID' not in nbor[ 'Chassis' ] or
              'Port' not in nbor or 'PortID' not in nbor[ 'Port' ] ):
            t4('skipping this neighbor, incomplete data:', nbor )
            continue
         fwIntf = str( nbor[ 'Interface' ][ 'Name' ] )
         if not Lib.filterOnStartsWith( fwIntf, INTF_TYPE_FILTER ):
            continue

         switchChassisId = str( nbor[ 'Chassis' ][ 'ChassisID' ] )
         switchChassisId = switchChassisId.replace( 'mac', '' ).strip()
         switchIntf =  str( nbor[ 'Port' ][ 'PortID' ] )
         switchIntf = switchIntf.replace( 'ifname', '' ).strip()
         nborMgmtIp =  str( nbor[ 'Chassis' ][ 'MgmtIP' ] ) \
                       if 'MgmtIP' in nbor[ 'Chassis' ] else ''
         nborSysName = str( nbor[ 'Chassis' ][ 'SysName' ] ) \
                       if 'SysName' in nbor[ 'Chassis' ] else ''
         nborDesc =    str( nbor[ 'Chassis' ][ 'SysDescr' ] ) \
                       if 'SysDescr' in nbor[ 'Chassis' ] else ''
         neighbors[ fwIntf ] = { 'switchChassisId': switchChassisId,
                                 'switchIntf': switchIntf,
                                 'nborMgmtIp': nborMgmtIp,
                                 'nborSysName': nborSysName,
                                 'nborDesc': nborDesc }
         t4( fwIntf, 'neighbor:', neighbors[ fwIntf ] )
      return neighbors

#-----------------------------------------------------------------------------------
def translateIntfState( state ):
   ''' Translate CheckPoint intf state to MSS intf state
   '''
   if state:
      return LINK_STATE_UP
   return LINK_STATE_DOWN

def parseObjectsDict( resp ):
   ''' Returns an object dictionary where key is uid
       Also returns a zoneInterfaces dict where key=zone name,
        value=ethernet or LAG interface name.
   '''
   objectsDict = {}  # key is object uid, value is the object
   zoneIntfs = {}
   for obj in resp[ 'objects-dictionary' ]:
      objectsDict[ obj[ 'uid' ] ] = obj
      if obj[ 'type' ] == 'simple-gateway':
         for intf in obj[ 'interfaces' ]:
            if ( 'security-zone-settings' in intf and
                 'specific-zone' in intf[ 'security-zone-settings' ] ):
               zoneName = str( intf[ 'security-zone-settings' ][ 'specific-zone' ] )
               if zoneName not in zoneIntfs:
                  zoneIntfs[ zoneName ] = []
               zoneIntfs[ zoneName ].append( str( intf[ 'name' ] ) )
   t2( 'zoneIntfs:', zoneIntfs )
   return objectsDict, zoneIntfs

def getPolicyZoneInfo( zoneUids, objectsDict, zoneIntfs ):
   zoneName = ''
   hosts = []
   for zoneUid in zoneUids:
      obj = objectsDict[ zoneUid ]
      if obj[ 'type' ] == 'security-zone':
         zoneName = str( obj[ 'name' ] )
      elif obj[ 'type' ] == 'host':
         hosts.append( str( obj[ 'ipv4-address' ] ) )
      elif obj[ 'type' ] == 'network':
         subnetWithMask = '%s|%s' % ( str( obj[ 'subnet4' ] ),
                                      str( obj[ 'subnet-mask' ] ) )
         Lib.appendOrExtend( hosts, subnetWithMask )
   t2( 'zoneInfo:', zoneName, hosts )
   return zoneName, hosts

def resolveL4Ports( objectsDict, serviceIds, policy ):
   policy.dstL4Services = {}
   for serviceId in serviceIds:
      serviceObj = objectsDict[ serviceId ]
      serviceType = serviceObj[ 'type' ]
      if serviceType.startswith( 'service-' ):
         protocol = str( serviceType[ 8: ].upper() )
      else:
         t0( 'Unknown Service Type:', serviceType )
         continue
      if 'port' in serviceObj:
         ports = [ str( serviceObj[ 'port' ] ) ]
      else:
         ports = []
         t0( 'No Service port' )
      if protocol not in policy.dstL4Services:
         policy.dstL4Services[ protocol ] = ports
      else:
         policy.dstL4Services[ protocol ].extend( ports )

def jsonDump( obj ):
   print json.dumps( obj, indent=3, sort_keys=True )

####################################################################################
# unit tests

def testRetriesAndTimeouts( cfg ):
   print '\n\nTEST ALL GOOD'
   testMsApi( cfg )
   testGwApi( cfg, 'fwchk101', '172.24.138.207' )

   print '\n\nTEST BAD MS ADDR'
   orig = cfg[ 'ipAddress' ]
   cfg[ 'ipAddress' ] = 'BAD_MS_ADDR'
   testMsApi( cfg )
   testGwApi( cfg, 'fwchk101', '172.24.138.207' )
   cfg[ 'ipAddress' ] = orig

   print '\n\nTEST BAD MS PASSWORD'
   orig = cfg[ 'password' ]
   cfg[ 'password' ] = 'BAD_PASSWORD'
   testMsApi( cfg )
   testGwApi( cfg, 'fwchk101', '172.24.138.207' )
   cfg[ 'password' ] = orig

   print '\n\nTEST BAD GW IP'
   testGwApi( cfg, 'fwchk101', 'BAD_GW_IP' )

   print '\n\nTEST TIMEOUT'
   orig = cfg[ 'timeout' ]
   cfg[ 'timeout' ] = 0.001
   testMsApi( cfg )
   testGwApi( cfg, 'fwchk101', '172.24.138.207' )
   cfg[ 'timeout' ] = orig


def testMsApi( deviceDict ):
   print '\nTest API CheckPointMgmtServer:\n'
   ms = CheckPointMgmtServer( deviceDict )

   info = ms.getDeviceInfo()
   print 'MS DeviceInfo:'
   printObj( info )

   print '\nmanagedGateways:'
   printObj( ms.managedGateways )

   print '\ngateways:'
   printObj( ms.getAllGatewayIPAddrs() )

   print '\nclusters:'
   printObj( ms.clusters )

   print '\nzones:'
   printObj( ms.gwToIntfAndZoneMap )

   print '\ninterfaces:'
   printObj( ms.interfaces )

   policies = ms.getPolicies()
   print '\nPolicies from MS:\n'
   for device, pols in policies.items():
      print 'Firewall:', device
      for pol in pols:
         print ' ', pol, '\n'

   zones = ms.getGwToIntfAndZoneMap()
   print 'Zones:'
   printObj( zones )

   for gw, gwObj in ms.managedGateways.items() :
      print "\n---------------------------------------------------------"
      print "Using Proxy APIs to connect to", gw, "at", gwObj.ipAddr
      print '\nDevice Info:'
      printObj( gwObj.getDeviceInfo() )
      print '\nDevice Resources:'
      printObj( gwObj.getDeviceResources() )
      print '\nDevice HA State:'
      print gwObj.getHighAvailabilityState()
      print '\nDevice Interfaces:'
      for intf in gwObj.getInterfacesInfo():
         print intf
      print '\nDevice Routing Table:'
      print gwObj.getDeviceRoutingTables()

   # ms.getUrl( 'show-api-status' )  # 'show-api-versions'
   # ms.getUrl( 'show-commands' )  # show all API cmds
   ms.closeApiConnection()


def testGwApi( deviceDict, name, mgmtIp ):
   print '\nTest API Check Point Gateway/Firewall'
   gw = CheckPointGateway( deviceDict, mgmtIp )

   info = gw.getDeviceInfo()
   print '\nGW INFO:', info, '\n'

   ha = gw.getHighAvailabilityState()
   print 'HA state:', ha, '\n'

   for netIntf in gw.getInterfacesInfo():
      print 'INTF INFO:', netIntf

   #print '\nNEIGHBORS:'
   #nbors = gw.getInterfaceNeighbors()
   # for intf, nbor in nbors.items():
   #    print intf, nbor

   res = gw.getDeviceResources()
   print '\nRESOURCES:\n', res[ 'resourceInfo' ] if res else ''

   routingTable = gw.getRoutingTable()
   print '\nRouting Table:\n', routingTable

   gw.closeApiConnection()


def runUnitTests():
   #  'ipAddress': 'bizdev-chkp-ms', 'username': 'readonly', 'password': 'arista123',
   cfg = {
      'ipAddress': '172.24.132.248', 'username': 'admin', 'password': 'arista123',
      'protocol': 'https', 'protocolPortNum': 4434, 'method': 'tls',
      'verifyCertificate': False, 'timeout': 15, 'retries': 1,
      'exceptionMode': 'bypass', 'group': 'Arista_MSS', 'sslProfileName': '',
      'interfaceMap': {
         'eth1-01': {
            'switchIntf': 'Ethernet11', 'switchChassisId': '001c.7374.819e' },
         'eth1-02': {
            'switchIntf': 'Ethernet12', 'switchChassisId': '001c.7374.819e' }, } }

   testMsApi( cfg )
   testRetriesAndTimeouts( cfg )


if __name__ == "__main__":
   runUnitTests()
