# Copyright (c) 2016 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac
from CliModel import Bool, Dict, Enum, Int, List, Str
from CliModel import Model, Submodel
from ArnetModel import IpGenAddrWithFullMask, MacAddress
from TableOutput import createTable, Format
from HaloLib import TableType
from AclLib import AclDirection, AclType, AclAction

#---------------------------------------
# Render helper methods
#---------------------------------------

# For rendering output, tabLevels maintains the level of indentation
# applied before displaying an output. All non-tabular rendering in
# this module is done via printt() which honours tabLevels.
def printt( msg, tabLevel ):
   '''Print the given string 'msg' with indentation'''
   print ' ' * 3 * tabLevel + msg

def tabular_printt( table, tabLevel ):
   line = ''
   for char in table.output():
      if char == '':
         continue
      if char == '\n':
         printt( line, tabLevel )
         line = ''
      else:
         line += char

def strBool( b ):
   return 'Y' if b else 'N'

#---------------------------------------
# Cli models
#---------------------------------------

#---------------------------------------
# Cli models for hidden show commands
#---------------------------------------

class AclId( Model ):
   tableType = Enum( values=TableType.attributes,
                     help='Table type' )
   chipRuleListId = Int( help='Chip rule list ID' )
   direction = Enum( values=AclDirection.attributes,
                     help='ACL direction' )
   def render( self, tabLevel ): #pylint:disable=W0221
      printt( 'ACL ID %s:%s:%s' % ( self.tableType,
                                    self.chipRuleListId,
                                    self.direction ), tabLevel )

class AclKey( Model ):
   name = Str( help='ACL CLI name' )
   aclType =  Enum( values=AclType.attributes,
                    help='ACL type' )
   direction = Enum( values=AclDirection.attributes,
                     help='ACL direction' )
   standard = Bool( help='Standard ACL' )
   vlanId = Int( help='VLAN ID' )

class L4PortRange( Model ):
   rangeStart = Int( help='Port range start' )
   rangeEnd = Int( help='Port range end' )
   def formatStr( self ):
      return '[%d,%d]' % ( self.rangeStart, self.rangeEnd )

class TcpFlag( Model ):
   fin = Bool( help='TCP flag fin', optional=True )
   syn = Bool( help='TCP flag syn', optional=True )
   rst = Bool( help='TCP flag rst', optional=True )
   psh = Bool( help='TCP flag psh', optional=True )
   ack = Bool( help='TCP flag ack', optional=True )
   urg = Bool( help='TCP flag urg', optional=True )

class IpFilter( Model ):
   source = Submodel( valueType=IpGenAddrWithFullMask,
                      help='Source IP Prefix' )
   destination = Submodel( valueType=IpGenAddrWithFullMask,
                           help='Destination IP Prefix' )
   sportRange = Submodel( valueType=L4PortRange,
                          help='Source L4 port range' )
   dportRange = Submodel( valueType=L4PortRange,
                          help='Destination L4 port range' )
   fragments = Bool( help='Match packet fragments', optional=True )
   dscp = Int( help='DSCP value', optional=True )
   matchDscp = Bool( help='Match DSCP', optional=True )
   proto = Int( help='IP protocol number', optional=True )
   matchProto = Bool( help='Match IP protocol', optional=True )
   ttl = Int( help='TTL value', optional=True )
   includeTtl = Bool( help='Match TTL', optional=True )
   tcpFlag = Submodel( valueType=TcpFlag, help='TCP flags', optional=True )
   established = Bool( help='Established', optional=True )
   icmpType = Int( help='ICMP type', optional=True )
   icmpCode = Int( help='ICMP code', optional=True )
   l4Rule = Bool( help='Match L4 information', optional=True )

class MacFilter( Model ):
   sourceAddr = MacAddress( help='Source MAC address' )
   sourceMask = MacAddress( help='Source MAC mask' )
   destAddr = MacAddress( help='Destination MAC address' )
   destMask = MacAddress( help='Destination MAC mask' )
   proto = Int( help='EtherType protocol' )

