import asyncio
import warnings
from typing import Tuple, Optional, List, Union
from pydantic import BaseModel, validate_call
from math import ceil
import invertedai as iai
from invertedai.large.common import Region
from invertedai.common import Point, AgentState, AgentAttributes, AgentProperties, RecurrentState, TrafficLightStatesDict, LightRecurrentState
from invertedai.api.drive import DriveResponse
from invertedai.utils import convert_attributes_to_properties
from invertedai.error import InvertedAIError, InvalidRequestError
from ._quadtree import QuadTreeAgentInfo, QuadTree, _flatten_and_sort, QUADTREE_SIZE_BUFFER
DRIVE_MAXIMUM_NUM_AGENTS = 100
async def async_drive_all(async_input_params):
all_responses = await asyncio.gather(*[iai.async_drive(**input_params) for input_params in async_input_params])
return all_responses
[docs]@validate_call
def large_drive(
location: str,
agent_states: List[AgentState],
agent_properties: List[Union[AgentAttributes,AgentProperties]],
recurrent_states: Optional[List[RecurrentState]] = None,
traffic_lights_states: Optional[TrafficLightStatesDict] = None,
light_recurrent_states: Optional[List[LightRecurrentState]] = None,
get_infractions: bool = False,
random_seed: Optional[int] = None,
api_model_version: Optional[str] = None,
single_call_agent_limit: Optional[int] = None,
async_api_calls: bool = True
) -> DriveResponse:
"""
A utility function to drive more than the normal capacity of agents in a call to :func:`drive`.
The agents are inserted into a quadtree structure and :func:`drive` is then called on each
region represented by a leaf node of the quadtree. Agents near this region are included in the
:func:`drive` calls to ensure the agents see all their neighbours. The quadtree is constructed
during each call to this utility function to maintain statelessness.
Parameters
----------
location:
Please refer to the documentation of :func:`drive` for information on this parameter.
agent_states:
Please refer to the documentation of :func:`drive` for information on this parameter.
agent_properties:
Please refer to the documentation of :func:`drive` for information on this parameter.
recurrent_states:
Please refer to the documentation of :func:`drive` for information on this parameter.
traffic_lights_states:
Please refer to the documentation of :func:`drive` for information on this parameter.
light_recurrent_states:
Please refer to the documentation of :func:`drive` for information on this parameter.
get_infractions:
Please refer to the documentation of :func:`drive` for information on this parameter.
random_seed:
Please refer to the documentation of :func:`drive` for information on this parameter.
api_model_version:
Please refer to the documentation of :func:`drive` for information on this parameter.
single_call_agent_limit:
The number of agents allowed in a region before it must subdivide. Currently this value
represents the capacity of a quadtree leaf node that will subdivide if the number of vehicles
in the region, plus relevant neighbouring regions, passes this threshold. In any case, this
will be limited to the maximum currently supported by :func:`drive`.
async_api_calls:
A flag to control whether to use asynchronous DRIVE calls.
See Also
--------
:func:`drive`
"""
# Validate input arguments
if single_call_agent_limit is None:
single_call_agent_limit = DRIVE_MAXIMUM_NUM_AGENTS
if single_call_agent_limit > DRIVE_MAXIMUM_NUM_AGENTS:
single_call_agent_limit = DRIVE_MAXIMUM_NUM_AGENTS
iai.logger.warning(f"Single Call Agent Limit cannot be more than {DRIVE_MAXIMUM_NUM_AGENTS}, limiting this value to {DRIVE_MAXIMUM_NUM_AGENTS} and proceeding.")
num_agents = len(agent_states)
if not (num_agents == len(agent_properties)):
if recurrent_states is not None and not (num_agents == len(recurrent_states)):
raise InvalidRequestError(message="Input lists are not of equal size.")
if not num_agents > 0:
raise InvalidRequestError(message="Valid call must contain at least 1 agent.")
# Convert any AgentAttributes to AgentProperties for backwards compatibility
agent_properties_new = []
is_using_attributes = False
for properties in agent_properties:
properties_new = properties
if isinstance(properties,AgentAttributes):
properties_new = convert_attributes_to_properties(properties)
is_using_attributes = True
agent_properties_new.append(properties_new)
agent_properties = agent_properties_new
if is_using_attributes:
warnings.warn('agent_attributes is deprecated. Please use agent_properties.',category=DeprecationWarning)
# Generate quadtree
agent_x = [agent.center.x for agent in agent_states]
agent_y = [agent.center.y for agent in agent_states]
max_x, min_x, max_y, min_y = max(agent_x), min(agent_x), max(agent_y), min(agent_y)
region_size = ceil(max(max_x - min_x, max_y - min_y)) + QUADTREE_SIZE_BUFFER
region_center = (round((max_x+min_x)/2),round((max_y+min_y)/2))
quadtree = QuadTree(
capacity=single_call_agent_limit,
region=Region.create_square_region(
center=Point.fromlist(list(region_center)),
size=region_size
),
)
for i, (agent, attrs) in enumerate(zip(agent_states,agent_properties)):
if recurrent_states is None:
recurr_state = None
else:
recurr_state = recurrent_states[i]
agent_info = QuadTreeAgentInfo.fromlist([agent, attrs, recurr_state, i])
is_inserted = quadtree.insert(agent_info)
if not is_inserted:
raise InvertedAIError(message=f"Unable to insert agent into region.")
# Call DRIVE API on all leaf nodes
all_leaf_nodes = quadtree.get_leaf_nodes()
async_input_params = []
all_responses = []
non_empty_nodes = []
agent_id_order = []
if len(all_leaf_nodes) > 1:
for i, leaf_node in enumerate(all_leaf_nodes):
region, region_buffer = leaf_node.region, leaf_node.region_buffer
region_agents_ids = [particle.agent_id for particle in leaf_node.particles]
if len(region.agent_states) > 0:
non_empty_nodes.append(leaf_node)
agent_id_order.extend(region_agents_ids)
input_params = {
"location":location,
"agent_states":region.agent_states+region_buffer.agent_states,
"recurrent_states":None if recurrent_states is None else region.recurrent_states+region_buffer.recurrent_states,
"agent_properties":region.agent_properties+region_buffer.agent_properties,
"light_recurrent_states":light_recurrent_states,
"traffic_lights_states":traffic_lights_states,
"get_birdview":False,
"rendering_center":None,
"rendering_fov":None,
"get_infractions":get_infractions,
"random_seed":random_seed,
"api_model_version":api_model_version
}
if not async_api_calls:
all_responses.append(iai.drive(**input_params))
else:
async_input_params.append(input_params)
if async_api_calls:
all_responses = asyncio.run(async_drive_all(async_input_params))
response = DriveResponse(
agent_states = _flatten_and_sort([region_response.agent_states[:leaf_node.get_number_of_agents_in_node()] for region_response, leaf_node in zip(all_responses,non_empty_nodes)],agent_id_order),
recurrent_states = _flatten_and_sort([region_response.recurrent_states[:leaf_node.get_number_of_agents_in_node()] for region_response, leaf_node in zip(all_responses,non_empty_nodes)],agent_id_order),
is_inside_supported_area = _flatten_and_sort([region_response.is_inside_supported_area[:leaf_node.get_number_of_agents_in_node()] for region_response, leaf_node in zip(all_responses,non_empty_nodes)],agent_id_order),
infractions = [] if not get_infractions else _flatten_and_sort([region_response.infractions[:leaf_node.get_number_of_agents_in_node()] for region_response, leaf_node in zip(all_responses,non_empty_nodes)],agent_id_order),
api_model_version = all_responses[0].api_model_version,
birdview = None,
traffic_lights_states = all_responses[0].traffic_lights_states,
light_recurrent_states = all_responses[0].light_recurrent_states
)
else:
# Quadtree capacity has not been surpassed therefore can just call regular drive()
response = iai.drive(
location = location,
agent_states = agent_states,
agent_properties = agent_properties,
recurrent_states = recurrent_states,
traffic_lights_states = traffic_lights_states,
light_recurrent_states = light_recurrent_states,
get_birdview = False,
rendering_center = None,
rendering_fov = None,
get_infractions = get_infractions,
random_seed = random_seed,
api_model_version = api_model_version
)
return response