# Copyright (c) 2016 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from ApiBaseModels import BaseModel
from ApiBaseModels import Str
from ApiBaseModels import Int
from ApiBaseModels import Bool
from ApiBaseModels import Float
from ApiBaseModels import List

import Tac
import Tracing
import Logging
import OpenStackLogMsgs

traceHandle = Tracing.Handle( 'OpenStackUwsgiServer' )
log = traceHandle.trace0
warn = traceHandle.trace1
info = traceHandle.trace2
trace = traceHandle.trace3
debug = traceHandle.trace4

SwitchInterface = Tac.Type( 'VirtualNetwork::Client::SwitchInterface' )
SyncStatus = Tac.Type( 'OpenStack::SyncStatus' )

class AgentModel ( BaseModel ):
   uuid = Str( apiName='uuid', tacName='agentUuid', versions=[ 1 ], inputOk=False,
               outputOk=True, description='Agent UUID' )
   agentMode = Str( apiName='agentMode', tacName='agentMode', versions=[ 1 ],
                    inputOk=False, outputOk=True,
                    description='Provision or visibility mode' )
   supportedApis = List( apiName='supportedApis', tacName='supportedApiTypes',
                         versions=[ 1 ], inputOk=False, outputOk=True,
                         description='API types this agent supports',
                         valueType=str )
   isLeader = Bool( apiName='isLeader', tacName=None,
                       versions=[ 1 ], inputOk=False, outputOk=True,
                       description='CVX leader status' )

class ServiceEndPointModel( BaseModel ):
   name = Str( apiName='name', tacName='name', versions=[ 1 ], inputOk=True,
               outputOk=True, description='Endpoint name' )
   authUrl = Str( apiName='authUrl', tacName='authUrl', versions=[ 1 ], inputOk=True,
                  outputOk=True, description='Endpoint auth URL' )
   user = Str( apiName='user', tacName='user', versions=[ 1 ], inputOk=True,
               outputOk=True, description='Endpoint username' )
   password = Str( apiName='password', tacName='password', versions=[ 1 ],
                   inputOk=True, outputOk=True, description='Endpoint password' )
   tenant = Str( apiName='tenant', tacName='tenant', versions=[ 1 ], inputOk=True,
                 outputOk=True, description='Endpoint tenant' )

   def toSysdb( self, region ):
      ep = region.serviceEndPoint.newMember( self.getModelField( 'name' ).value )
      ep.authUrl = self.getModelField( 'authUrl' ).value
      ep.user = self.getModelField( 'user' ).value
      ep.password = self.getModelField( 'password' ).value
      ep.tenant = self.getModelField( 'tenant' ).value

class RegionModel( BaseModel ):
   name = Str( apiName='name', tacName='name', versions=[ 1 ], inputOk=True,
               outputOk=True, description='Region name' )
   syncStatus = Str( apiName='syncStatus', tacName='syncStatus', versions=[ 1 ],
                     inputOk=False, outputOk=True, description='Sync status' )
   syncInterval = Float( apiName='syncInterval', tacName='syncInterval',
                       versions=[ 1 ], inputOk=True, outputOk=True,
                       description='Synchronization interval' )
   syncHeartbeat = Float( apiName='syncHeartbeat', tacName='syncHeartbeat',
                        versions=[ 1 ], inputOk=False, outputOk=True,
                        description='Synchronization heartbeat' )

   def fromSysdb( self, rConfig, rStatus ): # pylint: disable-msg=W0221
      super( RegionModel, self ).fromSysdb( rConfig )
      if rStatus:
         self.getModelField( 'syncStatus' ).value = rStatus.syncStatus

   def toSysdb( self, region ):
      interval = self.getModelField( 'syncInterval' )
      if interval in self.getPopulatedModelFields():
         region.syncInterval = interval.value