class IpRuleConfig( Model ):
   ipFilter = Submodel( valueType=IpFilter,
                        help='IP Filter' )
   action = Enum( values=AclAction.attributes,
                  help='Rule action' )
   log = Bool( help='Logging enabled for the rule', optional=True )
   count = Bool( help='Counters enabled for the rule', optional=True )

class MacRuleConfig( Model ):
   macFilter = Submodel( valueType=MacFilter, help='MAC Filter' )
   action = Enum( values=AclAction.attributes, help='Rule action' )
   log = Bool( help='Logging enabled for the rule', optional=True )
   count = Bool( help='Counters enabled for the rule', optional=True )

class HaloAcl( Model ):
   name = Str( help='ACL CLI name' )
   aclId = Submodel( valueType=AclId,
                     help='ACL ID' )
   aclType =  Enum( values=AclType.attributes,
                    help='ACL type' )
   ipRules = Dict( keyType=str,
                   valueType=IpRuleConfig,
                   help='A mapping from rule ID to IP/IPv6 rule' )
   macRules = Dict( keyType=str,
                    valueType=MacRuleConfig,
                    help='A mapping from rule ID to MAC rule' )
   def render( self, tabLevel ): #pylint:disable=W0221
      self.aclId.render( tabLevel )
      tabLevel = tabLevel + 1
      printt( 'ACL rule count: %d' % ( len( self.ipRules ) + len( self.macRules ) ),
                                     tabLevel )
      tabLevel = tabLevel - 1

class AclKeyToAclId( Model ):
   aclIds = Dict( keyType=str,
                  valueType=AclId,
                  help='ACL ID' )
   aclKey = Submodel( valueType=AclKey,
                      help='ACL key' )

class RootConfig( Model ):
   chipName = Str( help='Chip name' )
   acls = Dict( keyType=str,
                valueType=HaloAcl,
                help='A mapping from ACL ID to ACL' )
   keyToIds = List( valueType=AclKeyToAclId,
                  help='A list of ACL key to ACL ID mappings' )
   def render( self, tabLevel ): #pylint:disable=W0221
      printt( 'Chip name {}'.format( self.chipName ), tabLevel )
      tabLevel = tabLevel + 1
      if len( self.acls ) == 0:
         printt( 'No ACL Applied', tabLevel )
         tabLevel = tabLevel - 1
         return
      for _key, haloAcl in sorted( self.acls.items() ):
         haloAcl.render( tabLevel )
      tabLevel = tabLevel - 1

# the purpose of this model is to output a json model to be used
# as an input to HaloReplay, which is why the 'show platform acl' command
# is hidden
class RootConfigColl( Model ):
   __public__ = False

   configs = Dict( keyType=int,
                   valueType=RootConfig,
                   help='A mapping from Root ID to config' )
   lastAttemptConfigs = Dict( keyType=int,
                              valueType=RootConfig,
                              help='A mapping from Root ID to last attempt config' )
   def render( self, tabLevel=0 ): #pylint:disable=W0221
      if len( self.configs ) == 0:
         printt( 'No chips', tabLevel )
         return
      printt( 'Current Root Configs:', tabLevel )
      tabLevel = tabLevel + 1
      for _key, rootConfig in sorted( self.configs.items() ):
         rootConfig.render( tabLevel )
      tabLevel = tabLevel - 1
      if len( self.lastAttemptConfigs ) == 0:
         return
      printt( 'Last Attempt Root Configs:', tabLevel )
      tabLevel = tabLevel + 1
      for _key, rootConfig in sorted( self.lastAttemptConfigs.items() ):
         rootConfig.render( tabLevel )
      tabLevel = tabLevel - 1

#--------------------------------------------------
# Cli models for non hidden algomatch show commands
#--------------------------------------------------

