Source code for invertedai.api.drive

import time
from typing import List, Optional, Tuple
from pydantic import BaseModel, validate_call
import asyncio

import invertedai as iai
from invertedai.api.config import TIMEOUT, should_use_mock_api
from invertedai.api.mock import (
    mock_update_agent_state,
    get_mock_birdview,
    get_mock_infractions,
    get_mock_light_recurrent_states
)
from invertedai.error import APIConnectionError, InvalidInput
from invertedai.common import (
    AgentState,
    RecurrentState,
    Image,
    InfractionIndicators,
    AgentAttributes,
    AgentProperties,
    TrafficLightStatesDict,
    LightRecurrentStates,
    LightRecurrentState,
)


[docs]class DriveResponse(BaseModel): """ Response returned from an API call to :func:`iai.drive`. """ agent_states: List[ AgentState ] #: Predicted states for all agents at the next time step. recurrent_states: List[ RecurrentState ] #: To pass to :func:`iai.drive` at the subsequent time step. birdview: Optional[ Image ] #: If `get_birdview` was set, this contains the resulting image. infractions: Optional[ List[InfractionIndicators] ] #: If `get_infractions` was set, they are returned here. is_inside_supported_area: List[ bool ] #: For each agent, indicates whether the predicted state is inside supported area. traffic_lights_states: Optional[TrafficLightStatesDict] #: Traffic light states for the full map, as seen by the agents before they performed their actions resulting in the returned state. Each key-value pair corresponds to one particular traffic light. light_recurrent_states: Optional[LightRecurrentStates] #: Light recurrent states for the full map, each element corresponds to one light group. Pass this to the next call of :func:`iai.drive` for the server to realistically update the traffic light states. api_model_version: str # Model version used for this API call
[docs]@validate_call def drive( location: str, agent_states: List[AgentState], agent_attributes: Optional[List[AgentAttributes]]=None, agent_properties: Optional[List[AgentProperties]]=None, recurrent_states: Optional[List[RecurrentState]] = None, traffic_lights_states: Optional[TrafficLightStatesDict] = None, light_recurrent_states: Optional[LightRecurrentStates] = None, get_birdview: bool = False, rendering_center: Optional[Tuple[float, float]] = None, rendering_fov: Optional[float] = None, get_infractions: bool = False, random_seed: Optional[int] = None, api_model_version: Optional[str] = None ) -> DriveResponse: """ Parameters ---------- location: Location name in IAI format. agent_states: Current states of all agents. The state must include x: [float], y: [float] coordinate in meters orientation: [float] in radians with 0 pointing along x and pi/2 pointing along y and speed: [float] in m/s. agent_attributes: Static attributes of all agents. List of agent attributes. Each agent requires, length: [float] width: [float] and rear_axis_offset: [float] all in meters. agent_type: [str], currently supports 'car' and 'pedestrian'. waypoint: optional [Point], the target waypoint of the agent. agent_properties: Agent properties for all agents, replacing soon to be deprecated `agent_attributes`. List of agent attributes. Each agent requires, length: [float] width: [float] and rear_axis_offset: [float] all in meters. agent_type: [str], currently supports 'car' and 'pedestrian'. waypoint: optional [Point], the target waypoint of the agent. max_speed: optional [float], the desired maximum speed of the agent in m/s. recurrent_states: Recurrent states for all agents, obtained from the previous call to :func:`drive` or :func:`initialize`. get_birdview: Whether to return an image visualizing the simulation state. This is very slow and should only be used for debugging. rendering_center: Optional center coordinates for the rendered birdview. rendering_fov: Optional fov for the rendered birdview. get_infractions: Whether to check predicted agent states for infractions. This introduces some overhead, but it should be relatively small. traffic_lights_states: If the location contains traffic lights within the supported area, their current state can be provided here. It is legal to not provide this field, and use 'light_recurrent_states' to step the traffic lights. If provided, light states from 'traffic_light_states' will override the original light states given by 'light_recurrent_states'. The server does not currently support continuing user-provided light state sequences, so once the states are provided at any step, they should also be provided on all subsequent steps to guarantee coherent light sequences. If neither 'traffic_lights_states' nor 'light_recurrent_states' are provided, the server will arbitrarily initialize the traffic light states, and return the associated 'light_recurrent_states' in the response. light_recurrent_states: Light recurrent states for all agents, obtained from the previous call to :func:`drive` or :func:`initialize`. Specifies the state and time remaining for each light group in the map. To let the server manage all light states in the scene, pass 'light_recurrent_states' from the previous response of :func:`drive` here and leave `traffic_light_states=None`. random_seed: Controls the stochastic aspects of agent behavior for reproducibility. api_model_version: Optionally specify the version of the model. If None is passed which is by default, the best model will be used. See Also -------- :func:`initialize` :func:`location_info` :func:`light` :func:`blame` """ if should_use_mock_api(): agent_states = [mock_update_agent_state(s) for s in agent_states] present_mask = [True for _ in agent_states] birdview = get_mock_birdview() infractions = get_mock_infractions(len(agent_states)) response = DriveResponse( agent_states=agent_states, is_inside_supported_area=present_mask, recurrent_states=recurrent_states, birdview=birdview, infractions=infractions, traffic_lights_states=traffic_lights_states if traffic_lights_states is not None else None, light_recurrent_states=get_mock_light_recurrent_states(len(traffic_lights_states)) if traffic_lights_states is not None else None, api_model_version=api_model_version if api_model_version is not None else "best" ) return response def _tolist(input_data: List): if not isinstance(input_data, list): return input_data.tolist() else: return input_data recurrent_states = _tolist(recurrent_states) if recurrent_states is not None else None model_inputs = dict( location=location, agent_states=[state.tolist() for state in agent_states], agent_attributes=[state.tolist() for state in agent_attributes] if agent_attributes is not None else None, agent_properties=[ap.serialize() for ap in agent_properties] if agent_properties is not None else None, recurrent_states=[r.packed for r in recurrent_states] if recurrent_states is not None else None, traffic_lights_states=traffic_lights_states, light_recurrent_states=[light_recurrent_state.tolist() for light_recurrent_state in light_recurrent_states] if light_recurrent_states is not None else None, get_birdview=get_birdview, get_infractions=get_infractions, random_seed=random_seed, rendering_center=rendering_center, rendering_fov=rendering_fov, model_version=api_model_version ) start = time.time() timeout = TIMEOUT while True: try: response = iai.session.request(model="drive", data=model_inputs) response = DriveResponse( agent_states=[ AgentState.fromlist(state) for state in response["agent_states"] ], recurrent_states=[ RecurrentState.fromval(r) for r in response["recurrent_states"] ], birdview=Image.fromval(response["birdview"]) if response["birdview"] is not None else None, infractions=[ InfractionIndicators.fromlist(infractions) for infractions in response["infraction_indicators"] ] if response["infraction_indicators"] else [], is_inside_supported_area=response["is_inside_supported_area"], api_model_version=response["model_version"], traffic_lights_states=response["traffic_lights_states"] if response["traffic_lights_states"] is not None else None, light_recurrent_states=[ LightRecurrentState(state=state_arr[0], time_remaining=state_arr[1]) for state_arr in response["light_recurrent_states"] ] if response["light_recurrent_states"] is not None else None ) return response except APIConnectionError as e: iai.logger.warning("Retrying") if ( timeout is not None and time.time() > start + timeout ) or not e.should_retry: raise e
@validate_call async def async_drive( location: str, agent_states: List[AgentState], agent_attributes: Optional[List[AgentAttributes]]=None, agent_properties: Optional[List[AgentProperties]]=None, recurrent_states: Optional[List[RecurrentState]] = None, traffic_lights_states: Optional[TrafficLightStatesDict] = None, light_recurrent_states: Optional[LightRecurrentStates] = None, get_birdview: bool = False, rendering_center: Optional[Tuple[float, float]] = None, rendering_fov: Optional[float] = None, get_infractions: bool = False, random_seed: Optional[int] = None, api_model_version: Optional[str] = None ) -> DriveResponse: """ A light async version of :func:`drive` """ def _tolist(input_data: List): if not isinstance(input_data, list): return input_data.tolist() else: return input_data recurrent_states = _tolist(recurrent_states) if recurrent_states is not None else None model_inputs = dict( location=location, agent_states=[state.tolist() for state in agent_states], agent_attributes=[state.tolist() for state in agent_attributes] if agent_attributes is not None else None, agent_properties=[ap.serialize() for ap in agent_properties] if agent_properties is not None else None, recurrent_states=[r.packed for r in recurrent_states] if recurrent_states is not None else None, traffic_lights_states=traffic_lights_states, light_recurrent_states=[light_recurrent_state.tolist() for light_recurrent_state in light_recurrent_states] if light_recurrent_states is not None else None, get_birdview=get_birdview, get_infractions=get_infractions, random_seed=random_seed, rendering_center=rendering_center, rendering_fov=rendering_fov, model_version=api_model_version ) response = await iai.session.async_request(model="drive", data=model_inputs) response = DriveResponse( agent_states=[ AgentState.fromlist(state) for state in response["agent_states"] ], recurrent_states=[ RecurrentState.fromval(r) for r in response["recurrent_states"] ], birdview=Image.fromval(response["birdview"]) if response["birdview"] is not None else None, infractions=[ InfractionIndicators.fromlist(infractions) for infractions in response["infraction_indicators"] ] if response["infraction_indicators"] else [], is_inside_supported_area=response["is_inside_supported_area"], api_model_version=response["model_version"], traffic_lights_states=response["traffic_lights_states"] if response["traffic_lights_states"] is not None else None, light_recurrent_states=[ LightRecurrentState(state=state_arr[0], time_remaining=state_arr[1]) for state_arr in response["light_recurrent_states"] ] if response["light_recurrent_states"] is not None else None ) return response