from typing import List, Optional, Tuple
from collections import defaultdict
from copy import deepcopy
from invertedai.common import RECURRENT_SIZE, AgentState, AgentProperties, RecurrentState, SimulationAgentDict, AgentData
from invertedai.api.initialize import InitializeResponse
from invertedai.api.drive import DriveResponse
from invertedai.helpers.waypoints import WaypointManagerConfig, WaypointManager
from pydantic import BaseModel
from invertedai.utils import ScenePlotterConfig, ScenePlotter, WaypointsDict
from invertedai.large.initialize import large_initialize, get_regions_default, RegionsConfig
from invertedai.large.drive import large_drive
from invertedai.logs.logger import LogWriterConfig, LogWriter
from invertedai.large.common import Region
from matplotlib.animation import FuncAnimation
import uuid
[docs]class SimulationManager:
"""
Stateful class for managing keyed agents with an internal dictionary SimulationAgentDict to manage AgentData by AgentID
and provides wrappers around the IAI large_initialize and large_drive APIs
Parameters:
scene_plotter_cfg : Optional[ScenePlotterConfig]
Configuration object used to initialize a ScenePlotter instance
Enables birdview visualization and animation of the simulation
waypoint_cfg : Optional[WaypointManagerConfig]
Configuration for initializing a WaypointManager
If provided waypoints will be dynamically updated during simulation
log_writer_cfg : Optional[LogWriterConfig]
Configuration for enabling structured logging of the simulation
If provided all initialize and drive steps will be recorded to a JSON log
"""
def __init__(
self,
scene_plotter_cfg: Optional[ScenePlotterConfig] = None, # can optionally initialize a scene plotter for visualization
waypoint_cfg : Optional[WaypointManagerConfig] = None, # can optionally initialize a waypointManager to manage waypoints
log_writer_cfg: Optional[LogWriterConfig] = None, # can optionally initialize a log_writer_cfg to write a json file log of the simulation
):
self.scene_plotter = None
if scene_plotter_cfg:
self.scene_plotter = ScenePlotter(
scene_plotter_cfg.location_info_response.birdview_image.decode(),
scene_plotter_cfg.location_info_response.map_fov,
(scene_plotter_cfg.location_info_response.map_center.x, scene_plotter_cfg.location_info_response.map_center.y),
scene_plotter_cfg.location_info_response.static_actors,
left_hand_coordinates = scene_plotter_cfg.location.split(":")[0] == "carla"
)
self.agents_dict: SimulationAgentDict = defaultdict(AgentData)
self.waypoint_manager: Optional[WaypointManager] = None
if waypoint_cfg:
self.waypoint_manager = WaypointManager(cfg=waypoint_cfg)
self.log_writer = None
self.log_writer_cfg = log_writer_cfg
if log_writer_cfg:
self.log_writer = LogWriter()
[docs] def insert_agents(
self,
agent_data_list: List[AgentData],
ids: Optional[List[str]],
overwrite: bool = False,
):
"""
Insert multiple agents into the existing agents_dict using their AgentData
Parameters:
agent_data_list : List[AgentData]
List of AgentData for each agent to be inserted
ids : Optional[List[str]]
Optional list of AgentIDs to use for the new agents.
If None, random UUIDs will be generated for each new agent.
overwrite: bool
If True, allows new agents to overwrite existing agents with the same ID.
"""
if ids is None:
new_ids = [str(uuid.uuid4()) for _ in agent_data_list]
else:
new_ids = ids
if len(new_ids) != len(agent_data_list):
raise ValueError("Length of ids provided and agent_data_list is not equal")
for i, agent_id in enumerate(new_ids):
if agent_id in self.agents_dict and not overwrite:
raise ValueError(f"Agent '{agent_id}' already exists. Cannot be inserted again with overwrite=False.")
self.agents_dict[agent_id] = agent_data_list[i]
return new_ids
[docs] def remove_agents(
self,
agent_ids: List[str],
):
"""
Removes multiple agents from the SimulationManager given their AgentIDs
Parameters:
agent_ids : List[str]
List of AgentIDs to remove
Raises:
KeyError
If any AgentID does not exist in self.agents_dict
"""
missing = [aid for aid in agent_ids if aid not in self.agents_dict]
if missing:
raise KeyError(f"Agents do not exist: {missing}. Cannot be removed.")
for aid in agent_ids:
self.agents_dict.pop(aid)
def _unpack(
self,
agent_dict: Optional[SimulationAgentDict] = None
) -> Tuple[
List[str],
List[AgentState],
List[AgentProperties],
List[RecurrentState],
]:
# agents with both properties and states will be placed at the front of the list
# Separate agents into two groups: with states & without states
agents_with_states = []
agents_without_states = []
if agent_dict is None:
agent_dict = self.agents_dict
for aid, data in agent_dict.items():
if data.properties is not None:
if data.state is not None:
agents_with_states.append((aid, data))
else:
agents_without_states.append((aid, data))
# agents_with_states in front of agents_without_states for API alignment
ordered_agents = agents_with_states + agents_without_states
agent_ids: List[str] = []
states: List[AgentState] = []
properties: List[AgentProperties] = []
recurrent_states: List[RecurrentState] = []
for aid, data in ordered_agents:
agent_ids.append(aid)
states.append(data.state)
properties.append(data.properties)
recurrent_states.append(data.recurrent)
if states == [None] * len(states):
states = None
return agent_ids, states, properties, recurrent_states
def _pack(
self,
agent_ids: List[str],
states: List[AgentState],
properties: List[AgentProperties],
recurrent_states: List[RecurrentState],
) -> SimulationAgentDict:
agents_dict = defaultdict(AgentData)
for i, aid in enumerate(agent_ids):
agents_dict[aid] = AgentData(
state=states[i],
properties=properties[i] if properties else None,
recurrent=recurrent_states[i] if recurrent_states else None,
)
return agents_dict
[docs] def initialize(
self,
regions: List[Region],
external_agent_data: Optional[SimulationAgentDict] = None,
**kwargs
) -> InitializeResponse:
"""
Initialize simulation using :func:`large_initialize` with agents in self.agents_dict along with any optionally provided external_agent_data
Parameters:
regions : List[Region]
Regions with presampled agents. use iai.get_regions_default() to obtain list of Regions
external_agent_data : Optional[SimulationAgentDict]
Optional stateless dictionary of externally created agents to initialize alongside internal self.agents_dict
Requires valid properties and optional state to be provided for each agent in the dictionary
You can use the Scenario Builder tool available on the Inverted AI website to validate an agent's state and properties
Please see :func:`large_initialize` for documentation on **kwargs
Note:
- agent_states, agent_properties, and recurrent_states should not be
provided in **kwargs. These values are automatically derived from the
internal agent dictionary and managed by this wrapper
- For all other supported parameters, please refer to the documentation for :func:`large_initialize`
"""
if external_agent_data:
overlap = set(self.agents_dict.keys()) & set(external_agent_data.keys())
if overlap:
raise ValueError(f"External agent IDs conflict with internal agents: {overlap}")
agent_ids, states, properties, recurrent_states = self._unpack(
{**self.agents_dict, **external_agent_data} if external_agent_data else self.agents_dict
)
external_ids = set(external_agent_data.keys()) if external_agent_data else set()
if any(p is None for p in properties):
raise ValueError("All agents must have non-None properties before initialization.")
original_agent_count = len(agent_ids)
response = large_initialize(
regions=regions,
agent_properties=properties,
agent_states=states,
return_exact_agents=True,
**kwargs
)
num_new_agents = len(response.agent_states) - original_agent_count
new_ids = [str(uuid.uuid4()) for _ in range(num_new_agents)]
all_agent_ids = agent_ids + new_ids
new_properties = response.agent_properties
if self.waypoint_manager:
new_properties = self.waypoint_manager.update(
response = response,
agent_properties = response.agent_properties,
)
internal_indices = []
for i, aid in enumerate(all_agent_ids):
if aid not in external_ids:
internal_indices.append(i)
self.agents_dict = self._pack(
agent_ids=[all_agent_ids[i] for i in internal_indices],
states=[response.agent_states[i] for i in internal_indices],
properties=[new_properties[i] for i in internal_indices],
recurrent_states=[response.recurrent_states[i] for i in internal_indices],
)
if self.scene_plotter:
self.scene_plotter.initialize_recording(
agent_states=response.agent_states,
agent_properties=response.agent_properties,
)
if self.log_writer is not None:
all_agents_dict = self._pack( # both internal+external agents
agent_ids=all_agent_ids,
states=response.agent_states,
properties=new_properties,
recurrent_states=response.recurrent_states,
)
self.log_writer.initialize(
location=self.log_writer_cfg.location,
location_info_response=self.log_writer_cfg.location_info_response,
agents_dict=all_agents_dict,
init_response=response,
)
return response
[docs] def drive(
self,
external_agent_data: Optional[SimulationAgentDict] = None,
**kwargs
)-> DriveResponse:
"""
Advance the simulation by one timestep with :func:`large_drive` using the data from agents
in self.agents_dict and optionally provided external_agent_data
This method:
- updates self.agents_dict with results from large_drive
- uses iai.WaypointManager to update waypoints if configured
- Records visualization and logging outputs if configured
Parameters:
external_agent_data : Optional[SimulationAgentDict]
Optional stateless dictionary of externally created agents to
drive alongside internal self.agents_dict at this timestep
Requires both valid state and properties to be provided for each agent in the dictionary
You can use the Scenario Builder tool available on the Inverted AI website to validate an agent's state and properties
Returns:
DriveResponse
Please see :func:`large_drive` for information on kwargs
Note:
- agent_states, agent_properties, and recurrent_states should not be
provided in kwargs. These values are automatically derived from the
internal agent dictionary and managed by this wrapper
- For all other supported parameters, please refer to the documentation for :func:`large_drive`
"""
if external_agent_data:
overlap = set(self.agents_dict.keys()) & set(external_agent_data.keys())
if overlap:
raise ValueError(f"External agent IDs conflict with internal agents: {overlap}")
agent_ids, states, properties, recurrent_states = self._unpack(self.agents_dict)
if len(recurrent_states) > 0:
internal_recur_size = len(recurrent_states[0].packed)
else:
internal_recur_size = RECURRENT_SIZE
if external_agent_data:
# external data validation
missing_states = [aid for aid, data in external_agent_data.items() if data.state is None]
if missing_states:
raise ValueError(f"External agents must have a state for drive: {missing_states}")
missing_props = [aid for aid, data in external_agent_data.items() if data.properties is None]
if missing_props:
raise ValueError(f"External agents must have properties for drive: {missing_props}")
ext_ids, ext_states, ext_props, _ = self._unpack(external_agent_data)
ext_recurrent_states = [RecurrentState(packed=[0.0] * internal_recur_size) for _ in ext_ids]
agent_ids = agent_ids + ext_ids
states = states + ext_states
properties = properties + ext_props
recurrent_states = recurrent_states + ext_recurrent_states
external_ids = set(ext_ids) if external_agent_data else set()
response = large_drive(
agent_states=states,
agent_properties=properties,
recurrent_states=recurrent_states,
**kwargs
)
if self.waypoint_manager:
properties = self.waypoint_manager.update(
response = response,
agent_properties = properties,
)
internal_indices = []
for i, aid in enumerate(agent_ids):
if aid not in external_ids:
internal_indices.append(i)
self.agents_dict = self._pack(
agent_ids=[agent_ids[i] for i in internal_indices],
states=[response.agent_states[i] for i in internal_indices],
properties=[properties[i] for i in internal_indices],
recurrent_states=[response.recurrent_states[i] for i in internal_indices],
)
if self.scene_plotter:
self.scene_plotter.record_step(
response.agent_states,
traffic_light_states=response.traffic_lights_states,
agent_properties=properties,
)
if self.log_writer is not None:
all_agents_dict = self._pack(
agent_ids=agent_ids,
states=response.agent_states,
properties=properties,
recurrent_states=response.recurrent_states,
)
self.log_writer.drive(
drive_response=response,
agents_dict=all_agents_dict,
)
return response
[docs] def visualize_data(self, **kwargs) -> FuncAnimation:
"""
Produce an animation of sequentially recorded steps. If a ScenePlotter was configured during initialization,
recorded steps from each drive will be visualized using the birdview map and static actors.
A matplotlib animation object can be returned and/or a gif saved of the scene.
For kwargs, please see documentation from :func:`animate_scene` in the ScenePlotter class
"""
if self.scene_plotter is None:
raise ValueError("ScenePlotter not initialized, failed to animate scene")
self.scene_plotter.animate_scene(**kwargs)
[docs] def export_log(self, path: Optional[str] = None):
"""
Export the log of the simulation to a JSON file if logging was enabled with a
LogWriterConfig during initialization of the SimulationManager
Parameters:
path : Optional[str]
Optional path to specify where the log should be saved. If None, will default to the path provided in LogWriterConfig during initialization.
If no path is provided in either place, will raise an error.
"""
if self.log_writer is None:
raise ValueError("Logging not enabled.")
log_path = path or self.log_writer_cfg.log_path
if log_path is None:
raise ValueError("No export path specified.")
self.log_writer.export_to_file(log_path=log_path)
# Getters
def get_scene_plotter(self) -> Optional[ScenePlotter]:
return self.scene_plotter
def get_states(self) -> List[AgentState]:
return [data.state for data in self.agents_dict.values()]
def get_agent_ids(self) -> List[str]:
return list(self.agents_dict.keys())
def get_properties(self) -> List[AgentProperties]:
return [data.properties for data in self.agents_dict.values()]
def get_recurrent_states(self) -> List[RecurrentState]:
return [data.recurrent for data in self.agents_dict.values()]
def get_agent_data(self, agent_id:str) -> AgentData:
if agent_id not in self.agents_dict:
raise KeyError(f"Agent '{agent_id}' does not exist")
return self.agents_dict[agent_id]
def get_agent_dict(self)-> SimulationAgentDict:
return self.agents_dict
# Setters for individual agents in self.agents_dict
def set_state(self, agent_id: str, state: AgentState):
if agent_id not in self.agents_dict:
raise KeyError(f"Agent '{agent_id}' does not exist")
self.agents_dict[agent_id].state = state
def set_property(self, agent_id: str, properties: AgentProperties):
if agent_id not in self.agents_dict:
raise KeyError(f"Agent '{agent_id}' does not exist")
self.agents_dict[agent_id].properties = properties
def set_recurrent_state(self, agent_id: str, recurrent: RecurrentState):
if agent_id not in self.agents_dict:
raise KeyError(f"Agent '{agent_id}' does not exist")
self.agents_dict[agent_id].recurrent = recurrent
#Setters for all agents in self.agents_dict
def set_states(self, states: List[AgentState]):
if len(states) != len(self.agents_dict):
raise ValueError(f"Expected {len(self.agents_dict)} states, got {len(states)}")
for i, agent_id in enumerate(self.agents_dict.keys()):
self.agents_dict[agent_id].state = states[i]
def set_properties(self, properties: List[AgentProperties]):
if len(properties) != len(self.agents_dict):
raise ValueError(f"Expected {len(self.agents_dict)} properties, got {len(properties)}")
for i, agent_id in enumerate(self.agents_dict.keys()):
self.agents_dict[agent_id].properties = properties[i]
def set_recurrent_states(self, recurrent_states: List[RecurrentState]):
if len(recurrent_states) != len(self.agents_dict):
raise ValueError(f"Expected {len(self.agents_dict)} recurrent states, got {len(recurrent_states)}")
for i, agent_id in enumerate(self.agents_dict.keys()):
self.agents_dict[agent_id].recurrent = recurrent_states[i]