# Models for representing Halo information
class IpMaskGroup( Model ):
   srcIp = Int( help='Source IP mask length' )
   dstIp = Int( help='Destination IP mask length' )
   srcPort = Int( help='Source port range mask length' )
   dstPort = Int( help='Destination port range mask length' )
   tcpFlag = Str( help='TCP flags' )
   established = Bool( help='Established set' )
   ttl = Bool( help='TTL set' )
   dscp = Bool( help='DSCP set' )
   fragment = Bool( help='IP fragment bit set' )
   protocol = Bool( help='IP protocol field set' )

   def renderMaskGroup( self ):
      mgStr = 'sip:%d,dip:%d,sport:%d,dport:%d' % \
              ( self.srcIp, self.dstIp, self.srcPort, self.dstPort )
      if self.tcpFlag:
         mgStr += ',tcpFlag(%s)' % self.tcpFlag
      if self.established:
         mgStr += ',est'
      if self.ttl:
         mgStr += ',ttl'
      if self.dscp:
         mgStr += ',dscp'
      if self.fragment:
         mgStr += ',frag'
      if self.protocol:
         mgStr += ',proto'
      return mgStr

class MacMaskGroup( Model ):
   srcMac = MacAddress( help='Source MAC' )
   dstMac = MacAddress( help='Destination MAC' )
   ethProto = Bool( help='Ethernet protocol set' )

   def renderMaskGroup( self ):
      mgStr = 'smac:%s,dmac:%s' % ( self.srcMac, self.dstMac )
      if self.ethProto:
         mgStr += ',eth proto'
      return mgStr

class MaskGroup( Model ):
   id = Int( help='Mask group ID' )
   ip = Submodel( IpMaskGroup, help='IP mask group descriptor', optional=True )
   mac = Submodel( MacMaskGroup, help='MAC mask group descriptor', optional=True )
   entries = Int( help='Number of entries using this mask group' )

   def renderDescriptor( self ):
      if self.ip:
         mgStr = self.ip.renderMaskGroup()
      else:
         mgStr = self.mac.renderMaskGroup()
      return '%s' % mgStr

   def renderMaskGroup( self, maskGroupTable ):
      maskGroupTable.newRow( self.id, self.renderDescriptor(), self.entries )

class MaskGroups( Model ):
   maskGroups = List( valueType=MaskGroup, help='List of mask groups' )

   def render( self, tabLevel ): # pylint: disable-msg=W0221
      headers = [ 'Mask group ID', 'Mask group descriptor', 'Entries' ]
      fmt = Format( justify='left' )
      fmt.noPadLeftIs( True )
      fmt.padLimitIs( True )
      maskGroupTable = createTable( headers, tableWidth=100 )
      maskGroupTable.formatColumns( fmt, fmt, fmt )
      for maskGroup in self.maskGroups:
         maskGroup.renderMaskGroup( maskGroupTable )
      tabular_printt( maskGroupTable, tabLevel )

class TableLayout( Model ):
   layout = Dict( keyType=int, valueType=MaskGroup,
                   help='A mapping between a table and a mask group' )

   def render( self, tabLevel ): # pylint: disable-msg=W0221
      headers = [ 'Table ID', 'Mask group ID', 'Mask group descriptor' ]
      fmt = Format( justify='left' )
      fmt.noPadLeftIs( True )
      fmt.padLimitIs( True )
      layoutTable = createTable( headers, tableWidth=100 )
      layoutTable.formatColumns( fmt, fmt, fmt )
      for stage, maskGroup in self.layout.items():
         layoutTable.newRow( stage, maskGroup.id, maskGroup.renderDescriptor() )
      tabular_printt( layoutTable, tabLevel )

