# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
import Arnet
import BothTrace
from McsHttpErrorCodes import errorCodes

log = BothTrace.tracef0
warn = BothTrace.tracef1
error = BothTrace.tracef2
info = BothTrace.tracef3
debug = BothTrace.tracef4

def errorTemplate( ec ):

   return {}

def get_bandwidth_num( bandwidthStr ):
   """ Convert bandwidth string to number
   """
   # From CliPlugin EthIntfCli.py
   # returns bandwidth in kbits/second
   bandwidthVals = { 'speed10Mbps': 10000,
                    'speed100Mbps': 100000,
                    'speed1Gbps': 1000000,
                    'speed10Gbps': 10000000,
                    'speed25Gbps': 25000000,
                    'speed40Gbps': 40000000,
                    'speed50Gbps': 50000000,
                    'speed100Gbps': 100000000,
                    'speedUnknown': 0,
                    'U/A': 0 }
   bandwidth = 0
   if bandwidthStr in bandwidthVals:
      bandwidth = bandwidthVals[ bandwidthStr ]
   return bandwidth

def validateSourceIP( addr ):
   """Validate of Source IP is within acceptible Unicast source IP ranges
   """

   try:
      ip = Arnet.IpGenAddr( addr )
   except ValueError:
      ec = '127'
      error( errorCodes[ ec ] )
      return False, ec

   if ip.isUnicast:
      return True, ''
   ec = '125'
   error( errorCodes[ ec ].replace( '<address>', addr ) )
   return False, ec

def validateMcastGroup( addr ):
   """Validate if Multicast destIp mentioned is within acceptible
      multicast range 224.0.0.0-239.255.255.255
   """
   try:
      ip = Arnet.IpGenAddr( addr )
   except ValueError:
      ec = '129'
      error( errorCodes[ ec ] )
      return False, ec

   if ip.isMulticast:
      return True, ''
   ec = '126'
   error( errorCodes[ ec ].replace( '<address>', addr ) )
   return False, ec

def findMissingKeys( expected_keys, flow_keys ):

   missingKeys = set( expected_keys ) - set( flow_keys.keys() )

   if missingKeys:
      ec = '131'
      error( errorCodes[ ec ].replace( '<key>', str( missingKeys ) ) )
      return ec, missingKeys
   return '', missingKeys

def validateBandwidthType( bwType ):
   if bwType not in [ 'k', 'm', 'g' ]:
      return False
   return True

def parsed_mac( mac ):
   """
   Parse X.X.X, X-X-X-X or X:X:X:X mac addresses to X:X:X:X
   """
   return Arnet.EthAddr( mac ).stringValue

def validateDeviceMac( deviceMac ):
   try:
      Arnet.EthAddr( deviceMac )
   except IndexError:
      return False
   return True

def validateOui( data ):

   if set( [ 'oui', 'vendorName' ] ) != set( data.keys() ):
      return False, ''

   oui = data[ 'oui' ]
   vendor = data[ 'vendorName' ]

   if not oui or not vendor:
      return False, ''

   if not isinstance( vendor, str ) or not isinstance( oui, str ):
      return False, ''

   dlim = ''
   if '.' in oui:
      dlim = '.'

   if '-' in oui:
      dlim = '-'

   if ':' in oui:
      dlim = ':'
   if dlim:

      oui = oui.replace( dlim, '' )

   if len( oui ) > 6:
      return False, ''
   elif not oui.isalnum():
      return False, ''
   else:
      return True, oui

def validateBandwidth( bw ):
   try:
      assert isinstance( bw, int ) and bw >= 0
   except AssertionError:
      return False

   return True

def replaceAll( text, wordMap ):
   for key, value in wordMap.items():
      text = text.replace( key, str( value ) )
   return text

def validatePort( port ):
   port = port.lower().capitalize()
   if not ( port.startswith( 'Ethernet' ) or port.startswith( 'Vlan' ) ):

      return False

   try:
      Arnet.IntfId( port )
   except IndexError:

      return False
   return True