class SyncModel( BaseModel ):
   requester = Str( apiName='requester', tacName='requester', versions=[ 1 ],
                    inputOk=True, outputOk=True,
                    description='Id of the client requesting the lock' )
   requestId = Str( apiName='requestId', tacName='requestId', versions=[ 1 ],
                    inputOk=True, outputOk=True,
                    description='The request id for the lock' )

   def toSysdb( self, region ):
      requester = self.getModelField( 'requester' ).value
      requestId = self.getModelField( 'requestId' ).value
      # Compute our proper sync status
      # In order to start a sync both requester and requestId must be
      # set as valid values. To end the sync pass in "" for both
      # requester and requestId.
      if requester and requestId:
         region.requester = requester
         region.requestId = requestId
         region.syncStatus = SyncStatus.syncInProgress
      elif requester == "" and requestId == "":
         region.syncStatus = SyncStatus.syncComplete
         region.requestId = ""
         region.requester = ""

class TenantModel( BaseModel ):
   tenantId = Str( apiName='id', tacName='id', versions=[ 1 ],
                   inputOk=True, outputOk=True, description='Tenant ID' )

class NetworkModel( BaseModel ):
   networkId = Str( apiName='id', tacName='id', versions=[ 1 ], inputOk=True,
                    outputOk=True, description='Network ID' )
   tenantId = Str( apiName='tenantId', tacName=None, versions=[ 1 ], inputOk=True,
                    outputOk=True, description='Tenant this network belongs to' )
   networkName = Str( apiName='name', tacName='networkName', versions=[ 1 ],
                      inputOk=True, outputOk=True, description='Network Name' )
   shared = Bool( apiName='shared', tacName='shared', versions=[ 1 ],
                  inputOk=True, outputOk=True, description='Shared network' )

   def fromSysdb( self, t ):
      super( NetworkModel, self ).fromSysdb( t )
      if t.tenant:
         self.getModelField( 'tenantId' ).value = t.tenant.id

   def toSysdb( self, network ):
      for field in self.getPopulatedModelFields():
         tacName = field.tacName
         # Set mutable fields
         if tacName in ( 'networkName', 'shared' ):
            setattr( network, tacName, field.value )

class SwitchportModel( BaseModel ):
   switchId = Str( apiName='id', tacName='switchId', versions=[ 1 ], inputOk=True,
                   outputOk=True, description='Switch ID' )
   interface = Str( apiName='interface', tacName='interface', versions=[ 1 ],
                    inputOk=True, outputOk=True, description='Switchports' )

   def fromSysdb( self, t ):
      super( SwitchportModel, self ).fromSysdb( t )

   def toSysdb( self, switchport ):
      pass

class PortModel( BaseModel ):
   portId = Str( apiName='id', tacName='id', versions=[ 1 ], inputOk=True,
                 outputOk=True, description='Port ID' )
   portName = Str( apiName='name', tacName='portName', versions=[ 1 ], inputOk=True,
                   outputOk=True, description='Port name' )
   portVlanType = Str( apiName='vlanType', tacName='portVlanType', versions=[ 1 ],
                       inputOk=True, outputOk=True, description='Port VLAN type' )
   network = Str( apiName='networkId', tacName=None, versions=[ 1 ],
                  inputOk=True, outputOk=True, description='Tenant network' )
   instanceId = Str( apiName='instanceId', tacName=None, versions=[ 1 ],
                     inputOk=True, outputOk=True,
                     description='The VM/DHCP/Router ID this port belongs to' )
   instanceType = Str( apiName='instanceType', tacName=None, versions=[ 1 ],
                     inputOk=True, outputOk=True,
                     description='Either vm/dhcp/router/baremetal' )
   tenant = Str( apiName='tenantId', tacName=None, versions=[ 1 ],
                 inputOk=True, outputOk=True,
                 description='Tenant ID' )

   def fromSysdb( self, t, iType ): # pylint: disable-msg=W0221
      super( PortModel, self ).fromSysdb( t )
      # We must serialize ptr manually
      if t.network:
         self.getModelField( 'networkId' ).value = t.network.id
      vInstance = t.instance
      if vInstance:
         self.getModelField( 'instanceId' ).value = vInstance.id
      if vInstance.tenant:
         self.getModelField( 'tenantId' ).value = vInstance.tenant.id
      self.getModelField( 'instanceType' ).value = iType
      self.getModelField( 'vlanType' ).value = t.portVlanType

   def toSysdb( self, region, pId, instance, network ):
      portVlanType = self.getModelField( 'vlanType' ).value or 'allowed'
      port = region.newPort( pId, portVlanType )
      pName = self.getModelField( 'name' ).value
      if pName:
         port.portName = pName
      port.instance = instance
      instance.port.addMember( port )
      port.network = network