# Nested models for info per ( aclType, name, direction, hwAclType, hwAclId ) tuple
class HwAclInfo( Model ):
   maskGroups = Submodel( MaskGroups,
                          help='A set of mask groups',
                          optional=True )
   tableLayout = Submodel( TableLayout,
                           help='A table to mask group layout',
                           optional=True )

   def renderHwAclId( self, hwAclId, direction, hwAclType, tabLevel ):
      printt( 'Hw ACL ID: %d, direction: %s, Hw ACL type: %s' %
              ( hwAclId, direction, hwAclType ), tabLevel )

   def renderMaskGroups( self, aclType, name, direction, hwAclType, hwAclId,
                         tabLevel ):
      if not self.maskGroups:
         return
      self.renderHwAclId( hwAclId, direction, hwAclType, tabLevel )
      self.maskGroups.render( tabLevel + 1 )

   def renderLayout( self, aclType, name, direction, hwAclType, hwAclId, tabLevel ):
      if not self.tableLayout:
         return
      self.renderHwAclId( hwAclId, direction, hwAclType, tabLevel )
      self.tableLayout.render( tabLevel + 1 )

class HwAclTypeHwAclInfo( Model ):
   hwAclId = Dict( keyType=int,
                   valueType=HwAclInfo,
                   help='A mapping of hardware ACL ID to hardware ACL info' )

   def renderMapping( self, mappingTable, aclType, name, direction, hwAclType ):
      hwAclIdStr = ','.join( str( id ) for id in sorted( self.hwAclId.keys() ) )
      mappingTable.newRow( name, aclType, direction, hwAclType, hwAclIdStr )

   def renderMaskGroups( self, aclType, name, direction, hwAclType, tabLevel ):
      for hwAclId, info in sorted( self.hwAclId.items() ):
         info.renderMaskGroups( aclType, name, direction, hwAclType, hwAclId,
                                tabLevel )

   def renderLayout( self, aclType, name, direction, hwAclType, tabLevel ):
      for hwAclId, info in sorted( self.hwAclId.items() ):
         info.renderLayout( aclType, name, direction, hwAclType, hwAclId, tabLevel )

class DirectionHwAclInfo( Model ):
   hwAclType = Dict( keyType=str,
                     valueType=HwAclTypeHwAclInfo,
                     help='A mapping of hardware ACL type to hardware ACL info' )

   def renderMapping( self, mappingTable, aclType, name, direction ):
      for hwAclType, hwAclTypeInfo in self.hwAclType.items():
         hwAclTypeInfo.renderMapping( mappingTable, aclType, name, direction,
                                      hwAclType )

   def renderMaskGroups( self, aclType, name, direction, tabLevel ):
      for hwAclType, hwAclTypeInfo in self.hwAclType.items():
         hwAclTypeInfo.renderMaskGroups( aclType, name, direction, hwAclType,
                                         tabLevel )

   def renderLayout( self, aclType, name, direction, tabLevel ):
      for hwAclType, hwAclTypeInfo in self.hwAclType.items():
         hwAclTypeInfo.renderLayout( aclType, name, direction, hwAclType, tabLevel )

class NameHwAclInfo( Model ):
   direction = Dict( keyType=str,
                     valueType=DirectionHwAclInfo,
                     help='A mapping of direction to hardware ACL info' )

   def renderMapping( self, mappingTable, aclType, name ):
      for direction, dirInfo in self.direction.items():
         dirInfo.renderMapping( mappingTable, aclType, name, direction )

   def renderMaskGroups( self, aclType, name, tabLevel ):
      printt( '%s %s' % ( name, aclType.upper() ), tabLevel )
      for direction, dirInfo in self.direction.items():
         dirInfo.renderMaskGroups( aclType, name, direction, tabLevel + 1 )

   def renderLayout( self, aclType, name, tabLevel ):
      printt( '%s %s' % ( name, aclType.upper() ), tabLevel )
      for direction, dirInfo in self.direction.items():
         dirInfo.renderLayout( aclType, name, direction, tabLevel + 1 )

