# Copyright (c) 2015 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
from ArnetModel import IpGenericAddress
from CliModel import Model, Bool, Enum, Dict, List, Int, Submodel
from IntfModels import Interface
from TableOutput import createTable, Format
import Tac
import sys

class DecapGroupModelBase( Model ):
   'Base model definition for all ip decap-group types'
   tunnelType = Enum( help='Decap tunnel type',
                      values=Tac.Type( 'Tunnel::Decap::TunnelType' ).attributes )
   persistent = Bool( 'Decap group is present in system configuration' )

class DecapGroupCounterEntry( Model ):
   packets = Int( help="Number of received packets on DecapGroup" )
   octets = Int( help="Number of received octets on DecapGroup" )

class DecapIntfModel( Model ):
   intfId = Interface( 'Decap destination interface' )
   addressFamily = Enum( help='Decap address family',
                         values=Tac.Type( 'Arnet::AddressFamily' ).attributes )
   addressType = Enum( help='Decap address type',
                       values=Tac.Type( 'Tunnel::Decap::AddressType' ).attributes )
   oldConfig = Bool( 'Old way of writing this config' )
   counter = Submodel( valueType=DecapGroupCounterEntry,
                       help='Counter for the decap interface.',
                       optional=True )

class DecapGroupGreModel( DecapGroupModelBase ):
   'GRE decap-group'
   decapIntf = List( help='decap intf key list',
                     valueType=DecapIntfModel )
   decapIp = List( help='decap ip list', valueType=IpGenericAddress )
   decapIpCounters = Dict( keyType=IpGenericAddress,
                           valueType=DecapGroupCounterEntry,
                           help='Map decap IP address to decap group counter',
                           optional=True )

class DecapGroupIpIpModel( DecapGroupModelBase ):
   'IP-in-IP decap-group'
   decapIntf = List( help='decap intf key list',
                     valueType=DecapIntfModel )
   decapIp = List( help='decap ip list', valueType=IpGenericAddress )
   decapIpCounters = Dict( keyType=IpGenericAddress,
                           valueType=DecapGroupCounterEntry,
                           help='Map decap IP address to decap group counter',
                           optional=True )

class DecapGroupUdpModel( DecapGroupModelBase ):
   'UDP decap-group'
   decapIntf = List( help='decap intf key list',
                     valueType=DecapIntfModel )
   decapIp = List( help='decap ip list', valueType=IpGenericAddress )
   destinationPort = Int( help='Decap destination port' )
   payloadType = Enum( help='Decap Payload Type',
               values=Tac.Type( 'Tunnel::Decap::PayloadType' ).attributes )

