# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac
import Tracing
from ApiBaseModels import BaseModel, BaseType, Bool, Int, List, Str

t0 = Tracing.trace0

# Parameters differ from overridden 'fromSysdb' and 'toSysdb' methods
# pylint: disable-msg=arguments-differ

class RuleHitsModel( BaseModel ):
   ruleId = Str( apiName='rule_id', tacName='id', versions=[ 1 ],
                 inputOk=False, outputOk=True, description='' )
   pktCount = Int( apiName='packet_count', tacName='hitCount', versions=[ 1 ],
                   inputOk=False, outputOk=True, description='' )
   hitCount = Int( apiName='hit_count', tacName='hitCount', versions=[ 1 ],
                   inputOk=False, outputOk=True, description='' )
   # Arista hardware doesn't support session and byte hit counts.
   # So, 'session_count' and 'hit_count' are not sent

class HitCountModel( BaseModel ):
   sectionId = Str( apiName='section_id', tacName='', versions=[ 1 ],
                    inputOk=False, outputOk=True, description='' )
   results = List( apiName='results', tacName='', versions=[ 1 ],
                   inputOk=False, outputOk=True, description='',
                   valueType=RuleHitsModel )
   resultCount = Int( apiName='result_count', tacName='', versions=[ 1 ],
                      inputOk=False, outputOk=True, description='' )

   def fromSysdb( self, section ):
      self.getModelField( 'section_id' ).value = section.id
      ruleHits = []
      for rule in section.prs.itervalues():
         model = RuleHitsModel()
         model.fromSysdb( rule )
         ruleHits.append( model )
      self.getModelField( 'results' ).value = ruleHits
      self.getModelField( 'result_count' ).value = len( ruleHits )

class SwitchStatusDetail( BaseModel ):
   subSystemType = Str( apiName='sub_system_type', tacName='', versions=[ 1 ],
                        inputOk=False, outputOk=True, description='' )
   subSystemId = Str( apiName='sub_system_id', tacName='', versions=[ 1 ],
                      inputOk=False, outputOk=True, description='' )
   state = Str( apiName='state', tacName='', versions=[ 1 ],
                inputOk=False, outputOk=True, description='' )
   failureMsg = Str( apiName='failure_message', tacName='', versions=[ 1 ],
                     inputOk=False, outputOk=True, description='' )

class StatusModel( BaseModel ):
   details = List( apiName='details', tacName='', versions=[ 1 ],
                   inputOk=False, outputOk=True, description='',
                   valueType=SwitchStatusDetail )
   state = Str( apiName='state', tacName='', versions=[ 1 ],
                inputOk=False, outputOk=True, description='' )

   def fromSysdb( self, section ):
      # section.state is an enum that has values with prefix 'prs'
      state = section.state.split( 'prs' )[ 1 ].lower()
      state = 'in_progress' if state == 'inprogress' else state
      self.getModelField( 'state' ).value = state
      switchFailMsg = {}
      detailsList = []
      for rule in section.prs.itervalues():
         for switch, msg in rule.info.iteritems():
            if switch in switchFailMsg:
               # If programming more than one rule of this section failed
               # on a switch fails, only one failure message should be sent
               continue
            else:
               switchFailMsg[ switch ] = True
            model = SwitchStatusDetail()
            model.getModelField( 'sub_system_type' ).value = 'switch'
            model.getModelField( 'sub_system_id' ).value = switch
            model.getModelField( 'state' ).value = state
            model.getModelField( 'failure_message' ).value = msg
            detailsList.append( model )
      self.getModelField( 'details' ).value = detailsList

class GroupNotifIdModel( BaseModel ):
   notificationId = Str( apiName='notification_id', tacName='', versions=[ 1 ],
                         inputOk=True, outputOk=False, description='' )
   uris = List( apiName='uris', tacName='', versions=[ 1 ],
                inputOk=True, outputOk=False, description='', valueType=Str )

   def notifId( self ):
      notificationId = self.getModelField( 'notification_id' )
      return notificationId.value if notificationId.hasBeenSet else ''

   def getUris( self ):
      uriSet = self.getModelField( 'uris' )
      return uriSet.value if uriSet.hasBeenSet else []

   def groupNames( self ):
      uris = self.getUris()
      return [ uri.value.split( '/' )[ -1 ] for uri in uris ]

