#!/usr/bin/env python
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, print_function

import subprocess
import struct
import fcntl
import datetime
import os
import cjson
# pylint: disable=import-error
import HwEpochPolicy

# Constants for formatting data in bde ioctls
_IOC_NRBITS = 8
_IOC_TYPEBITS = 8
_IOC_SIZEBITS = 14
_IOC_DIRBITS = 2

_IOC_NRMASK = ((1 << _IOC_NRBITS)-1)
_IOC_TYPEMASK = ((1 << _IOC_TYPEBITS)-1)
_IOC_SIZEMASK = ((1 << _IOC_SIZEBITS)-1)
_IOC_DIRMASK = ((1 << _IOC_DIRBITS)-1)

_IOC_NRSHIFT = 0
_IOC_TYPESHIFT = (_IOC_NRSHIFT+_IOC_NRBITS)
_IOC_SIZESHIFT = (_IOC_TYPESHIFT+_IOC_TYPEBITS)
_IOC_DIRSHIFT = (_IOC_SIZESHIFT+_IOC_SIZEBITS)

_IOC_NONE = 0
_IOC_WRITE = 1
_IOC_READ = 2

debugMode = False

def _IOC(_dir, _type, nr, size):
   return ((_dir) << _IOC_DIRSHIFT) | \
   ((_type) << _IOC_TYPESHIFT) | \
   ((nr)   << _IOC_NRSHIFT) | \
   ((size) << _IOC_SIZESHIFT)

def _IO(_type, nr):
   return _IOC(_IOC_NONE, (_type), (nr), 0)

def log( msg ):
   print( str( datetime.datetime.now() ), msg )

def runAndMaybeLogCmd( cmd, debug ):
   returnCode = 0
   output = ''
   try:
      output = subprocess.check_output( cmd )
   except subprocess.CalledProcessError as e:
      returnCode = e.returncode
      output = e.output
   if debug:
      log( "cmd = %s, output = %s, rc= %d" % ( ' '.join( cmd ), output,
                                               returnCode ) )
   else:
      log( "cmd = %s, rc = %d" % ( ' '.join( cmd ), returnCode ) )

def logIntf( intf ):
   cmd = [ 'ip', '-d', 'link', 'show', intf ]
   runAndMaybeLogCmd( cmd, True )

def logIntfAddr( intf ):
   cmd = [ 'ip', '-d', 'addr', 'show', intf ]
   runAndMaybeLogCmd( cmd, True )