class SegmentModel( BaseModel ):
   segmentId = Str( apiName='id', tacName='id', versions=[ 1 ], inputOk=True,
                    outputOk=True, description='Segmentation ID' )
   segmentationType = Str( apiName='type', tacName='type', versions=[ 1 ],
                           inputOk=True, outputOk=True,
                           description='Segmentation Type' )
   segmentationId = Int( apiName='segmentationId', tacName='segmentationId',
                         versions=[ 1 ], inputOk=True, outputOk=True,
                         description='Segmentation Type Id' )
   networkId = Str( apiName='networkId', tacName=None, versions=[ 1 ],
                    inputOk=True, outputOk=True, description='Network ID' )
   segmentType = Str( apiName='segmentType', tacName=None, versions=[ 1 ],
                      inputOk=True, outputOk=True,
                      description='Indicates whether the segment type is static or '
                                  'dynamic.' )

   def toSysdb( self, region, network, vlanPool ): # pylint: disable-msg=W0221
      segmentId = self.getModelField( 'id' ).value
      segmentationType = self.getModelField( 'type' ).value
      segmentationId = self.getModelField( 'segmentationId' ).value
      # If the arista_vlan type driver is being used, check that the vlan id is
      # in the configured pool
      if ( segmentationType == 'vlan' and vlanPool and
           segmentationId not in ( vlanPool.availableVlan.keys() +
                                   vlanPool.allocatedVlan.keys() ) ):
         invalidVlanMessage = ( "VLAN segmentation id %d is not available" %
                                segmentationId )
         warn( invalidVlanMessage )
         Logging.log( OpenStackLogMsgs.CVX_OPENSTACK_INVALID_NETWORK_VLAN,
                      segmentationId,
                      network.id )
         return segmentationId
      segment = region.newSegment( segmentId, segmentationType, segmentationId,
                                   network.id )
      segmentType = self.getModelField( 'segmentType' ).value
      if 'static' == segmentType:
         network.staticSegment.addMember( segment )
      elif 'dynamic' == segmentType:
         network.dynamicSegment.addMember( segment )

   def fromSysdb( self, region, segment ):
      super( SegmentModel, self ).fromSysdb( segment )
      self.getModelField( 'networkId' ).value = segment.networkId
      network = region.network.get( segment.networkId )
      if network is None:
         self.getModelField( 'segmentType' ).value = 'unknown'
         return
      if segment.id in network.staticSegment:
         self.getModelField( 'segmentType' ).value = 'static'
      elif segment.id in network.dynamicSegment:
         self.getModelField( 'segmentType' ).value = 'dynamic'

class PortToHostBindingModel( BaseModel ):
   ''' Internal implementation. Do not expose via any APIs (yet). '''
   host = Str( apiName='host', tacName='host', versions=[ 1 ], inputOk=True,
               outputOk=True, description='Port to host binding' )
   segment = List( apiName='segment', tacName='segment', versions=[ 1 ],
                   inputOk=True, outputOk=True,
                   description='Network segment the port is connected to.',
                   valueType=SegmentModel )

   # pylint: disable-msg=W0221
   def toSysdb( self, region, binding ):
      host = self.getModelField( 'host' ).value
      hostBinding = binding.newPortToHostBinding( host )
      segments = self.getModelField( 'segment' ).value or []

      for level, segment in enumerate( segments ):
         segmentId = segment.getModelField( 'id' ).value
         if segmentId in region.segment:
            hostBinding.segment[ int( level ) ] = region.segment[ segmentId ]

   def fromSysdb( self, region, portBinding ):
      self.getModelField( 'host' ).value = portBinding.host
      levels = portBinding.segment.keys()
      segments = []
      for l in sorted( levels ):
         segmentModel = SegmentModel()
         segmentModel.fromSysdb( region, portBinding.segment[ l ] )
         segments.append( segmentModel )
      self.getModelField( 'segment' ).value = segments

