Source code for invertedai.api.initialize

import time
from pydantic import BaseModel, validate_arguments, root_validator
from typing import List, Optional, Dict, Tuple
import asyncio

import invertedai as iai
from invertedai.api.config import TIMEOUT, should_use_mock_api
from invertedai.error import TryAgain, InvalidInputType, InvalidInput
from invertedai.api.mock import (
    get_mock_agent_attributes,
    get_mock_agent_state,
    get_mock_recurrent_state,
    get_mock_birdview,
    get_mock_infractions,
)
from invertedai.common import (
    RecurrentState,
    AgentState,
    AgentAttributes,
    TrafficLightStatesDict,
    Image,
    InfractionIndicators,
)


[docs]class InitializeResponse(BaseModel): """ Response returned from an API call to :func:`iai.initialize`. """ recurrent_states: List[ Optional[RecurrentState] ] #: To pass to :func:`iai.drive` at the first time step. agent_states: List[Optional[AgentState]] #: Initial states of all initialized agents. agent_attributes: List[ Optional[AgentAttributes] ] #: Static attributes of all initialized agents. birdview: Optional[ Image ] #: If `get_birdview` was set, this contains the resulting image. infractions: Optional[ List[InfractionIndicators] ] #: If `get_infractions` was set, they are returned here. model_version: str # Model version used for this API call
[docs]@validate_arguments def initialize( location: str, agent_attributes: Optional[List[AgentAttributes]] = None, states_history: Optional[List[List[AgentState]]] = None, traffic_light_state_history: Optional[ List[TrafficLightStatesDict] ] = None, get_birdview: bool = False, location_of_interest: Optional[Tuple[float, float]] = None, get_infractions: bool = False, agent_count: Optional[int] = None, random_seed: Optional[int] = None, model_version: Optional[str] = None # Model version used for this API call ) -> InitializeResponse: """ Initializes a simulation in a given location, using a combination of **user-defined** and **sampled** agents. **User-defined** agents are placed in a scene first, after which a number of agents are sampled conditionally inferred from the `agent_count` argument. If **user-defined** agents are desired, `states_history` must contain a list of `AgentState's` of all **user-defined** agents per historical time step. Any **user-defined** agent must have a corresponding fully specified static `AgentAttribute` in `agent_attributes`. Any **sampled** agents not specified in `agent_attributes` will be generated with default static attribute values however **sampled** agents may be defined by specifying `agent_type` only. Agents are identified by their list index, so ensure the indices of each agent match in `states_history` and `agent_attributes` when applicable. If traffic lights are present in the scene, for best results their state should be specified for the current time in a `TrafficLightStatesDict`, and all historical time steps for which `states_history` is provided. It is legal to omit the traffic light state specification, but the scene will be initialized as if the traffic lights were disabled. Every simulation must start with a call to this function in order to obtain correct recurrent states for :func:`drive`. Parameters ---------- location: Location name in IAI format. agent_attributes: Static attributes for all agents. The pre-defined agents should be specified first, followed by the sampled agents. states_history: History of pre-defined agent states - the outer list is over time and the inner over agents, in chronological order, i.e., index 0 is the oldest state and index -1 is the current state. The order of agents should be the same as in `agent_attributes`. For best results, provide at least 10 historical states for each agent. traffic_light_state_history: History of traffic light states - the list is over time, in chronological order, i.e. the last element is the current state. Not specifying traffic light state is equivalent to disabling traffic lights. location_of_interest: Optional coordinates for spawning agents with the given location as center instead of the default map center get_birdview: If True, a birdview image will be returned representing the current world. Note this will significantly impact on the latency. get_infractions: If True, infraction metrics will be returned for each agent. agent_count: Deprecated. Equivalent to padding the `agent_attributes` list to this length with default `AgentAttributes`. random_seed: Controls the stochastic aspects of initialization for reproducibility. model_version: Optionally specify the version of the model. If None is passed which is by default, the best model will be used. See Also -------- :func:`drive` :func:`location_info` :func:`light` :func:`blame` """ if should_use_mock_api(): if agent_attributes is None: assert agent_count is not None agent_attributes = [get_mock_agent_attributes() for _ in range(agent_count)] agent_states = [get_mock_agent_state() for _ in range(agent_count)] else: agent_states = states_history[-1] recurrent_states = [get_mock_recurrent_state() for _ in range(agent_count)] birdview = get_mock_birdview() infractions = get_mock_infractions(len(agent_states)) response = InitializeResponse( agent_states=agent_states, agent_attributes=agent_attributes, recurrent_states=recurrent_states, birdview=birdview, infractions=infractions, ) return response model_inputs = dict( location=location, num_agents_to_spawn=agent_count, states_history=states_history if states_history is None else [[st.tolist() for st in states] for states in states_history], agent_attributes=agent_attributes if agent_attributes is None else [state.tolist() for state in agent_attributes], traffic_light_state_history=traffic_light_state_history, get_birdview=get_birdview, location_of_interest=location_of_interest, get_infractions=get_infractions, random_seed=random_seed, model_version=model_version ) start = time.time() timeout = TIMEOUT while True: try: response = iai.session.request(model="initialize", data=model_inputs) response = InitializeResponse( agent_states=[ AgentState.fromlist(state) for state in response["agent_states"] ], agent_attributes=[ AgentAttributes.fromlist(attr) for attr in response["agent_attributes"] ], recurrent_states=[ RecurrentState.fromval(r) for r in response["recurrent_states"] ], birdview=Image.fromval(response["birdview"]) if response["birdview"] is not None else None, infractions=[ InfractionIndicators.fromlist(infractions) for infractions in response["infraction_indicators"] ] if response["infraction_indicators"] else [], model_version=response["model_version"] ) return response except TryAgain as e: if timeout is not None and time.time() > start + timeout: raise e iai.logger.info(iai.logger.logfmt("Waiting for model to warm up", error=e))
@validate_arguments async def async_initialize( location: str, agent_attributes: Optional[List[AgentAttributes]] = None, states_history: Optional[List[List[AgentState]]] = None, traffic_light_state_history: Optional[ List[TrafficLightStatesDict] ] = None, get_birdview: bool = False, location_of_interest: Optional[Tuple[float, float]] = None, get_infractions: bool = False, agent_count: Optional[int] = None, random_seed: Optional[int] = None, model_version: Optional[str] = None ) -> InitializeResponse: """ The async version of :func:`initialize` """ model_inputs = dict( location=location, num_agents_to_spawn=agent_count, states_history=states_history if states_history is None else [[st.tolist() for st in states] for states in states_history], agent_attributes=agent_attributes if agent_attributes is None else [state.tolist() for state in agent_attributes], traffic_light_state_history=traffic_light_state_history, get_birdview=get_birdview, location_of_interest=location_of_interest, get_infractions=get_infractions, random_seed=random_seed, model_version=model_version ) response = await iai.session.async_request(model="initialize", data=model_inputs) agents_spawned = len(response["agent_states"]) if agents_spawned != agent_count: iai.logger.warning( f"Unable to spawn a scenario for {agent_count} agents, {agents_spawned} spawned instead." ) response = InitializeResponse( agent_states=[ AgentState.fromlist(state) for state in response["agent_states"] ], agent_attributes=[ AgentAttributes.fromlist(attr) for attr in response["agent_attributes"] ], recurrent_states=[ RecurrentState.fromval(r) for r in response["recurrent_states"] ], birdview=Image.fromval(response["birdview"]) if response["birdview"] is not None else None, infractions=[ InfractionIndicators.fromlist(infractions) for infractions in response["infraction_indicators"] ] if response["infraction_indicators"] else [], model_version=response["model_version"] ) return response