class Driver( object ):
   def __init__( self ):
      self.fabDevName_ = "/dev/fabric"
      self.bdeDevName_ = "/dev/arista-bde"

   def load( self ):
      if self.dev( self.fabDevName_ ):
         log( "%s exists. exiting" % self.fabDevName_ )
         exit( 1 )

      self.bdeModuleParamIs()
      self.drvIs()
      self.fabDevIs()
      self.bdeDevIs()

   def bdeIoctlIs( self, modId, busId ):
      with open( self.bdeDevName_, "w" ) as f:
         data = struct.pack(
            'IIIIQII256B', 0, 0, busId, modId, 0, 0, 0,
            *[ 0 for _ in range( 256 ) ] )
         cmd = _IO( ord( 'L' ), 52 )
         ret = fcntl.ioctl( f, cmd, data )
         rc = struct.unpack( 'IIIIQII256B', ret )[ 1 ]
         log( "%s attach dev modId=%d busId=%d rc=%d" % (
              self.bdeDevName_, modId, busId, rc ) )
         log( "calling enable interrupts ioctl" )
         data = struct.pack(
            'IIIIQII256B', 0, 0, 0, 0, 0, 0, 1, busId,
            *[ 0 for _ in range( 255 ) ] ) 
         cmd = _IO( ord( 'L' ), 6 )
         ret = fcntl.ioctl( f, cmd, data )
         rc = struct.unpack( 'IIIIQII256B', ret )[ 1 ]
         log( "%s enable intr modId=%d busId=%d rc=%d" % (
              self.bdeDevName_, modId, busId, rc ) )

   def devUpIs( self, modId, busId ):
      self.bdeIoctlIs( modId, busId )

      cmd = [ "fab", "devrole", str( modId ), "switch" ]
      runAndMaybeLogCmd( cmd, debugMode )

      cmd = [ "fab", "devstate", str( modId ), "up" ]
      runAndMaybeLogCmd( cmd, debugMode )

   def fabIntfUpIs( self, mac ):
      self.intfUpIs( "fabric", mac )

   def sviIntfUpIs( self, intf, mac ):
      self.intfUpIs( intf, mac )

   def intfIpIs( self, intf, ip, ipv=4 ):
      if not isinstance( ip, list ):
         ip = [ ip ]

      if ipv not in [ 4, 6 ]:
         log( "Wrong ip version supplied for %s" % intf )
         return

      if ipv == 6:
         cmd = [ 'sudo', 'sysctl', '-w',
                 'net.ipv6.conf.%s.disable_ipv6=0' % intf ]
         runAndMaybeLogCmd( cmd, debugMode )
         cmd = [ 'cat', '/proc/sys/net/ipv6/conf/%s/disable_ipv6' % intf ]
         runAndMaybeLogCmd( cmd, debugMode )
      for addr in ip:
         if ipv == 4:
            cmd = [ '/usr/sbin/ip', 'address', 'add', addr, 'dev', intf ]
         else:
            cmd = [ '/usr/sbin/ip', '-6', 'address', 'add', addr, 'dev', intf ]
         runAndMaybeLogCmd( cmd, debugMode )
      if debugMode:
         logIntfAddr( intf )

   def intfIpv6DisableIs( self, intf ):
      # Per Linux kernel config information,
      # /proc/sys/net/conf/<interface>/forwarding defaults to FALSE if global
      # forwarding (/proc/sys/net/conf/all/forwarding) is disabled, otherwise
      # defaults to TRUE.
      #
      # We explicitly disable IPv6 forwarding on any interface that does not have
      # IPv6 forwarding enabled.
      cmd = [ 'sudo', 'sysctl', '-w',
              'net.ipv6.conf.%s.forwarding=0' % intf ]
      runAndMaybeLogCmd( cmd, debugMode )
      cmd = [ 'cat', '/proc/sys/net/ipv6/conf/%s/forwarding' % intf ]
      runAndMaybeLogCmd( cmd, debugMode )

   def ipv6ForwardingIs( self, enabled ):
      # NOTE WELL:
      #
      # Per Linux kernel proc filesystem configuration documentation,
      # the default value of /proc/sys/net/conf/<interface>/forwarding defaults to
      # the value of /proc/sys/net/ipv6/conf/all/forwarding. Any ports that don't
      # have IPv6 forwarding turned on in EOS config should explicitly have
      # forwarding turned off before globally enabling IPv6 forwarding logic via
      # this function.
      cmd = [ 'sudo', 'sysctl', '-w',
               'net.ipv6.conf.all.forwarding=%d' % ( enabled ) ]

      runAndMaybeLogCmd( cmd, debugMode )
      cmd = [ 'cat', '/proc/sys/net/ipv6/conf/all/forwarding' ]
      runAndMaybeLogCmd( cmd, debugMode )

   def ipv6NbrSolicitRuleIs( self ):
      # add a rule in ip6tables to allow neighbor solicitation packets
      cmd = [ 'sudo', 'ip6tables', '-A', 'INPUT', '-p', 'icmpv6', '--icmpv6-type',
              'neighbour-solicitation', '-j', 'ACCEPT' ]
      runAndMaybeLogCmd( cmd, debugMode )
      cmd = [ 'sudo', 'ip6tables', '-L', 'INPUT' ]
      runAndMaybeLogCmd( cmd, debugMode )

   def sviIntfIs( self, intf, vlan ):
      cmd = [ "/usr/sbin/ip", "link", "add", "link", "fabric", "name", intf,
           "type", "vlan", "id", str( vlan ) ]
      runAndMaybeLogCmd( cmd, debugMode )
      if debugMode:
         logIntf( "vlan%s" % str( vlan ) )

   def intfUpIs( self, intf, mac ):
      cmd = [ '/usr/sbin/ip', 'link', 'set', 'dev', intf, 'addr', mac ]
      runAndMaybeLogCmd( cmd, debugMode )
      cmd = [ '/usr/sbin/ip', 'link', 'set', 'dev', intf, 'up' ]
      runAndMaybeLogCmd( cmd, debugMode )
      if debugMode:
         logIntf( intf )

   def bdeModuleParamIs( self ):
      log( "configure BDE module param" )
      with open( '/etc/modprobe.d/arista-bde.conf', 'w' ) as f:
         f.write( 'options arista-bde fixed_modid=1\n' )

   def dev( self, devName ):
      ret = False
      try:
         with open( "%s" % devName, "r" ) as _:
            ret = True
      except IOError:
         pass
      return ret

   def drvIs( self ):
      cmd = [ "/sbin/modprobe", "strata-dma-drv" ]
      appEnabled = 'app_en=0' # app dma disabled 
      if HwEpochPolicy.getPdpModeEnabled():
         appEnabled = 'app_en=1'
      cmd.append( appEnabled )
      runAndMaybeLogCmd( cmd, debugMode )

   def devIs( self, devName, major, minor ):
      cmd = [ "/bin/mknod", devName, "c", str( major ), str( minor ), "-m=666" ]
      runAndMaybeLogCmd( cmd, debugMode )

   def fabQdiscIs( self ):
      # Setup queues and priority maps for fabric device in kernel
      cmd = [ "/usr/sbin/tc", "qdisc", "add", "dev", "fabric", "root", "handle",
           "1:", "prio", "bands", "3", "priomap" ]
      cmd.extend( "1 2 2 2 1 2 0 0 1 1 1 1 1 1 1 1".split() )
      runAndMaybeLogCmd( cmd, debugMode )

   def fabDevIs( self ):
      self.devIs( self.fabDevName_, 240, 0 )
      if not self.dev( self.fabDevName_ ):
         log( "%s error. exiting" % self.fabDevName_ )
         exit( 1 )

      self.fabQdiscIs()

   def bdeDevIs( self ):
      self.devIs( self.bdeDevName_, 127, 0 )
      if not self.dev( self.bdeDevName_ ):
         log( "%s error. exiting" % self.bdeDevName_ )
         exit( 1 )

   # Using schan utility to program registers
   def readReg( self, busId, address ):
      cmd = [ "schan", busId, "readreg", "0", "3", address ]
      output = subprocess.check_output( cmd )
      return output.split()

   def writeReg( self, busId, address, value ):
      cmd = [ "schan", busId, "writereg", "0", "3", address ]
      cmd.extend( value )
      retCode = subprocess.call( cmd )
      return retCode

   def enableT2CpuPortIs( self, busId ):
      devId = "%d:00.0" % busId

      # 0xae007400 MMU_THDM_DB_PORTSP_RX_ENABLE0_64.mmu0
      out = self.readReg( devId, "0xae007400" )
      log( "devId=%s MMU_THDM_DB_PORTSP_RX_ENABLE0_64 read: %s" % ( devId, out ) )

      value = [ out[ 0 ], hex( int( out[ 1 ], 16 ) | 0x100000 ) ]
      rc = self.writeReg( devId, "0xae007400", value )
      log( "devId=%s MMU_THDM_DB_PORTSP_RX_ENABLE0_64 write: %s rc=%d" % \
           ( devId, value, rc ) )

      # 0xb2007400 MMU_THDM_MCQE_PORTSP_RX_ENABLE0_64.mmu0
      out = self.readReg( devId, "0xb2007400" )
      log( "devId=%s MMU_THDM_MCQE_PORTSP_RX_ENABLE0_64 read: %s" % ( devId, out ) )

      value = [ out[ 0 ], hex( int( out[ 1 ], 16 ) | 0x100000 ) ]
      rc = self.writeReg( devId, "0xb2007400", value )
      log ( "devId=%s MMU_THDM_MCQE_PORTSP_RX_ENABLE0_64 write: %s rc=%d" % \
            ( devId, value, rc ) )

      # 0x08000034 THDI_INPUT_PORT_XON_ENABLES.cpu0
      out = self.readReg( devId, "0x08000034" )
      log( "devId=%s THDI_INPUT_PORT_XON_ENABLES read: %s" % ( devId, out ) )

      value = [ hex( int( out[ 0 ], 16 ) | 0x20000 ) ]
      rc = self.writeReg( devId, "0x08000034", value )
      log( "devId=%s THDI_INPUT_PORT_XON_ENABLES write: %s rc=%d" % \
           ( devId, value, rc ) )