class GroupNotificationModel( BaseModel ):
   count = Int( apiName='result_count', tacName='', versions=[ 1 ],
                 inputOk=True, outputOk=False, description='' )
   results = List( apiName='results', tacName='', versions=[ 1 ],
                   inputOk=True, outputOk=False, description='',
                   valueType=GroupNotifIdModel )
   refreshNeeded = Bool( apiName='refresh_needed', tacName='', versions=[ 1 ],
                         inputOk=True, outputOk=False, description='' )

   def uris( self, notifId ):
      results = self.getModelField( 'results' )
      if results.hasBeenSet:
         for grpNotifId in results.value:
            if notifId == grpNotifId.notifId():
               return grpNotifId.getUris()
      return []

class GroupIpSetModel( BaseModel ):
   ipSet = List( apiName='results', tacName='', versions=[ 1 ],
                 inputOk=True, outputOk=False, description='', valueType=Str )

   def toSysdb( self, group ):
      ipSet = self.getModelField( 'results' )
      if ipSet.hasBeenSet:
         for ip in ipSet.value:
            ipAddrWithMask = Tac.Value( "Arnet::IpGenAddrWithMask", ip )
            group.ipAddr.add( ipAddrWithMask )

class ServiceModel( BaseModel ):
   icmpType = Int( apiName='icmp_type', tacName='', versions=[ 1 ],
                   inputOk=True, outputOk=False, description='' )
   icmpCode = Int( apiName='icmp_code', tacName='', versions=[ 1 ],
                   inputOk=True, outputOk=False, description='' )
   srcPort = List( apiName='source_ports', tacName='', versions=[ 1 ],
                   inputOk=True, outputOk=False, description='', valueType=Str )
   destPort = List( apiName='destination_ports', tacName='', versions=[ 1 ],
                    inputOk=True, outputOk=False, description='', valueType=Str )
   ipProto = Str( apiName='protocol', tacName='', versions=[ 1 ],
                  inputOk=True, outputOk=False, description='' )
   l4Proto = Str( apiName='l4_protocol', tacName='', versions=[ 1 ],
                  inputOk=True, outputOk=False, description='' )

   def parseService( self ):
      ipProto = 'pipUnknown'
      icmp = Tac.Value( "Pcs::PolicyIcmp" )
      srcPorts = []
      dstPorts = []
      if self.getModelField( 'protocol' ).hasBeenSet:
         proto = self.getModelField( 'protocol' ).value.lower()
         if proto.startswith( 'icmp' ):
            ipProto = 'icmp'
         else:
            return ipProto, icmp, srcPorts, dstPorts
      elif self.getModelField( 'l4_protocol' ).hasBeenSet:
         l4Proto = self.getModelField( 'l4_protocol' ).value.lower()
         if l4Proto in Tac.Type( "Pcs::PolicyIpProtocol" ).attributes:
            ipProto = l4Proto
         else:
            return ipProto, icmp, srcPorts, dstPorts
      elif self.getModelField( 'source_ports' ).hasBeenSet or \
           self.getModelField( 'destination_ports' ).hasBeenSet:
         t0( 'Service protocol not specified. Defaulting to TCP' )
         ipProto = 'tcp'
      else:
         ipProto = 'ip'

      icmpType = self.getModelField( 'icmp_type' )
      icmpCode = self.getModelField( 'icmp_code' )
      if icmpType.hasBeenSet:
         icmp.type = icmpType.value
      if icmpCode.hasBeenSet:
         icmp.code = icmpCode.value

      if self.getModelField( 'source_ports' ).hasBeenSet:
         for p in self.getModelField( 'source_ports' ).value:
            # No real ranges are sent. So, the range attr value will never be
            # space separated port numbers
            srcPorts.append( p )

      if self.getModelField( 'destination_ports' ).hasBeenSet:
         for p in self.getModelField( 'destination_ports' ).value:
            dstPorts.append( p )

      return ipProto, icmp, srcPorts, dstPorts

