Source code for invertedai.api.drive

import time
from typing import List, Optional, Tuple
from pydantic import BaseModel, validate_arguments
import asyncio

import invertedai as iai
from invertedai.api.config import TIMEOUT, should_use_mock_api
from invertedai.api.mock import (
    mock_update_agent_state,
    get_mock_birdview,
    get_mock_infractions,
)
from invertedai.error import APIConnectionError, InvalidInput
from invertedai.common import (
    AgentState,
    RecurrentState,
    Image,
    InfractionIndicators,
    AgentAttributes,
    TrafficLightStatesDict,
)


[docs]class DriveResponse(BaseModel): """ Response returned from an API call to :func:`iai.drive`. """ agent_states: List[ AgentState ] #: Predicted states for all agents at the next time step. recurrent_states: List[ RecurrentState ] #: To pass to :func:`iai.drive` at the subsequent time step. birdview: Optional[ Image ] #: If `get_birdview` was set, this contains the resulting image. infractions: Optional[ List[InfractionIndicators] ] #: If `get_infractions` was set, they are returned here. is_inside_supported_area: List[ bool ] #: For each agent, indicates whether the predicted state is inside supported area. model_version: str # Model version used for this API call
[docs]@validate_arguments def drive( location: str, agent_states: List[AgentState], agent_attributes: List[AgentAttributes], recurrent_states: List[RecurrentState], traffic_lights_states: Optional[TrafficLightStatesDict] = None, get_birdview: bool = False, rendering_center: Optional[Tuple[float, float]] = None, rendering_fov: Optional[float] = None, get_infractions: bool = False, random_seed: Optional[int] = None, model_version: Optional[str] = None ) -> DriveResponse: """ Parameters ---------- location: Location name in IAI format. agent_states: Current states of all agents. The state must include x: [float], y: [float] coordinate in meters orientation: [float] in radians with 0 pointing along x and pi/2 pointing along y and speed: [float] in m/s. agent_attributes: Static attributes of all agents. List of agent attributes. Each agent requires, length: [float] width: [float] and rear_axis_offset: [float] all in meters. agent_type: [str], currently only supports 'car', but support for 'pedestrian' will be added in the future recurrent_states: Recurrent states for all agents, obtained from the previous call to :func:`drive` or :func:`initialize`. get_birdview: Whether to return an image visualizing the simulation state. This is very slow and should only be used for debugging. rendering_center: Optional center coordinates for the rendered birdview. rendering_fov: Optional fov for the rendered birdview. get_infractions: Whether to check predicted agent states for infractions. This introduces some overhead, but it should be relatively small. traffic_lights_states: If the location contains traffic lights within the supported area, their current state should be provided here. Any traffic light for which no state is provided will be ignored by the agents. random_seed: Controls the stochastic aspects of agent behavior for reproducibility. model_version: Optionally specify the version of the model. If None is passed which is by default, the best model will be used. See Also -------- :func:`initialize` :func:`location_info` :func:`light` :func:`blame` """ if should_use_mock_api(): agent_states = [mock_update_agent_state(s) for s in agent_states] present_mask = [True for _ in agent_states] birdview = get_mock_birdview() infractions = get_mock_infractions(len(agent_states)) response = DriveResponse( agent_states=agent_states, is_inside_supported_area=present_mask, recurrent_states=recurrent_states, birdview=birdview, infractions=infractions, ) return response def _tolist(input_data: List): if not isinstance(input_data, list): return input_data.tolist() else: return input_data recurrent_states = ( _tolist(recurrent_states) if recurrent_states is not None else None ) # AxTx2x64 model_inputs = dict( location=location, agent_states=[state.tolist() for state in agent_states], agent_attributes=[state.tolist() for state in agent_attributes], recurrent_states=[r.packed for r in recurrent_states], traffic_lights_states=traffic_lights_states, get_birdview=get_birdview, get_infractions=get_infractions, random_seed=random_seed, rendering_center=rendering_center, rendering_fov=rendering_fov, model_version=model_version ) start = time.time() timeout = TIMEOUT while True: try: response = iai.session.request(model="drive", data=model_inputs) response = DriveResponse( agent_states=[ AgentState.fromlist(state) for state in response["agent_states"] ], recurrent_states=[ RecurrentState.fromval(r) for r in response["recurrent_states"] ], birdview=Image.fromval(response["birdview"]) if response["birdview"] is not None else None, infractions=[ InfractionIndicators.fromlist(infractions) for infractions in response["infraction_indicators"] ] if response["infraction_indicators"] else [], is_inside_supported_area=response["is_inside_supported_area"], model_version=response["model_version"] ) return response except APIConnectionError as e: iai.logger.warning("Retrying") if ( timeout is not None and time.time() > start + timeout ) or not e.should_retry: raise e
@validate_arguments async def async_drive( location: str, agent_states: List[AgentState], agent_attributes: List[AgentAttributes], recurrent_states: List[RecurrentState], traffic_lights_states: Optional[TrafficLightStatesDict] = None, get_birdview: bool = False, rendering_center: Optional[Tuple[float, float]] = None, rendering_fov: Optional[float] = None, get_infractions: bool = False, random_seed: Optional[int] = None, model_version: Optional[str] = None ) -> DriveResponse: """ A light async version of :func:`drive` """ def _tolist(input_data: List): if not isinstance(input_data, list): return input_data.tolist() else: return input_data recurrent_states = ( _tolist(recurrent_states) if recurrent_states is not None else None ) # AxTx2x64 model_inputs = dict( location=location, agent_states=[state.tolist() for state in agent_states], agent_attributes=[state.tolist() for state in agent_attributes], recurrent_states=[r.packed for r in recurrent_states], traffic_lights_states=traffic_lights_states, get_birdview=get_birdview, get_infractions=get_infractions, random_seed=random_seed, rendering_center=rendering_center, rendering_fov=rendering_fov, model_version=model_version, ) response = await iai.session.async_request(model="drive", data=model_inputs) response = DriveResponse( agent_states=[ AgentState.fromlist(state) for state in response["agent_states"] ], recurrent_states=[ RecurrentState.fromval(r) for r in response["recurrent_states"] ], birdview=Image.fromval(response["birdview"]) if response["birdview"] is not None else None, infractions=[ InfractionIndicators.fromlist(infractions) for infractions in response["infraction_indicators"] ] if response["infraction_indicators"] else [], is_inside_supported_area=response["is_inside_supported_area"], model_version=response["model_version"] ) return response