class PortToSwitchInterfaceBindingModel( BaseModel ):
   ''' Internal implementation. Do not expose via any APIs (yet). '''
   host = Str( apiName='host', tacName='host', versions=[ 1 ], inputOk=True,
               outputOk=True, description='Switch hostname' )
   switch = Str( apiName='switch', tacName=None, versions=[ 1 ], inputOk=True,
               outputOk=True, description='Switch ID' )
   interface = Str( apiName='interface', tacName=None, versions=[ 1 ], inputOk=True,
               outputOk=True, description='Switch Interface' )
   segment = List( apiName='segment', tacName='segment', versions=[ 1 ],
                   inputOk=True, outputOk=True,
                   description='Network segment the port is connected to.',
                   valueType=SegmentModel )

   # pylint: disable-msg=W0221
   def toSysdb( self, region, binding ):
      switch = self.getModelField( 'switch' ).value
      interface = self.getModelField( 'interface' ).value
      host = self.getModelField( 'host' ).value
      switchInterface = SwitchInterface( switch, interface )
      switchInterfaceBinding = binding.newPortToSwitchInterfaceBinding(
                                  switchInterface, host )
      segments = self.getModelField( 'segment' ).value or []

      for level, segment in enumerate( segments ):
         segmentId = segment.getModelField( 'id' ).value
         if segmentId in region.segment:
            switchInterfaceBinding.segment[ int( level ) ] = region.segment[
                                                                        segmentId ]

   def fromSysdb( self, region, portBinding ):
      self.getModelField( 'host' ).value = portBinding.host
      switchInterface = portBinding.switchInterface
      self.getModelField( 'switch' ).value = switchInterface.switchId
      self.getModelField( 'interface' ).value = switchInterface.interface
      levels = portBinding.segment.keys()
      segments = []
      for l in sorted( levels ):
         segmentModel = SegmentModel()
         segmentModel.fromSysdb( region, portBinding.segment[ l ] )
         segments.append( segmentModel )
      self.getModelField( 'segment' ).value = segments

class PortBindingModel( BaseModel ):
   portId = Str( apiName='portId', tacName='portId', versions=[ 1 ], inputOk=True,
                 outputOk=True, description='Port ID' )
   host = List( apiName='hostBinding', tacName='portToHostBinding', versions=[ 1 ],
                inputOk=True,
                outputOk=True, description='Host to which the port is bound to',
                valueType=PortToHostBindingModel )
   switchBinding = List( apiName='switchBinding',
                         tacName='portToSwitchInterfaceBinding', versions=[ 1 ],
                         inputOk=True, outputOk=True,
                         description='Host to which the port is bound to',
                         valueType=PortToSwitchInterfaceBindingModel )

   # pylint: disable-msg=W0221
   def toSysdb( self, region, portId ):
      portBindings = self.getModelField( 'hostBinding' ).value or []
      for binding in portBindings:
         p = region.port[ portId ]
         b = region.newPortBinding( portId, p )
         binding.toSysdb( region, b )

      portBindings = self.getModelField( 'switchBinding' ).value or []
      for binding in portBindings:
         p = region.port[ portId ]
         b = region.newPortBinding( portId, p )
         binding.toSysdb( region, b )

   def fromSysdb( self, region, binding ):
      self.getModelField( 'portId' ).value = binding.portId
      switchBindings = []
      for switchBinding in binding.portToSwitchInterfaceBinding.values():
         switchBindingModel = PortToSwitchInterfaceBindingModel()
         switchBindingModel.fromSysdb( region, switchBinding )
         switchBindings.append( switchBindingModel )
      self.getModelField( 'switchBinding' ).value = switchBindings
      hostBindings = []
      for hostBinding in binding.portToHostBinding.values():
         hostBindingModel = PortToHostBindingModel()
         hostBindingModel.fromSysdb( region, hostBinding )
         hostBindings.append( hostBindingModel )
      self.getModelField( 'hostBinding' ).value = hostBindings