class AclTypeHwAclInfo( Model ):
   name = Dict( keyType=str,
                valueType=NameHwAclInfo,
                help='A mapping of name to hardware ACL info' )

   def renderMapping( self, mappingTable, aclType ):
      for name, nameInfo in self.name.items():
         nameInfo.renderMapping( mappingTable, aclType, name )

   def renderMaskGroups( self, aclType, tabLevel ):
      for name, nameInfo in self.name.items():
         nameInfo.renderMaskGroups( aclType, name, tabLevel )

   def renderLayout( self, aclType, tabLevel ):
      for name, nameInfo in self.name.items():
         nameInfo.renderLayout( aclType, name, tabLevel )

class AlgoMatchHwAclInfo( Model ):
   aclType = Dict( keyType=str,
                   valueType=AclTypeHwAclInfo,
                   help='A mapping of ACL type to hardware ACL info' )

   def renderMapping( self, mappingTable ):
      for aclType, aclTypeInfo in self.aclType.items():
         aclTypeInfo.renderMapping( mappingTable, aclType )

   def renderMaskGroups( self, tabLevel ):
      for aclType, aclTypeInfo in self.aclType.items():
         aclTypeInfo.renderMaskGroups( aclType, tabLevel )

   def renderLayout( self, tabLevel ):
      for aclType, aclTypeInfo in self.aclType.items():
         aclTypeInfo.renderLayout( aclType, tabLevel )

# Upper level models returned to CLI show commands
class AlgoMatchAclMapping( Model ):
   __public__ = False

   chipName = Str( help='Chip name' )
   hwAclMappingInfo = Submodel( AlgoMatchHwAclInfo,
                                help='Hardware ACL mapping information' )

   def render( self ):
      headers = [ 'Name', 'ACL Type', 'Direction', 'Hw ACL Type', 'Hw ACL IDs' ]
      fmt = Format( justify='left' )
      fmt.noPadLeftIs( True )
      fmt.padLimitIs( True )
      mappingTable = createTable( headers, tableWidth=100 )
      mappingTable.formatColumns( fmt, fmt, fmt, fmt, fmt )
      self.hwAclMappingInfo.renderMapping( mappingTable )
      print mappingTable.output()

class AlgoMatchMaskGroup( Model ):
   __public__ = False

   chipName = Str( help='Chip name' )
   initMaskGroupInfo = Submodel( AlgoMatchHwAclInfo,
                                 help='Initial mask group information' )
   finalMaskGroupInfo = Submodel( AlgoMatchHwAclInfo,
                                  help='Final mask group information' )

   def render( self ):
      tabLevel = 0
      printt( 'Initial revision', tabLevel )
      self.initMaskGroupInfo.renderMaskGroups( tabLevel + 1 )
      printt( 'Final revision', tabLevel )
      self.finalMaskGroupInfo.renderMaskGroups( tabLevel + 1 )
      printt( '', tabLevel )

class AlgoMatchTableLayout( Model ):
   __public__ = False

   chipName = Str( help='Chip name' )
   finalLayoutInfo = Submodel( AlgoMatchHwAclInfo,
                               help='Final table layout information' )

   def render( self ):
      tabLevel = 0
      printt( 'Final revision', tabLevel )
      self.finalLayoutInfo.renderLayout( tabLevel + 1 )
      printt( '', tabLevel )

class AlgoMatchTableUsage( Model ):
   __public__ = False

   chipName = Str( help='Chip name' )
   usage = Dict( keyType=int, valueType=int,
                 help='The number of entries in every table' )

   def render( self ):
      headers = [ 'Table ID', 'Entries used' ]
      fmt = Format( justify='left' )
      fmt.noPadLeftIs( True )
      fmt.padLimitIs( True )
      usageTable = createTable( headers, tableWidth=100 )
      usageTable.formatColumns( fmt, fmt )
      for table, entries in self.usage.items():
         usageTable.newRow( table, entries )
      print usageTable.output()