class Interface( object ):
   def __init__( self ):
      self.ifName = ''
      self.ip = None
      self.ip6 = None
      self.vlan = None

class ConfigParser( object ):
   def __init__( self ):
      self.configFile_ = "/mnt/flash/.asu-reload-data.conf"
      self.jsonConfig_ = {}

   def parse( self ):
      try:
         with open( self.configFile_, "r" ) as f:
            config = f.read()
            self.jsonConfig_ = cjson.decode( config )
      except IOError:
         log( "no config file. exiting" )
         exit( 0 )
      except cjson.DecodeError:
         log( "config parse error. exiting" )
         exit( 1 )
      else:
         version = self.jsonConfig_.get( 'version', '' )
         log( "configFile: %s, version: %s" % ( self.configFile_, version ) )
         if version != "2.0":
            log( "Unsupported version %s" % version )
            exit( 1 )
      os.rename( "/mnt/flash/.asu-reload-data.conf",
                 "/mnt/flash/.asu-reload-data.conf.bak" )

   def devices( self ):
      devices = self.jsonConfig_.get( "devices", [] )
      devCopy = devices[ : ]
      for device in devCopy:
         if not isinstance( device, dict ):
            log( "Unknown device info format %s" % device )
            devices.remove( device )
            continue
         busId = device.get( "busId" )
         modId = device.get( "modId" )
         chipModel = device.get( "chipModel" )
         if not busId or not modId or not chipModel:
            log( "Device info missing or incomplete. busId: %s, modId: %s, chip: %s"
                 % ( busId, modId, chipModel ) )
            devices.remove( device )
            continue
      return devices

   def bridgeMac( self ):
      return self.jsonConfig_.get( "bridgeMac", "00:00:00:00:00:00" )

   def interfaces( self ):
      intfList = []
      intfJsonDict = self.jsonConfig_.get( "interfaces", {} )
      for intf in intfJsonDict:
         intfJson = intfJsonDict[ intf ]
         interface = Interface()
         interface.ifName = intf
         interface.ip = intfJson.get( "ip", [] )
         interface.ip6 = intfJson.get( "ip6", [] )
         interface.vlan = intfJson.get( 'vlan' )
         intfList.append( interface )
      return intfList