class VmModel( BaseModel ):
   vmInstanceId = Str( apiName='id', tacName='vmInstanceId',
                       versions=[ 1 ], inputOk=True, outputOk=True,
                       description='VM Instance ID' )
   vmHostId = Str( apiName='hostId', tacName='vmHostId', versions=[ 1 ],
                   inputOk=True, outputOk=True, description='VM Host ID' )
   tenantId = Str( apiName='tenantId', tacName=None, versions=[ 1 ],
                   inputOk=True, outputOk=True, description='VM Tenant ID' )

   def toSysdb( self, vm ):
      for field in self.getPopulatedModelFields():
         tacName = field.tacName
         # set mutable fields
         if tacName in ( 'vmHostId', ):
            setattr( vm, tacName, field.value )

   def fromSysdb( self, vm ):
      super( VmModel, self ).fromSysdb( vm )
      if vm.tenant:
         self.getModelField( 'tenantId' ).value = vm.tenant.id

class BaremetalModel( BaseModel ):
   baremetalInstanceId = Str( apiName='id', tacName='baremetalInstanceId',
                              versions=[ 1 ], inputOk=True, outputOk=True,
                              description='Baremetal Instance ID' )
   baremetalHostId = Str( apiName='hostId', tacName='baremetalHostId',
                          versions=[ 1 ], inputOk=True, outputOk=True,
                          description='Baremetal Host ID' )
   tenantId = Str( apiName='tenantId', tacName=None, versions=[ 1 ],
                   inputOk=True, outputOk=True, description='Baremetal Tenant ID' )

   def toSysdb( self, baremetal ):
      for field in self.getPopulatedModelFields():
         tacName = field.tacName
         # set mutable fields
         if tacName in ( 'baremetalHostId', ):
            setattr( baremetal, tacName, field.value )

   def fromSysdb( self, baremetal ):
      super( BaremetalModel, self ).fromSysdb( baremetal )
      if baremetal.tenant:
         self.getModelField( 'tenantId' ).value = baremetal.tenant.id

class RouterModel( BaseModel ):
   routerInstanceId = Str( apiName='id', tacName='routerInstanceId',
                           versions=[ 1 ], inputOk=True, outputOk=True,
                           description='Router Instance ID' )
   routerHostId = Str( apiName='hostId', tacName='routerHostId', versions=[ 1 ],
                       inputOk=True, outputOk=True, description='Router Host ID' )
   tenantId = Str( apiName='tenantId', tacName=None, versions=[ 1 ],
                   inputOk=True, outputOk=True, description='Router Tenant ID' )

   def toSysdb( self, router ):
      for field in self.getPopulatedModelFields():
         tacName = field.tacName
         # set mutable fields
         if tacName in ( 'routerHostId', ):
            setattr( router, tacName, field.value )

   def fromSysdb( self, router ):
      super( RouterModel, self ).fromSysdb( router )
      if router.tenant:
         self.getModelField( 'tenantId' ).value = router.tenant.id

class DhcpModel( BaseModel ):
   dhcpInstanceId = Str( apiName='id', tacName='dhcpInstanceId',
                         versions=[ 1 ], inputOk=True, outputOk=True,
                         description='DHCP instance ID' )
   dhcpHostId = Str( apiName='hostId', tacName='dhcpHostId', versions=[ 1 ],
                     inputOk=True, outputOk=True, description='DHCP host ID' )
   tenantId = Str( apiName='tenantId', tacName=None, versions=[ 1 ],
                   inputOk=True, outputOk=True, description='DHCP Tenant ID' )

   def toSysdb( self, dhcp ):
      for field in self.getPopulatedModelFields():
         tacName = field.tacName
         # set mutable fields
         if tacName in ( 'dhcpHostId', ):
            setattr( dhcp, tacName, field.value )

   def fromSysdb( self, dhcp ):
      super( DhcpModel, self ).fromSysdb( dhcp )
      if dhcp.tenant:
         self.getModelField( 'tenantId' ).value = dhcp.tenant.id
