Source code for invertedai.cosimulation

from typing import List, Tuple, Optional, Union
import random
from collections import deque
from queue import Queue
from pydantic import BaseModel, validate_arguments
from itertools import product
import numpy as np
import asyncio
from itertools import chain

import invertedai as iai
from invertedai import drive, initialize, location_info, light, async_drive, async_initialize
from invertedai.common import (AgentState, InfractionIndicators, Image,
                               TrafficLightStatesDict, AgentAttributes, RecurrentState, Point)


[docs]class BasicCosimulation: """ Stateful wrapper around the Inverted AI API to simplify co-simulation. All arguments to :func:`initialize` can be passed to the constructor here and a sufficient combination of them must be passed as required by :func:`initialize`. This wrapper caches static agent attributes and propagates the recurrent state, so that only states of ego agents and NPCs need to be exchanged with it to perform co-simulation. Typically, each time step requires a single call to :func:`self.npc_states` and a single call to :func:`self.step`. This wrapper only supports a minimal co-simulation functionality. For more advanced use cases, call :func:`initialize` and :func:`drive` directly. :param location: Location name as expected by :func:`initialize`. :param ego_agent_mask: List indicating which agent is ego, meaning that it is controlled by you externally. The order in this list should be the same as that used in arguments to :func:`initialize`. :param monitor_infraction: Whether to monitor driving infractions, at a small increase in latency and payload size. :param get_birdview: Whether to render the bird's eye view of the simulation state at each time step. It drastically increases the payload received from Inverted AI servers and therefore slows down the simulation - use only for debugging. :param random_seed: Controls the stochastic aspects of simulation for reproducibility. """ def __init__( self, location: str, ego_agent_mask: Optional[List[bool]] = None, monitor_infractions: bool = False, get_birdview: bool = False, random_seed: Optional[int] = None, traffic_lights: bool = False, # sufficient arguments to initialize must also be included **kwargs, ): self._location = location self.rng = None if random_seed is None else random.Random(random_seed) self.light_flag = False self.light_recurrent_state = None self._light_state = None if traffic_lights: location_info_response = location_info(location=location) static_actors = location_info_response.static_actors if any(actor.agent_type == "traffic-light" for actor in static_actors): self.light_flag = True response = initialize( location=location, get_birdview=get_birdview, get_infractions=monitor_infractions, random_seed=None if self.rng is None else self.rng.randint(0, int(9e6)), traffic_light_state_history=None, **kwargs, ) if self.light_flag: self._light_state = response.traffic_lights_states self.light_recurrent_state = response.light_recurrent_states if monitor_infractions and (response.infractions is not None): self._infractions = response.infractions else: self._infractions = None self._agent_count = len( response.agent_attributes ) # initialize may produce different agent count self._agent_attributes = response.agent_attributes self._agent_states = response.agent_states self._recurrent_states = response.recurrent_states self._monitor_infractions = monitor_infractions self._birdview = response.birdview if get_birdview else None self._get_birdview = get_birdview if ego_agent_mask is None: self._ego_agent_mask = [False] * self._agent_count else: self._ego_agent_mask = ego_agent_mask[:self._agent_count] # initialize might not return the exact number of agents requested, # in which case we need to adjust the ego agent mask if len(self._ego_agent_mask) > self._agent_count: self._ego_agent_mask = self._ego_agent_mask[:self._agent_count] if len(self._ego_agent_mask) < self._agent_count: self._ego_agent_mask += [False] * (self._agent_count - len(self._ego_agent_mask)) self._time_step = 0 @property def location(self) -> str: """ Location name as recognized by Inverted AI API. """ return self._location @property def agent_count(self) -> int: """ The total number of agents, both ego and NPCs. """ return self._agent_count @property def agent_states(self) -> List[AgentState]: """ The predicted states for all agents, including ego. """ return self._agent_states @property def agent_attributes(self) -> List[AgentAttributes]: """ The attributes (length, width, rear_axis_offset) for all agents, including ego. """ return self._agent_attributes @property def ego_agent_mask(self) -> List[bool]: """ Lists which agents are ego, which means that you control them externally. It can be updated during the simulation, but see caveats in user guide regarding the quality of resulting predictions. """ return self._ego_agent_mask @ego_agent_mask.setter def ego_agent_mask(self, value): self.ego_agent_mask = value @property def ego_states(self): """ Returns the predicted states of ego agents in order. The NPC agents are excluded. """ return [d for d, s in zip(self._agent_states, self._ego_agent_mask) if s] @property def ego_attributes(self): """ Returns the attributes of ego agents in order. The NPC agents are excluded. """ return [attr for attr, s in zip(self._agent_attributes, self._ego_agent_mask) if s] @property def infractions(self) -> Optional[List[InfractionIndicators]]: """ If `monitor_infractions` was set in the constructor, lists infractions currently committed by each agent, including ego agents. """ return self._infractions @property def birdview(self) -> Image: """ If `get_birdview` was set in the constructor, this is the image showing the current state of the simulation. """ return self._birdview @property def npc_states(self) -> List[AgentState]: """ Returns the predicted states of NPCs (non-ego agents) in order. The predictions for ego agents are excluded. """ npc_states = [] for (i, s) in enumerate(self._agent_states): if not self._ego_agent_mask[i]: npc_states.append(s) return npc_states @property def light_states(self) -> Optional[TrafficLightStatesDict]: """ Returns the traffic light states if any exists on the map. """ return self._light_state
[docs] def step(self, current_ego_agent_states: List[AgentState]) -> None: """ Calls :func:`drive` to advance the simulation by one time step. Current states of ego agents need to be provided to synchronize with your local simulator. :param current_ego_agent_states: States of ego agents before the step. :return: None - call :func:`self.npc_states` to retrieve predictions. """ self._update_ego_states(current_ego_agent_states) response = drive( location=self.location, agent_attributes=self._agent_attributes, agent_states=self.agent_states, recurrent_states=self._recurrent_states, get_infractions=self._monitor_infractions, get_birdview=self._get_birdview, random_seed=None if self.rng is None else self.rng.randint(0, int(9e6)), light_recurrent_states=self.light_recurrent_state, ) self._agent_states = response.agent_states self._recurrent_states = response.recurrent_states if self._monitor_infractions and (response.infractions is not None): self._infractions = response.infractions if self._get_birdview: self._birdview = response.birdview self._time_step += 1 if self.light_flag: self._light_state = response.traffic_lights_states self.light_recurrent_state = response.light_recurrent_states
def _update_ego_states(self, ego_agent_states): new_states = [] ego_idx = 0 for (i, s) in enumerate(self.agent_states): if self.ego_agent_mask[i]: new_states.append(ego_agent_states[ego_idx]) ego_idx += 1 else: new_states.append(self.agent_states[i]) self._agent_states = new_states