# Copyright (c) 2013 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import IntfCli
import MvrpCliLib
from CliModel import Bool, Enum, Dict, Model, List, Str
from IntfModels import Interface
from Arnet import sortIntf
from MrpCliLib import tacMrpAppState, tacMrpRegState
from MrpCliLib import mrpAppStateFmt, mrpRegStateFmt

_intfFmt = lambda intf: '--' if not intf else IntfCli.Intf.getShortname( intf )

_vlanFmt = lambda vlanId: vlanId if vlanId else '--'

_attrTypeFmt = { 1: 'Vlan' }

class MvrpIntfInfo( Model ):
   mvrpEnabled = Bool( help="Indicates whether Mvrp is enabled" )

   class Intf( Model ):
      operState = Enum( values=( 'Disabled', 'LinkDown', 'Blocked', 'Active' ),
                        help="Operational Mvrp state on this interface: "
                             "Disabled -- Mvrp is disabled, "
                             "LinkDown -- Physical link is down, "
                             "Blocked -- STP port state is blocked, "
                             "Active -- Mvrp is running successfully" )

      registeredVlans = List( valueType=int,
                              help="Mvrp vlans registered on a given port" )
      declaredVlans = List( valueType=int,
                              help="Mvrp vlans declared on a given port" )

   interfaces = Dict( keyType=Interface, valueType=Intf,
                      help="Mvrp information for a port" )

   def render( self ):
      _mvrpEnabledDispFmt = { True : 'Enabled', False: 'Disabled' }

      # Global MVRP information
      print
      print '  MVRP Global Status :', _mvrpEnabledDispFmt[ self.mvrpEnabled ]
      print

      # MVRP information per interface
      headings = ( "Port", "Admin State", "Registered Vlans", "Declared Vlans" )
      headingsFmt = '  %-10.10s  %-12.12s  %-20.20s  %-20.20s'
      print headingsFmt % headings
      print '  ' + '-'*10 + ' '*2 + '-'*12 +' '*2 + '-'*20 + ' '*2 + '-'*20

      for intfId in sortIntf( self.interfaces ):
         interface = self.interfaces[ intfId ]
         if interface.operState == 'Disabled':
            continue

         regVlans = [ str( vlan ) for vlan in sorted( interface.registeredVlans ) ]
         decVlans = [ str( vlan ) for vlan in sorted( interface.declaredVlans ) ]

         # Print the list of ports that are members of this VLAN. If the list is
         # too long to fit on one line, wrap it onto subsequent lines.
         i, j = 0, 0
         line = '  %-10.10s  %-12.12s  ' % ( _intfFmt( intfId ),
                                             interface.operState )
         firstLine = True

         # Try to fit in both registered and declared vlans in this line.
         while i < len( regVlans ) and j < len( decVlans ):
            if not line:
               line = ' '*28
            elif line and firstLine:
               firstLine = False
            else:
               # Line got full
               print line
               line = ''
               continue

            line += regVlans[ i ]
            i += 1

            # Try to pack in as many registered vlans as we can. The remaining
            # ones will be wrapped onto the next line.
            for intf in regVlans[ i: ]:
               if ( len( line ) + len( ', ' ) + len( intf ) > 48 ):
                  break
               else:
                  line += ', '
                  line += intf
                  i += 1

            # Try to pack in as many declared vlans as we can. The remaining
            # ones will be wrapped onto the next line.
            line = '%-50.50s' % line
            line += decVlans[ j ]
            j += 1

            for intf in decVlans[ j: ]:
               if ( len( line ) + len( ', ' ) + len( intf ) > 70 ):
                  break
               else:
                  line += ', '
                  line += intf
                  j += 1

         # Try to fit in only registered vlans in this line.
         while i < len( regVlans ):
            if not line:
               line = ' '*28
            elif line and firstLine:
               firstLine = False
            else:
               # Line got full
               print line
               line = ''
               continue

            line += regVlans[ i ]
            i += 1

            for intf in regVlans[ i: ]:
               if ( len( line ) + len( ', ' ) + len( intf ) > 48 ):
                  break
               else:
                  line += ', '
                  line += intf
                  i += 1

         # Try to fit in only declared vlans in this line.
         while j < len( decVlans ):
            if not line:
               line = ' '*50
            elif line and firstLine:
               firstLine = False
               line = '%-50.50s' % line
            else:
               # Line got full
               print line
               line = ''
               continue

            line += decVlans[ j ]
            j += 1

            for intf in decVlans[ j: ]:
               if ( len( line ) + len( ', ' ) + len( intf ) > 70 ):
                  break
               else:
                  line += ', '
                  line += intf
                  j += 1

         # The last line is ready.
         print line
         if regVlans or decVlans:
            # Add blank line only when there were VLANs to display
            print
         line = ''