class DecapGroups( Model ):
   __revision__ = 3
   decapGroups = Dict( help='A mapping from name to ip decap-group entry',
                       valueType=DecapGroupModelBase )
   globalUdpDestPortToPayloadType = \
         Dict( keyType=int, valueType=str,
               help='A global mapping from UDP destination port to payload type' )

   def _getColumnFormat( self, **kwargs ):
      baseFormat = { 'justify': 'left', 'border': True }
      baseFormat.update( kwargs )
      if 'maxWidth' in kwargs:
         baseFormat[ 'wrap' ] = True
      fmt = Format( **baseFormat )
      fmt.noPadLeftIs( True )
      return fmt

   def _tunnelTypeName( self, tunnelType ):
      from CliPlugin.DecapGroupCli import tacTunnelType as TunnelType
      _names = {
         TunnelType.ipip: 'IP-in-IP',
         }
      if tunnelType in _names:
         return _names[ tunnelType ]
      else:
         return tunnelType.upper()

   def _isIp4Addr( self, ip ):
      return ip.af == 'ipv4'

   def countersPresent( self ):
      '''
      If there is at least one decap group with counters present in the model,
      return True.  Otherwise, return False.
      '''
      for dg in self.decapGroups.itervalues():
         if dg.tunnelType not in [ 'gre', 'ipip' ]:
            continue
         if dg[ 'decapIpCounters' ] is not None:
            return True
         for decapIntf in dg[ 'decapIntf' ]:
            if decapIntf[ 'counter' ] is not None:
               return True
      return False

   def render( self ):
      if not self.decapGroups:
         return

      countersPresent = self.countersPresent()
      headings = []
      formats = []
      headings.append( 'D' )
      fmt = self._getColumnFormat()
      fmt.padLimitIs( True )
      formats.append( fmt )

      headings.append( 'Name' )
      formats.append( self._getColumnFormat( maxWidth=20 ) )
      headings.append( 'Type' )
      formats.append( self._getColumnFormat() )
      headings.append( 'Info' )
      formats.append( self._getColumnFormat() )
      headings.append( 'Version' )
      fmt = self._getColumnFormat()
      fmt.padLimitIs( True )
      formats.append( fmt )
      headings.append( 'Address Type' )
      formats.append( self._getColumnFormat() )
      headings.append( 'UDP Dest Port' )
      formats.append( self._getColumnFormat() )
      headings.append( 'Payload Type' )
      if countersPresent:
         # DecapGroup counters are present in the output, add the counter columns.
         formats.append( fmt )
         headings.append( 'Packets' )
         formats.append( fmt )
         headings.append( 'Octets' )
      formats.append( self._getColumnFormat( border=False ) )

      table = createTable( tuple( headings ), tableWidth=200 )
      table.formatColumns( *formats )

      from CliPlugin.DecapGroupCli import tacTunnelType as TunnelType
      if TunnelType is None:
         TunnelType = Tac.Type( 'Tunnel::Decap::TunnelType' )
      for name in sorted( self.decapGroups ):
         values = []
         addressFamily = []
         addressType = []
         dport = []
         payloadType = []
         packetCounts = []
         octetCounts = []
         dg = self.decapGroups[ name ]
         count = 0
         if dg.tunnelType != TunnelType.unknown:
            for key in dg.decapIntf:
               count += 1
               if key.intfId == Tac.Value( 'Arnet::IntfId', '' ):
                  values.append( 'all' )
               else:
                  values.append( key.intfId.stringValue )
               if key.addressFamily == 'ipv4':
                  addressFamily.append( 'IPv4' )
               else:
                  addressFamily.append( 'IPv6' )
               addressType.append( key.addressType )
               if countersPresent:
                  packetCounts.append( key.counter.packets )
                  octetCounts.append( key.counter.octets )
            values.extend( dg.decapIp )
            for ip in dg.decapIp:
               count += 1
               if self._isIp4Addr( ip ):
                  addressFamily.extend( [ 'IPv4' ] )
               else:
                  addressFamily.extend( [ 'IPv6' ] )
               if countersPresent:
                  counter = dg.decapIpCounters[ ip ]
                  packetCounts.append( counter.packets )
                  octetCounts.append( counter.octets )
            addressType.extend( [ '' ] * len( dg.decapIp ) )
            if count == 0:
               count = 1

            if dg.tunnelType == TunnelType.udp:
               if not dg.destinationPort and \
                  len( self.globalUdpDestPortToPayloadType ) > 0:
                  # Use global value if any
                  dport.extend(
                        [ self.globalUdpDestPortToPayloadType.keys() ] * count )
                  ptypeList = \
                     [ 'ip' if ptype == 'ipvx' else ptype
                       for ptype in self.globalUdpDestPortToPayloadType.values() ]
                  payloadType.extend( [ ptypeList ] * count )
               else:      
                  dport.extend( [ dg.destinationPort ] * count )
                  payloadType.extend( [ dg.payloadType ] * count )
            else:
               dport.extend( [ '' ] * count )
               payloadType.extend( [ '' ] * count )

         if len( values ) == 0:
            values.append( '' )
            addressFamily.append( '' )
            addressType.append( '' )
            dport.append( '' )
            payloadType.append( '' )
            if countersPresent:
               packetCounts.append( '' )
               octetCounts.append( '' )
         if countersPresent:
            for each, family, addType, port, pt, packets, octets in zip(
                  values, addressFamily, addressType, dport, payloadType,
                  packetCounts, octetCounts ):
               vals = []
               dynamic = ' ' if dg.persistent else '*'
               vals.append( dynamic )
               vals.append( name )
               vals.append( self._tunnelTypeName( dg.tunnelType ) )
               vals.append( each )
               vals.append( family )
               vals.append( addType )
               if isinstance( port, list ):
                  vals.append( port[ 0 ] )
               else:
                  vals.append( port )
               if isinstance( pt, list ):
                  vals.append( pt[ 0 ] )
               else:
                  vals.append( pt )
               vals.append( packets )
               vals.append( octets )
               table.newRow( *vals )

               if isinstance( port, list ):
                  for p, ptype in zip( port[ 1: ], pt[ 1: ] ):
                     vals = [ '', '', '', '', '', '' ]
                     vals.append( p )
                     vals.append( ptype )
                     vals.extend( [ '', '' ] )
                     table.newRow( *vals )
         else:
            for each, family, addType, port, pt in zip( values, addressFamily,
               addressType, dport, payloadType ):
               vals = []
               dynamic = ' ' if dg.persistent else '*'
               vals.append( dynamic )
               vals.append( name )
               vals.append( self._tunnelTypeName( dg.tunnelType ) )
               vals.append( each )
               vals.append( family )
               vals.append( addType )
               if isinstance( port, list ):
                  vals.append( port[ 0 ] )
               else:
                  vals.append( port )
               if isinstance( pt, list ):
                  vals.append( pt[ 0 ] )
               else:
                  vals.append( pt )
               table.newRow( *vals )
              
               if isinstance( port, list ):
                  for p, ptype in zip( port[ 1: ], pt[ 1: ] ):
                     vals = [ '', '', '', '', '', '' ]
                     vals.append( p )
                     vals.append( ptype )
                     table.newRow( *vals )

      # Print note about 'D' column
      print 'NOTE: "D" column indicates dynamic entries'

      # Render the table output
      sys.stdout.write( table.output() )

   def degrade( self, dictRepr, revision ):
      if revision == 2:
         for entry in dictRepr[ 'decapGroups' ].itervalues():
            if entry[ 'tunnelType' ] == 'ipip':
               if len( entry[ 'decapIntf' ] ) > 0:
                  if not isinstance( entry[ 'decapIntf' ][ 0 ], str ):
                     entry[ 'decapIntf' ] = entry[ 'decapIntf' ][ 0 ][ 'intfId' ]
      elif revision == 1:
         for entry in dictRepr[ 'decapGroups' ].itervalues():
            if entry[ 'tunnelType' ] == 'gre':
               if len( entry[ 'decapIp' ] ) > 0:
                  entry[ 'decapIp' ] =  entry[ 'decapIp' ][ 0 ]
               else:
                  entry[ 'decapIp' ] =  '0.0.0.0'
            else:
               del entry[ 'decapIp' ]
               if len( entry[ 'decapIntf' ] ) > 0:
                  if not isinstance( entry[ 'decapIntf' ][ 0 ], str ):
                     entry[ 'decapIntf' ] = entry[ 'decapIntf' ][ 0 ][ 'intfId' ]

      return dictRepr