class AsuArp( object ):
   def __init__( self ):
      self.drv_ = Driver()
      self.config_ = ConfigParser()

   def run( self, debug=False ):
      global debugMode
      debugMode = debug
      self.config_.parse()
      devices = self.config_.devices()
      if not devices:
         log( "No devices found in configuration. Aborting driver bringup" )
         exit( 1 )
      self.drv_.load()
      for device in devices:
         busId = device.get( "busId" )
         modId = device.get( "modId" )
         chipModel = device.get( "chipModel" )
         self.drv_.devUpIs( modId, busId )
         if chipModel == "bcm56850" or chipModel == "bcm56855":
            self.drv_.enableT2CpuPortIs( busId )
      self.drv_.fabIntfUpIs( self.config_.bridgeMac() )
      enIpv6Forwarding = False
      for svi in self.config_.interfaces():
         # XXX v1 we support only SVIs
         if svi.ifName.startswith( 'vlan' ) and svi.vlan:
            self.drv_.sviIntfIs( svi.ifName, svi.vlan )
            self.drv_.sviIntfUpIs( svi.ifName, self.config_.bridgeMac() )
            if svi.ip:
               self.drv_.intfIpIs( svi.ifName, svi.ip )
            if svi.ip6:
               self.drv_.intfIpIs( svi.ifName, svi.ip6, ipv=6 )
               enIpv6Forwarding = True
            else:
               self.drv_.intfIpv6DisableIs( svi.ifName )
      # NOTE WELL:
      #
      # Per Linux kernel proc filesystem configuration documentation the value
      # of /proc/sys/net/ipv6/conf/all/forwarding controls the default value of
      # /proc/sys/net/conf/<interface>/forwarding. We therefore wait to globally
      # enable until each port has been configured. Otherwise forwarding could be
      # enabled on all ports.
      log( "v6 forwarding %s" % enIpv6Forwarding )
      self.drv_.ipv6ForwardingIs( enIpv6Forwarding )
      if enIpv6Forwarding:
         self.drv_.ipv6NbrSolicitRuleIs()
      log( "v6 processing done" )