class ServiceListModel( BaseModel ):
   service = BaseType( apiName='service', tacName='', fieldType=dict, versions=[ 1 ],
               inputOk=True, outputOk=False, description='' )
   setattr( service, 'valueType', ServiceModel )

   def parseServiceDict( self ):
      service = self.getModelField( 'service' )
      if service.hasBeenSet:
         return service.value.parseService()
      else:
         return 'pipUnknown', None, None, None

def parseAllServices( services ):
   """
   Helper method to parse all services specified as part of a rule.
   """
   svcProto = 'pipUnknown'
   srcPortSet = Tac.Value( "Pcs::PortSet" )
   dstPortSet = Tac.Value( "Pcs::PortSet" )
   svcIcmp = Tac.Value( "Pcs::PolicyIcmp" )
   for s in services:
      ipProto, icmp, srcPorts, dstPorts = s.parseServiceDict()
      if ipProto != 'pipUnknown':
         # Valid service
         if svcProto == 'pipUnknown' or ipProto == svcProto:
            # If it's the first service that is being parsed, in which case
            # svcProto is 'pipUnknown', or if the protocol of the current
            # service is the same as the previously parsed valid service,
            # accumulate the return values.
            svcProto = ipProto
            svcIcmp = icmp
            map( srcPortSet.range.add, srcPorts )
            map( dstPortSet.range.add, dstPorts )
   # The below state is to take of a condition where all services are invalid
   # in which case the rule protocol is set the default of 'ip'.
   svcProto = 'ip' if svcProto == 'pipUnknown' else svcProto
   return svcProto, svcIcmp, srcPortSet, dstPortSet

class EndpointModel( BaseModel ):
   target = Str( apiName='target_id', tacName='', versions=[ 1 ],
                 inputOk=True, outputOk=False, description='' )
   displayName = Str( apiName='target_display_name', tacName='', versions=[ 1 ],
                      inputOk=True, outputOk=False, description='' )
   targetType = Str( apiName='target_type', tacName='', versions=[ 1 ],
                     inputOk=True, outputOk=False, description='' )
   targetTypeMap = { 'IPv4Address' : 'v4Host',
                     'IPv6Address' : 'v6Host',
                     'NSGroup' : 'group',
                     'any' : 'any',
                     'policy-group' : 'group' }

   def toSysdbAttr( self, rule, attr ):
      targetType = 'epUnknown'
      if self.getModelField( 'target_type' ).hasBeenSet:
         targetType = self.targetTypeMap.get(
                        self.getModelField( 'target_type' ).value )
      if self.getModelField( 'target_id' ).hasBeenSet:
         targetId = self.getModelField( 'target_id' ).value
         if attr == 'source':
            rule.source[ targetId ] = targetType
         if attr == 'destination':
            rule.destination[ targetId ] = targetType
         if targetType == 'group':
            name = self.getModelField( 'target_display_name' ).value or ""
            rule.targetDisplayName[ targetId ] = name