class MvrpVlanInfo( Model ):
   class Vlan( Model ):
      registeredIntfs = List( valueType=str,
                              help="Ports on which a given Mvrp Vlan is registered" )
      declaredIntfs = List( valueType=str,
                              help="Ports on which a given Mvrp Vlan is declared" )

   vlans = Dict( keyType=int, valueType=Vlan,
                      help="Mvrp information for a vlan" )

   def render( self ):
      # MVRP information per VLAN
      headingsFmt = '  %-6.6s  %-30.30s  %-30.30s'
      headings = ( "Vlan", "Registered Ports", "Declared Ports" )
      print headingsFmt % headings
      print '  ' + '-'*6 + ' '*2 + '-'*30 + ' '*2 + '-'*30

      for vlanId in sorted( self.vlans ):
         vlan = self.vlans[ vlanId ]

         regIntfs = sortIntf( _intfFmt( intf ) for intf in vlan.registeredIntfs )
         decIntfs = sortIntf( _intfFmt( intf ) for intf in vlan.declaredIntfs )

         # Print the list of ports that are members of this VLAN. If the list is
         # too long to fit on one line, wrap it onto subsequent lines.
         i, j = 0, 0
         firstLine = True
         line = ''

         # Try to fit in both registered and declared interfaces in this line.
         while i < len( regIntfs ) and j < len( decIntfs ):
            if firstLine:
               # First line will contain vlanId
               firstLine = False
               line = '  %-6.6s  ' % vlanId
            elif not line:
               line = ' '*10
            else:
               # Line got full
               print line
               line = ''
               continue

            line += regIntfs[ i ]
            i += 1

            # Try to pack in as many registered intfs as we can. The remaining
            # ones will be wrapped onto the next line.
            for intf in regIntfs[ i: ]:
               if ( len( line ) + len( ', ' ) + len( intf ) > 40 ):
                  break
               else:
                  line += ', '
                  line += intf
                  i += 1

            # Try to pack in as many declared intfs as we can. The remaining
            # ones will be wrapped onto the next line.
            line = '%-42.42s' % line
            line += decIntfs[ j ]
            j += 1

            for intf in decIntfs[ j: ]:
               if ( len( line ) + len( ', ' ) + len( intf ) > 72 ):
                  break
               else:
                  line += ', '
                  line += intf
                  j += 1

         # Try to fit in only registered interfaces in this line.
         while i < len( regIntfs ):
            if firstLine:
               # First line will contain vlanId
               firstLine = False
               line = '  %-6.6s  ' % vlanId
            elif not line:
               line = ' '*10
            else:
               # Line got full
               print line
               line = ''
               continue

            line += regIntfs[ i ]
            i += 1

            for intf in regIntfs[ i: ]:
               if ( len( line ) + len( ', ' ) + len( intf ) > 40 ):
                  break
               else:
                  line += ', '
                  line += intf
                  i += 1

         # Try to fit in only declared interfaces in this line.
         while j < len( decIntfs ):
            if firstLine:
               # First line will contain vlanId
               firstLine = False
               line = '  %-6.6s  ' % vlanId
               line = '%-42.42s' % line
            elif not line:
               line = ' '*42
            else:
               # Line got full
               print line
               line = ''
               continue

            line += decIntfs[ j ]
            j += 1

            for intf in decIntfs[ j: ]:
               if ( len( line ) + len( ', ' ) + len( intf ) > 72 ):
                  break
               else:
                  line += ', '
                  line += intf
                  j += 1

         # The last line is ready.
         print line
         print
         line = ''

class MrpDatabase( Model ):

   class IntfDatabase( Model ):

      class AttrTypeDatabase( Model ):

         class Attribute( Model ):
            applicantState = Enum( values=tacMrpAppState.attributes,
                                   help="MRP applicant state" ) 
            registrarState = Enum( values=tacMrpRegState.attributes,
                                   help="MRP registrar state" )
            value = Str( help="String representation of the MRP attribute value",
                         optional=True )

         attributes = Dict( keyType=int, valueType=Attribute,
                            help="MRP attribute information indexed by keys" )
      attributeTypes = Dict( keyType=int, valueType=AttrTypeDatabase,
                             help="MRP type attribute information" )
   interfaces = Dict( keyType=Interface, valueType=IntfDatabase,
                      help="MRP attribute information indexed by interface" )

   def render( self ):
      print ' ' * 51 + 'State'
      headingsFmt = '  %-10.10s %-17.17s %-20.20s %-3.3s %-3.3s'
      headings = ( "Port", "Type", "Key", "App", "Reg" )
      print headingsFmt % headings
      print headingsFmt % ( '-'*10, '-'*17, '-'*20, '-'*3, '-'*3 )
      rowFmt = headingsFmt

      for intf in sortIntf( self.interfaces ):
         interface = self.interfaces[ intf ]
         for attrType in sorted( interface.attributeTypes ):
            attrTypeInfo = interface.attributeTypes[ attrType ]
            for attrKey in sorted( attrTypeInfo.attributes ):
               attribute = attrTypeInfo.attributes[ attrKey ]

               print rowFmt % ( _intfFmt( intf ), _attrTypeFmt[ attrType ], attrKey,
                                mrpAppStateFmt[ attribute.applicantState ],
                                mrpRegStateFmt[ attribute.registrarState ] )

               if attribute.value:
                  print MvrpCliLib.getAttrValues( attribute.value, attrType )
                  print
