Source code for invertedai.api.drive

import time
from typing import List, Optional, Tuple
from pydantic import BaseModel, validate_call
import asyncio

import invertedai as iai
from invertedai.api.config import TIMEOUT, should_use_mock_api
from invertedai.api.mock import (
    mock_update_agent_state,
    get_mock_birdview,
    get_mock_infractions,
)
from invertedai.error import APIConnectionError, InvalidInput
from invertedai.common import (
    AgentState,
    RecurrentState,
    Image,
    InfractionIndicators,
    AgentAttributes,
    TrafficLightStatesDict,
    LightRecurrentStates,
    LightRecurrentState,
)


[docs]class DriveResponse(BaseModel): """ Response returned from an API call to :func:`iai.drive`. """ agent_states: List[ AgentState ] #: Predicted states for all agents at the next time step. recurrent_states: List[ RecurrentState ] #: To pass to :func:`iai.drive` at the subsequent time step. birdview: Optional[ Image ] #: If `get_birdview` was set, this contains the resulting image. infractions: Optional[ List[InfractionIndicators] ] #: If `get_infractions` was set, they are returned here. is_inside_supported_area: List[ bool ] #: For each agent, indicates whether the predicted state is inside supported area. traffic_lights_states: Optional[TrafficLightStatesDict] #: Traffic light states for the full map, each key-value pair corresponds to one particular traffic light. light_recurrent_states: Optional[LightRecurrentStates] #: Light recurrent states for the full map, each element corresponds to one light group. api_model_version: str # Model version used for this API call
[docs]@validate_call def drive( location: str, agent_states: List[AgentState], agent_attributes: List[AgentAttributes], recurrent_states: List[RecurrentState], traffic_lights_states: Optional[TrafficLightStatesDict] = None, light_recurrent_states: Optional[LightRecurrentStates] = None, get_birdview: bool = False, rendering_center: Optional[Tuple[float, float]] = None, rendering_fov: Optional[float] = None, get_infractions: bool = False, random_seed: Optional[int] = None, api_model_version: Optional[str] = None ) -> DriveResponse: """ Parameters ---------- location: Location name in IAI format. agent_states: Current states of all agents. The state must include x: [float], y: [float] coordinate in meters orientation: [float] in radians with 0 pointing along x and pi/2 pointing along y and speed: [float] in m/s. agent_attributes: Static attributes of all agents. List of agent attributes. Each agent requires, length: [float] width: [float] and rear_axis_offset: [float] all in meters. agent_type: [str], currently supports 'car' and 'pedestrian'. waypoint: optional [Point], the target waypoint of the agent. recurrent_states: Recurrent states for all agents, obtained from the previous call to :func:`drive` or :func:`initialize`. get_birdview: Whether to return an image visualizing the simulation state. This is very slow and should only be used for debugging. rendering_center: Optional center coordinates for the rendered birdview. rendering_fov: Optional fov for the rendered birdview. get_infractions: Whether to check predicted agent states for infractions. This introduces some overhead, but it should be relatively small. traffic_lights_states: If the location contains traffic lights within the supported area, their current state should be provided here. Any traffic light for which no state is provided will have a state generated by iai. light_recurrent_states: Light recurrent states for all agents, obtained from the previous call to :func:`drive` or :func:`initialize`. Specifies the state and time remaining for each light group in the map. If manual control of individual traffic lights is desired, modify the relevant state(s) in traffic_lights_states, then pass in light_recurrent_states as usual. random_seed: Controls the stochastic aspects of agent behavior for reproducibility. api_model_version: Optionally specify the version of the model. If None is passed which is by default, the best model will be used. See Also -------- :func:`initialize` :func:`location_info` :func:`light` :func:`blame` """ if should_use_mock_api(): agent_states = [mock_update_agent_state(s) for s in agent_states] present_mask = [True for _ in agent_states] birdview = get_mock_birdview() infractions = get_mock_infractions(len(agent_states)) response = DriveResponse( agent_states=agent_states, is_inside_supported_area=present_mask, recurrent_states=recurrent_states, birdview=birdview, infractions=infractions, ) return response def _tolist(input_data: List): if not isinstance(input_data, list): return input_data.tolist() else: return input_data recurrent_states = ( _tolist(recurrent_states) if recurrent_states is not None else None ) # AxTx2x64 model_inputs = dict( location=location, agent_states=[state.tolist() for state in agent_states], agent_attributes=[state.tolist() for state in agent_attributes], recurrent_states=[r.packed for r in recurrent_states], traffic_lights_states=traffic_lights_states, light_recurrent_states=[light_recurrent_state.tolist() for light_recurrent_state in light_recurrent_states] if light_recurrent_states is not None else None, get_birdview=get_birdview, get_infractions=get_infractions, random_seed=random_seed, rendering_center=rendering_center, rendering_fov=rendering_fov, model_version=api_model_version ) start = time.time() timeout = TIMEOUT while True: try: response = iai.session.request(model="drive", data=model_inputs) response = DriveResponse( agent_states=[ AgentState.fromlist(state) for state in response["agent_states"] ], recurrent_states=[ RecurrentState.fromval(r) for r in response["recurrent_states"] ], birdview=Image.fromval(response["birdview"]) if response["birdview"] is not None else None, infractions=[ InfractionIndicators.fromlist(infractions) for infractions in response["infraction_indicators"] ] if response["infraction_indicators"] else [], is_inside_supported_area=response["is_inside_supported_area"], api_model_version=response["model_version"], traffic_lights_states=response["traffic_lights_states"] if response["traffic_lights_states"] is not None else None, light_recurrent_states=[ LightRecurrentState(state=state_arr[0], time_remaining=state_arr[1]) for state_arr in response["light_recurrent_states"] ] if response["light_recurrent_states"] is not None else None ) return response except APIConnectionError as e: iai.logger.warning("Retrying") if ( timeout is not None and time.time() > start + timeout ) or not e.should_retry: raise e
@validate_call async def async_drive( location: str, agent_states: List[AgentState], agent_attributes: List[AgentAttributes], recurrent_states: List[RecurrentState], traffic_lights_states: Optional[TrafficLightStatesDict] = None, get_birdview: bool = False, rendering_center: Optional[Tuple[float, float]] = None, rendering_fov: Optional[float] = None, get_infractions: bool = False, random_seed: Optional[int] = None, api_model_versioin: Optional[str] = None ) -> DriveResponse: """ A light async version of :func:`drive` """ def _tolist(input_data: List): if not isinstance(input_data, list): return input_data.tolist() else: return input_data recurrent_states = ( _tolist(recurrent_states) if recurrent_states is not None else None ) # AxTx2x64 model_inputs = dict( location=location, agent_states=[state.tolist() for state in agent_states], agent_attributes=[state.tolist() for state in agent_attributes], recurrent_states=[r.packed for r in recurrent_states], traffic_lights_states=traffic_lights_states, get_birdview=get_birdview, get_infractions=get_infractions, random_seed=random_seed, rendering_center=rendering_center, rendering_fov=rendering_fov, model_version=api_model_versioin, ) response = await iai.session.async_request(model="drive", data=model_inputs) response = DriveResponse( agent_states=[ AgentState.fromlist(state) for state in response["agent_states"] ], recurrent_states=[ RecurrentState.fromval(r) for r in response["recurrent_states"] ], birdview=Image.fromval(response["birdview"]) if response["birdview"] is not None else None, infractions=[ InfractionIndicators.fromlist(infractions) for infractions in response["infraction_indicators"] ] if response["infraction_indicators"] else [], is_inside_supported_area=response["is_inside_supported_area"], api_model_version=response["model_version"], traffic_lights_states=response["traffic_lights_states"] if response["traffic_lights_states"] is not None else None, light_recurrent_states=[ LightRecurrentState(state=state_arr[0], time_remaining=state_arr[1]) for state_arr in response["light_recurrent_states"] ] if response["light_recurrent_states"] is not None else None ) return response