def validateRp( data ):
   expected_keys = [ 'chassis-id', 'interface-name', 'reservation-percent' ]
   missingKeys = set( expected_keys ) - set( data.keys() )
   if missingKeys:
      ec = '131'
      error( errorCodes[ ec ].replace( '<key>', str( missingKeys ) ) )
      return False, ec
   try:
      rp = float( data[ 'reservation-percent' ] )
   except ValueError:
      ec = '139'
      error( errorCodes[ ec ] )
      return False, ec
   if not( rp >= 0.0 and rp <= 1.0 ):
      ec = '138'
      error( errorCodes[ ec ] )
      return False, ec
   valid = validateDeviceMac( data[ 'chassis-id' ] )
   if not valid:
      ec = '137'
      error( errorCodes[ ec ].replace( '<deviceId>', data[ 'chassis-id' ] ) )
      return False, ec
   if not data[ 'interface-name' ].startswith( 'Ethernet' ):
      ec = '156'
      error( errorCodes[ ec ] )
      return False, ec
   return True, ''

def validateDeviceID( devID ):
   """
   Validates if we received a valid ingress interface ID.
   Validity determined on expected format <deviceId>-<port>, precense of macID and
   valid macID for device, precense of port and valid port(EthernetXX/vlanXX)
   """
   if devID:

      if '-' in devID and devID.count( '-' ) == 1:
         device, intfId = devID.split( '-' )

         if not device:
            ec = '132'
            return False, ec
         if not intfId:
            ec = '133'
            return False, ec

         valid = validateDeviceMac( device )
         if not valid:
            ec = '137'
            return False, ec

         valid = validatePort( intfId )
         if not valid:
            ec = '184'
            return False, ec

      else:
         ec = '183'
         return False, ec

   else:
      ec = '182'
      return False, ec
   return True, ""

def get_flows_active_template( bw, destinationIP, deviceId, iif, oifs, sourceIP ):
   return {
            "activityState": False,
            "activityTime": "",
            "deviceActive": True,
            "bandwidth": bw,
            "destinationIP": destinationIP,
            "deviceId": deviceId,
            "iif": iif,
            "oifs": oifs,
            "sourceIP": sourceIP
        }

def getFlowBwProgram( agentStatus ):
   flows = []
   bwProgrammed = {}
   for mcastKey in agentStatus.flowProgrammed:
      bwValue = agentStatus.flowProgrammed[ mcastKey ].bw.value
      devices = agentStatus.flowProgrammed[ mcastKey ].device

      for _, deviceValue in devices.items():
         deviceId = deviceValue.ethAddr
         if deviceId not in bwProgrammed:
            bwProgrammed[ deviceId ] = {}

         iif = deviceValue.iif
         if iif not in bwProgrammed[ deviceId ]:
            bwProgrammed[ deviceId ][ iif ] = { 'rxTotal': 0, 'txTotal': 0 }
         bwProgrammed[ deviceId ][ iif ][ 'rxTotal' ] += bwValue

         oifs = []
         for oif in deviceValue.oif:
            if oif not in bwProgrammed[ deviceId ]:
               bwProgrammed[ deviceId ][ oif ] = { 'rxTotal': 0, 'txTotal': 0 }
            bwProgrammed[ deviceId ][ oif ][ 'txTotal' ] += bwValue
            oifs.append( oif )
         flows.append( get_flows_active_template( bwValue, mcastKey.group, deviceId,
                                                iif, oifs, mcastKey.source ) )

   return flows, bwProgrammed


def getStoredReceivers( apiConfig ):
   storedReceivers = {}

   for mcastKey in apiConfig.mcastReceiver:
      receivers = apiConfig.mcastReceiver[ mcastKey ].receivers
      source = mcastKey.source
      group = mcastKey.group

      for receiver in receivers.keys():
         device = receiver
         intfIds = receivers[ receiver ].recvIntfs.keys()
         for recvIntf in intfIds:
            trackingId = receivers[ receiver ].recvIntfs[ recvIntf ].trackingId

            if source not in storedReceivers:
               storedReceivers[ source ] = {}
            if group not in storedReceivers[ source ]:
               storedReceivers[ source ][ group ] = {}
            if trackingId not in storedReceivers[ source ][ group ]:
               storedReceivers[ source ][ group ][ trackingId ] = {}

            if device not in storedReceivers[ source ][ group ][ trackingId ]:
               storedReceivers[ source ][ group ][ trackingId ][ device ] = {
               "destinationIP": group,
               "deviceId": device,
               "messages": [],
               "receiverIntfIds": [ recvIntf ],
               "sourceIP": source,
               "success": True,
               "trackingID": trackingId
               }
            else:
               storedReceivers[ source ][ group ][ trackingId ][ device ]\
                                       [ 'receiverIntfIds' ].append( recvIntf )
   return storedReceivers