class RuleModel( BaseModel ):
   id = Str( apiName='id', tacName='id', versions=[ 1 ],
             inputOk=True, outputOk=False, description='' )
   displayName = Str( apiName='display_name', tacName='', versions=[ 1 ],
                      inputOk=True, outputOk=False, description='' )
   priority = Int( apiName='priority', tacName='priority', versions=[ 1 ],
                   inputOk=True, outputOk=False, description='' )
   src = List( apiName='sources', tacName='', versions=[ 1 ],
            inputOk=True, outputOk=False, description='', valueType=EndpointModel )
   dst = List( apiName='destinations', tacName='', versions=[ 1 ],
            inputOk=True, outputOk=False, description='', valueType=EndpointModel )
   service = List( apiName='services', tacName='', versions=[ 1 ],
         inputOk=True, outputOk=False, description='', valueType=ServiceListModel )
   logged = Bool( apiName='logged', tacName='logged', versions=[ 1 ],
                  inputOk=True, outputOk=False, description='' )
   action = Str( apiName='action', tacName='action', versions=[ 1 ],
                  inputOk=True, outputOk=False, description='' )
   direction = Str( apiName='direction', tacName='direction', versions=[ 1 ],
                  inputOk=True, outputOk=False, description='' )
   ipProtocol = Str( apiName='ip_protocol', tacName='', versions=[ 1 ],
                  inputOk=True, outputOk=False, description='' )

   def toSysdb( self, rule ):
      rule.priority = self.getModelField( 'priority' ).value or 0
      rule.displayName = self.getModelField( 'display_name' ).value or ""
      if self.getModelField( 'sources' ).hasBeenSet:
         # First remove the ones that are deleted
         targets = []
         for s in self.getModelField( 'sources' ).value:
            targets.append( s.getModelField( 'target_id' ).value )
         for s in rule.source.keys():
            if s not in targets:
               del rule.source[ s ]
         # Add the new ones now
         for s in self.getModelField( 'sources' ).value:
            s.toSysdbAttr( rule, 'source' )
      else:
         rule.source.clear()
         rule.source[ 'any' ] = 'any'
      if self.getModelField( 'destinations' ).hasBeenSet:
         # Cleanup
         targets = []
         for s in self.getModelField( 'destinations' ).value:
            targets.append( s.getModelField( 'target_id' ).value )
         for s in rule.destination.keys():
            if s not in targets:
               del rule.destination[ s ]
         # Then, add
         for s in self.getModelField( 'destinations' ).value:
            s.toSysdbAttr( rule, 'destination' )
      else:
         rule.destination.clear()
         rule.destination[ 'any' ] = 'any'

      if self.getModelField( 'services' ).hasBeenSet:
         rule.ipProto, rule.icmp, rule.srcPort, rule.dstPort = parseAllServices(
            self.getModelField( 'services' ).value
         )
      else:
         rule.icmp = Tac.Value( "Pcs::PolicyIcmp" )
         rule.srcPort = Tac.Value( "Pcs::PortSet" )
         rule.dstPort = Tac.Value( "Pcs::PortSet" )
         ipProto = None
         if self.getModelField( 'ip_protocol' ).hasBeenSet:
            ipProto = self.getModelField( 'ip_protocol' ).value.lower()
            if ipProto.startswith( 'ip' ):
               ipProto = 'ip'
         rule.ipProto = ipProto or 'pipUnknown'
      if self.getModelField( 'logged' ).hasBeenSet:
         rule.logged = self.getModelField( 'logged' ).value
      else:
         rule.logged = False
      if self.getModelField( 'action' ).hasBeenSet:
         action = self.getModelField( 'action' ).value.lower()
         try:
            rule.action = action
         except TypeError:
            t0( "Received unsupport action", action.upper(), "for rule", rule.id )
            rule.action = 'paUnknown'
      else:
         rule.action = 'paUnknown'
      if self.getModelField( 'direction' ).hasBeenSet:
         val = self.getModelField( 'direction' ).value
         rule.direction = 'both' if val == 'IN_OUT' else val.lower()
      else:
         rule.direction = 'pdUnknown'

class SectionModel( BaseModel ):
   sectionId = Str( apiName='id', tacName='id', versions=[ 1 ],
                    inputOk=True, outputOk=False, description='' )
   priority = Int( apiName='priority', tacName='', versions=[ 1 ],
                   inputOk=True, outputOk=False, description='' )
   displayName = Str( apiName='display_name', tacName='', versions=[ 1 ],
                      inputOk=True, outputOk=False, description='' )
   sectionType = Str( apiName='section_type', tacName='type', versions=[ 1 ],
                      inputOk=True, outputOk=False, description='' )
   ruleCount = Int( apiName='rule_count', tacName='', versions=[ 1 ],
                    inputOk=True, outputOk=False, description='' )
   rules = List( apiName='rules', tacName='', versions=[ 1 ],
                 inputOk=True, outputOk=False, description='', valueType=RuleModel )
   revision = Int( apiName='_revision', tacName='', versions=[ 1 ],
                   inputOk=True, outputOk=False, description='' )

   def toSysdb( self, section ):
      section.displayName = self.getModelField( 'display_name' ).value or ""
      section.priority = self.getModelField( 'priority' ).value or 0
      if self.getModelField( 'section_type' ).hasBeenSet:
         section_type = self.getModelField( 'section_type' ).value.lower()
         if section_type in Tac.Type( 'Pcs::SectionType' ).attributes:
            section.type = section_type
         else:
            raise ValueError
      else:
         section.type = 'stUnknown'
      if self.getModelField( 'rules' ).hasBeenSet:
         rules = self.getModelField( 'rules' ).value
         if self.getModelField( 'rule_count' ).value != len( rules ):
            # return
            # TODO charanjith : Vmware has to fix their side
            t0( "rule_count != len(rules)" )
         newRules = [ r.getModelField( 'id' ).value for r in rules ]
         # Remove deleted rules
         for ruleId in section.rule:
            if ruleId not in newRules:
               del section.rule[ ruleId ]
         for rule in rules:
            ruleId = rule.getModelField( 'id' ).value
            oldRule = True if ruleId in section.rule else False
            sysdbRule = section.newRule( ruleId )
            sysdbRule.state = 'dirty' if oldRule else 'init'
            rule.toSysdb( sysdbRule )
            sysdbRule.state = 'complete'
      else:
         section.rule.clear()

class TagModel( BaseModel ):
   scope = Str( apiName='scope', tacName='', versions=[ 1 ],
                inputOk=True, outputOk=False, description='' )
   tag = Str( apiName='tag', tacName='', versions=[ 1 ],
              inputOk=True, outputOk=False, description='' )

class IpSetModel( BaseModel ):
   ipSetId = Str( apiName='id', tacName='id', versions=[ 1 ],
                    inputOk=True, outputOk=False, description='' )
   tags = List( apiName='tags', tacName='name', versions=[ 1 ],
                inputOk=True, outputOk=False, description='', valueType=TagModel )
   ipaddrs = List( apiName='ip_addresses', tacName='ipAddr', versions=[ 1 ],
                 inputOk=True, outputOk=False, description='', valueType=Str )
   revision = Int( apiName='_revision', tacName='', versions=[ 1 ],
                   inputOk=True, outputOk=False, description='' )

   def toSysdb( self, ipsetsync ):
      ipaddrs = self.getModelField( 'ip_addresses' )
      if ipaddrs.hasBeenSet:
         t0( "ipset %s, ip addresses %s" % ( ipsetsync.id, ipaddrs.value ) )
         ipsetsync.ipAddr.clear()
         for ip in ipaddrs.value:
            ipAddrWithMask = Tac.Value( "Arnet::IpGenAddrWithMask", ip )
            ipsetsync.ipAddr.add( ipAddrWithMask )
      revision = self.getModelField( '_revision' )
      if revision.hasBeenSet:
         t0( "ipset %s revision is %s" % ( ipsetsync.id, revision ) )
         ipsetsync.revision = revision.value
      else:
         ipsetsync.revision = 0

class IpSetResultModel( BaseModel ):
   resultCount = Int( apiName='result_count', tacName='', versions=[ 1 ],
                      inputOk=True, outputOk=False, description="rule count" )
   results = List( apiName='results', tacName='', versions=[ 1 ],
         inputOk=True, outputOk=False, description='', valueType=IpSetModel )

   def toSysdb( self, controllerTagDb ):
      ipsets = self.getModelField( 'results' )
      if ipsets.hasBeenSet:
         for ipset in ipsets.value:
            ipsetSysdb = None
            ipsetid = ipset.getModelField( 'id' ).value
            revision = ipset.getModelField( '_revision' ).value
            if not revision:
               t0( "Expected revision, assuming 0" )
               revision = 0
            ipAddrs = ipset.getModelField( 'ip_addresses' ).value
            if not ipAddrs:
               t0( "No ip addresses for this ipset, ignoring" )
               # fallthrough
            tags = ipset.getModelField( 'tags' )
            if not tags.hasBeenSet:
               t0( "No tags set in ipset, ignoring" )
               continue
            for tag in tags.value:
               scope = tag.getModelField( 'scope' ).value
               tag = tag.getModelField( 'tag' ).value
               # filter on tag scope.
               t0( "Tag %s scope is %s" % ( tag, scope ) )
               if not tag or scope != 'cvx':
                  t0( "Tag is not owned by CVX" )
                  continue
               ipsetSysdb = controllerTagDb.newIpset( ipsetid, revision, tag )
               for ip in ipAddrs:
                  ipAddrWithMask = Tac.Value( "Arnet::IpGenAddrWithMask", ip )
                  ipsetSysdb.ipAddr.add( ipAddrWithMask )
               ipsetSysdb.parseComplete = True
