Source code for invertedai.logs.logger

from collections import defaultdict

from pydantic import BaseModel, validate_arguments, model_validator
from typing import List, Optional, Dict, Tuple, Any, Union
from copy import deepcopy

import matplotlib.pyplot as plt
import json

from invertedai import location_info
from invertedai.utils import ScenePlotter, WaypointsDict, convert_attributes_to_properties
from invertedai.api.location import LocationResponse
from invertedai.api.initialize import InitializeResponse
from invertedai.api.drive import DriveResponse
from invertedai.common import ( 
    AgentAttributes,
    AgentData, 
    AgentProperties,
    AgentState, 
    LightRecurrentState,
    LightRecurrentStates,
    Point,
    RecurrentState,
    SimulationAgentDict,
    TrafficLightStatesDict 
)

[docs]class ScenarioLog(BaseModel): """ A log containing simulation information for storage, replay, or an initial state from which a simulation can be continued. Some data fields contain data for all historic time steps while others contain information for the most recent time step to be used to continue a simulation. """ agent_data: List[SimulationAgentDict] #: Historic data for all SimulationAgentDict up until the most recent time step. traffic_lights_states: Optional[List[TrafficLightStatesDict]] = None #: Historic data for all TrafficLightStatesDict up until the most recent time step. location: str #: Location name in IAI format. rendering_center: Optional[Tuple[float, float]] = None #: Please refer to the documentation of :func:`location_info` for information on this parameter. rendering_fov: Optional[int] = None #: Please refer to the documentation of :func:`location_info` for information on this parameter. lights_random_seed: Optional[int] = None #: Controls the stochastic aspects of the the traffic lights states. initialize_random_seed: Optional[int] = None #: Please refer to the documentation of :func:`initialize` for information on the random_seed parameter. drive_random_seed: Optional[int] = None #: Please refer to the documentation of :func:`drive` for information on the random_seed parameter. initialize_model_version: Optional[str] = "best" #: Please refer to the documentation of :func:`initialize` for information on the api_model_version parameter. drive_model_version: Optional[str] = "best" #: Please refer to the documentation of :func:`drive` for information on the api_model_version parameter. light_recurrent_states: Optional[LightRecurrentStates] = None #: As of the most recent time step. Please refer to the documentation of :func:`drive` for further information on this parameter. recurrent_states: Optional[List[RecurrentState]] = None #: As of the most recent time step. Please refer to the documentation of :func:`drive` for further information on this parameter.
[docs] def get_agent_states(self, timestep: Optional[int] = None) -> List[AgentState]: """Agent states at a specific timestep. If None, returns the latest.""" if timestep is None: timestep = -1 return [data.state for data in self.agent_data[timestep].values()]
[docs] def get_agent_properties(self, timestep: Optional[int] = None) -> List[AgentProperties]: """Agent properties at a specific timestep. If None, returns the latest.""" if timestep is None: timestep = -1 return [data.properties for data in self.agent_data[timestep].values()]
[docs] def get_agent_ids(self, timestep: Optional[int] = None) -> List[str]: """Agent IDs at a specific timestep. If None, returns the latest.""" if timestep is None: timestep = -1 return list(self.agent_data[timestep].keys())
[docs] def get_agents(self, timestep: Optional[int] = None) -> SimulationAgentDict: """Agent dict at a specific timestep. If None, returns the latest.""" if timestep is None: timestep = -1 return self.agent_data[timestep]
[docs] def get_agent_ids_all(self) -> List[str]: """All unique agent IDs across all timesteps, in order of first appearance.""" keys = [] seen = set() for snap in self.agent_data: for k in snap.keys(): if k not in seen: keys.append(k) seen.add(k) return keys
[docs] def get_agent_properties_all(self) -> Dict[str, AgentProperties]: """Map of agent ID -> latest properties for all agents that ever appeared.""" props = {} for snap in self.agent_data: for key, data in snap.items(): props[key] = data.properties return props
[docs] @model_validator(mode='after') def validate_agent_data(self): for t, agent_dict in enumerate(self.agent_data): self._validate_agent_dict(agent_dict, timestep=t) return self
@staticmethod def _validate_agent_dict(agent_dict: SimulationAgentDict, timestep: Optional[int] = None): """Validate that every agent in the dict has a non-None state.""" ts_label = f" at timestep {timestep}" if timestep is not None else "" for aid, data in agent_dict.items(): assert data.state is not None, f"Agent '{aid}'{ts_label} has no state."
[docs] def add_time_step_data(self, agent_dict: SimulationAgentDict): """Append a deep-copied agent_dict entry""" self._validate_agent_dict(agent_dict) self.agent_data.append(deepcopy(agent_dict))
class ScenarioLogLegacy(BaseModel): """ Deprecated, do not use. Please use ScenarioLog instead. A log containing simulation information for storage, replay, or an initial state from which a simulation can be continued. Some data fields contain data for all historic time steps while others contain information for the most recent time step to be used to continue a simulation. """ agent_states: List[List[AgentState]] #: Historic data for all agents states up until the most recent time step. agent_properties: List[AgentProperties] #: Agent properties data for all agents in this scenario/log. traffic_lights_states: Optional[List[TrafficLightStatesDict]] = None #: Historic data for all TrafficLightStatesDict up until the most recent time step. location: str #: Location name in IAI format. rendering_center: Optional[Tuple[float, float]] = None #: Please refer to the documentation of :func:`location_info` for information on this parameter. rendering_fov: Optional[int] = None #: Please refer to the documentation of :func:`location_info` for information on this parameter. lights_random_seed: Optional[int] = None #: Controls the stochastic aspects of the the traffic lights states. initialize_random_seed: Optional[int] = None #: Please refer to the documentation of :func:`initialize` for information on the random_seed parameter. drive_random_seed: Optional[int] = None #: Please refer to the documentation of :func:`drive` for information on the random_seed parameter. initialize_model_version: Optional[str] = "best" #: Please refer to the documentation of :func:`initialize` for information on the api_model_version parameter. drive_model_version: Optional[str] = "best" #: Please refer to the documentation of :func:`drive` for information on the api_model_version parameter. light_recurrent_states: Optional[LightRecurrentStates] = None #: As of the most recent time step. Please refer to the documentation of :func:`drive` for further information on this parameter. recurrent_states: Optional[List[RecurrentState]] = None #: As of the most recent time step. Please refer to the documentation of :func:`drive` for further information on this parameter. waypoints_per_frame: Optional[List[WaypointsDict]] = None # As of the most recent time step. A list of waypoints keyed to agent ID's not including waypoints already passed. These waypoints are not automatically populated into the agent properties. present_indexes: List[List[int]] = None #: List of indexes corresponding to agent_properties for which agents are present at each time step. If None, all agents are present at every time step. @model_validator(mode='after') def validate_states_and_present_indexes_init(self): if self.present_indexes is not None: assert len(self.agent_states) == len(self.present_indexes), "Given different number of time steps for agent states and present indexes." for states, pres_ids in zip(self.agent_states,self.present_indexes): self.validate_states_and_present_indexes_time_step( current_agent_states=states, current_present_indexes=pres_ids ) return self def validate_states_and_present_indexes_time_step( self, current_agent_states: List[AgentState], current_present_indexes: List[int] ): assert min(current_present_indexes) >= 0, "Invalid agent ID's in given list of present indexes." assert len(current_present_indexes) == len(current_agent_states), "Given number of agent states does not match number of present agents." def add_time_step_data( self, current_agent_states: List[AgentState], current_present_indexes: List[int] ): self.validate_states_and_present_indexes_time_step( current_agent_states=current_agent_states, current_present_indexes=current_present_indexes ) self.present_indexes.append(current_present_indexes) self.agent_states.append(current_agent_states) def to_scenario_log(self): """Convert this deprecated ScenarioLogLegacy into a ScenarioLog""" agent_data = [] for t, (states, present) in enumerate(zip(self.agent_states, self.present_indexes)): agent_dict: SimulationAgentDict = {} for pos, idx in enumerate(present): props = self.agent_properties[idx] if self.waypoints_per_frame is not None and t < len(self.waypoints_per_frame): frame_wp = self.waypoints_per_frame[t] if frame_wp and str(idx) in frame_wp: props = deepcopy(props) props.waypoints = frame_wp[str(idx)] agent_dict[str(idx)] = AgentData( state=states[pos], properties=props, recurrent=None, ) agent_data.append(agent_dict) return ScenarioLog( agent_data=agent_data, traffic_lights_states=self.traffic_lights_states, location=self.location, rendering_center=self.rendering_center, rendering_fov=self.rendering_fov, lights_random_seed=self.lights_random_seed, initialize_random_seed=self.initialize_random_seed, drive_random_seed=self.drive_random_seed, initialize_model_version=self.initialize_model_version, drive_model_version=self.drive_model_version, light_recurrent_states=self.light_recurrent_states, recurrent_states=self.recurrent_states, ) @classmethod def from_scenario_log( cls, scenario_log: ScenarioLog ): """Convert a ScenarioLog into the deprecated ScenarioLogLegacy format""" all_keys = scenario_log.get_agent_ids_all() id_to_idx = {aid: i for i, aid in enumerate(all_keys)} props_map = scenario_log.get_agent_properties_all() agent_properties_list = [props_map[k] for k in all_keys] agent_states_list = [] present_indexes_list = [] waypoints_per_frame_list = [] for agent_dict in scenario_log.agent_data: states = [] present = [] wp_dict = {} for aid, data in agent_dict.items(): idx = id_to_idx[aid] present.append(idx) states.append(data.state) if data.properties and data.properties.waypoints: wp_dict[str(idx)] = data.properties.waypoints present_indexes_list.append(present) agent_states_list.append(states) waypoints_per_frame_list.append(wp_dict) return cls( agent_states=agent_states_list, agent_properties=agent_properties_list, traffic_lights_states=scenario_log.traffic_lights_states, location=scenario_log.location, rendering_center=scenario_log.rendering_center, rendering_fov=scenario_log.rendering_fov, lights_random_seed=scenario_log.lights_random_seed, initialize_random_seed=scenario_log.initialize_random_seed, drive_random_seed=scenario_log.drive_random_seed, initialize_model_version=scenario_log.initialize_model_version, drive_model_version=scenario_log.drive_model_version, light_recurrent_states=scenario_log.light_recurrent_states, recurrent_states=scenario_log.recurrent_states, waypoints_per_frame=waypoints_per_frame_list if any(wp_dict for wp_dict in waypoints_per_frame_list) else None, present_indexes=present_indexes_list, ) class LogBase(): """ A class for containing features relevant to both log reading and writing such as visualization. """ def __init__(self): self._scenario_log = None self.simulation_length = None @classmethod def scenario_log_from_debug_log( cls, debug_log_path ): with open(debug_log_path) as f: DEBUG_LOG_DATA = json.load(f) last_init_res = json.loads(DEBUG_LOG_DATA["initialize_responses"][-1]) last_init_req = json.loads(DEBUG_LOG_DATA["initialize_requests"][-1]) last_drive_req = json.loads(DEBUG_LOG_DATA["drive_requests"][-1]) all_drive_responses = [] all_agent_states = [] all_agent_properties = [] all_traffic_lights_states = [] for res_, req_ in zip(DEBUG_LOG_DATA["drive_responses"], DEBUG_LOG_DATA["drive_requests"]): res = json.loads(res_) req = json.loads(req_) all_drive_responses.append(res) all_agent_states.append([AgentState.fromlist(state) for state in res["agent_states"]]) all_agent_properties.append([AgentProperties.deserialize(prop) for prop in req["agent_properties"]]) if res["traffic_lights_states"] is not None: all_traffic_lights_states.append(res["traffic_lights_states"]) else: all_traffic_lights_states = None log_location = last_init_req["location"] if len(DEBUG_LOG_DATA["location_responses"]) > 0: loc_res = json.loads(DEBUG_LOG_DATA["location_responses"][-1]) rendering_center = loc_res["map_center"] rendering_fov=loc_res["map_fov"] else: location_info_response = location_info( location=log_location, ) rendering_center = [ location_info_response.map_center.x, location_info_response.map_center.y ] rendering_fov=location_info_response.map_fov agent_data = [] for t, (states_list, props_list) in enumerate(zip(all_agent_states, all_agent_properties)): agent_dict: SimulationAgentDict = defaultdict(AgentData) for i, (state, props) in enumerate(zip(states_list, props_list)): agent_dict[str(i)] = AgentData( state=state, properties=props, recurrent=None, ) agent_data.append(agent_dict) scenario_log = ScenarioLog( agent_data=agent_data, traffic_lights_states=all_traffic_lights_states, location=log_location, rendering_center=rendering_center, rendering_fov=rendering_fov, lights_random_seed=last_drive_req["random_seed"], initialize_random_seed=last_init_req["random_seed"], drive_random_seed=last_drive_req["random_seed"], initialize_model_version=last_init_res["model_version"], drive_model_version=all_drive_responses[-1]["model_version"], light_recurrent_states=all_drive_responses[-1]["light_recurrent_states"], recurrent_states=[RecurrentState.fromval(rec_state) for rec_state in all_drive_responses[-1]["recurrent_states"]], ) return scenario_log @validate_arguments def visualize_range( self, timestep_range: Tuple[int,int], gif_path: str, fov: int = 200, resolution: Tuple[int,int] = (2048,2048), dpi: int = 300, map_center: Optional[Tuple[float,float]] = None, direction_vec: bool = False, velocity_vec: bool = False, plot_frame_number: bool = True, left_hand_coordinates: bool = False, agent_ids: Optional[List[int]] = None ): """ Use the available internal tools to visualize the a specific range of time steps within the log and save it to a given location. If an invalid time step range is given, the function will fail. Please refer to ScenePlotter for details on the visualization tool. """ for timestep in timestep_range: assert timestep >= 0 or timestep <= (self.simulation_length - 1), "Visualization time range valid." assert timestep_range[1] >= timestep_range[0], "Visualization time range valid." location_info_response = location_info( location=self._scenario_log.location, rendering_fov=fov, rendering_center=map_center ) rendered_static_map = location_info_response.birdview_image.decode() map_center = tuple([location_info_response.map_center.x, location_info_response.map_center.y]) if map_center is None else map_center traffic_lights_states = [None]*len(self._scenario_log.agent_data) if self._scenario_log.traffic_lights_states is None else self._scenario_log.traffic_lights_states first_agent_dict = self._scenario_log.agent_data[0] scene_plotter = ScenePlotter( map_image=rendered_static_map, fov=fov, xy_offset=map_center, static_actors=location_info_response.static_actors, resolution=resolution, dpi=dpi, left_hand_coordinates=left_hand_coordinates ) scene_plotter.initialize_recording( agent_states=[d.state for d in first_agent_dict.values()], agent_properties=[d.properties for d in first_agent_dict.values()], traffic_light_states=traffic_lights_states[timestep_range[0]], ) for ts, agent_dict in enumerate(self._scenario_log.agent_data): lights = traffic_lights_states[ts] if ts < len(traffic_lights_states) else None scene_plotter.record_step( agent_states=[d.state for d in agent_dict.values()], traffic_light_states=lights, agent_properties=[d.properties for d in agent_dict.values()], ) fig, ax = plt.subplots(constrained_layout=True, figsize=(50, 50)) plt.axis('off') scene_plotter.animate_scene( output_name=gif_path, start_idx=timestep_range[0], end_idx=timestep_range[1], ax=ax, direction_vec=direction_vec, velocity_vec=velocity_vec, plot_frame_number=plot_frame_number, numbers=agent_ids ) plt.close(fig) @validate_arguments def visualize( self, gif_path: str, fov: int = 200, resolution: Tuple[int,int] = (2048,2048), dpi: int = 300, map_center: Optional[Tuple[float,float]] = None, direction_vec: bool = False, velocity_vec: bool = False, plot_frame_number: bool = True, left_hand_coordinates: bool = False, agent_ids: Optional[List[int]] = None ): """ Use the available internal tools to visualize the entire log and save it to a given location. Please refer to ScenePlotter for details on the visualization tool. """ self.visualize_range( timestep_range = tuple([0,self.simulation_length-1]), gif_path = gif_path, fov = fov, resolution = resolution, dpi = dpi, map_center = map_center, direction_vec = direction_vec, velocity_vec = velocity_vec, plot_frame_number = plot_frame_number, left_hand_coordinates = left_hand_coordinates, agent_ids = agent_ids ) def initialize(self): pass def drive(self): pass class LogWriterConfig(BaseModel): """ Configuration for logging + exporting simulation runs. """ log_path: str location: str location_info_response: Optional[LocationResponse] = None
[docs]class LogWriter(LogBase): """ A class for conveniently writing a log to a JSON log format. """ def __init__(self): super().__init__()
[docs] @validate_arguments def export_to_file( self, log_path: str, scenario_log: Optional[ScenarioLog] = None ): """ Convert the data currently contained within the log into a JSON format and export it to a given file. This function can furthermore be used to export a given scenario log instead of the log contained within the object. """ def _format_waypoints_json( wps: List[Point] ): return { "suggestion_strength": 0.8, #Default value "states": [{ "center": { "x": wp.x, "y": wp.y } } for wp in wps] } if scenario_log is None: scenario_log = self._scenario_log all_keys = scenario_log.get_agent_ids_all() agent_properties = scenario_log.get_agent_properties_all() # Build waypoints from agent properties map individual_suggestions_dict = {} for key, prop in agent_properties.items(): if prop.waypoints: individual_suggestions_dict[key] = _format_waypoints_json(prop.waypoints) elif prop.waypoint: individual_suggestions_dict[key] = _format_waypoints_json([prop.waypoint]) num_cars, num_pedestrians = 0, 0 for prop in agent_properties.values(): if prop.agent_type == "car": num_cars += 1 elif prop.agent_type == "pedestrian": num_pedestrians += 1 num_controls_light, num_controls_yield, num_controls_stop, num_controls_other = 0, 0, 0, 0 static_actors_list = location_info(location=scenario_log.location).static_actors for actor in static_actors_list: if actor.agent_type == "traffic_light": num_controls_light += 1 elif actor.agent_type == "yield_sign": num_controls_yield += 1 elif actor.agent_type == "stop_sign": num_controls_stop += 1 else: num_controls_other += 1 # Build predetermined_agents_dict using agent keys from agent_data predetermined_agents_dict = {} for key in all_keys: prop = agent_properties[key] states_dict = {} for t, agent_dict in enumerate(scenario_log.agent_data): if key in agent_dict: state = agent_dict[key].state states_dict[str(t)] = { "center": {"x": state.center.x, "y": state.center.y}, "orientation": state.orientation, "speed": state.speed, } predetermined_agents_dict[key] = { "entity_type": prop.agent_type, "static_attributes": { "length": prop.length, "width": prop.width, "rear_axis_offset": prop.rear_axis_offset, }, "states":states_dict } predetermined_controls_dict = {} if scenario_log.traffic_lights_states is not None: for actor in [actor for actor in static_actors_list if actor.agent_type == "traffic_light"]: actor_id = actor.actor_id states_dict = {} for t, tls in enumerate(scenario_log.traffic_lights_states): states_dict[str(t)] = { "center": {"x": actor.center.x, "y": actor.center.y}, "orientation": actor.orientation, "speed": 0, "control_state": tls[actor_id] } predetermined_controls_dict[actor_id] = { "entity_type": "traffic_light", "static_attributes": { "length": actor.length, "width": actor.width, "rear_axis_offset": 0, }, "states":states_dict } self.output_dict = { "location": { "identifier": scenario_log.location }, "scenario_length": len(scenario_log.agent_data), "num_agents": { "car": num_cars, "pedestrian": num_pedestrians }, "predetermined_agents": predetermined_agents_dict, "num_controls": { "traffic_light": num_controls_light, "yield_sign": num_controls_yield, "stop_sign": num_controls_stop, "other": num_controls_other, }, "predetermined_controls": predetermined_controls_dict, "individual_suggestions": individual_suggestions_dict, "initialize_random_seed": scenario_log.initialize_random_seed, "lights_random_seed": scenario_log.lights_random_seed, "drive_random_seed": scenario_log.drive_random_seed, "drive_model_version": scenario_log.drive_model_version, "initialize_model_version": scenario_log.initialize_model_version, "birdview_options": { "rendering_center": [ scenario_log.rendering_center[0], scenario_log.rendering_center[1] ], "renderingFOV": scenario_log.rendering_fov }, "light_recurrent_states": [] if scenario_log.light_recurrent_states is None else [lrs.tolist() for lrs in scenario_log.light_recurrent_states], "rendering_centers": [ scenario_log.rendering_center[0], scenario_log.rendering_center[1] ] } with open(log_path, "w") as outfile: json.dump( self.output_dict, outfile, indent=4 )
[docs] @classmethod def export_log_to_file( cls, log_path: str, scenario_log: ScenarioLog ): """ Class function to convert a given log data type into a JSON format and export it to a given file. """ cls.export_to_file(cls,log_path,scenario_log)
[docs] @validate_arguments def initialize( self, location: Optional[str] = None, location_info_response: Optional[LocationResponse] = None, init_response: Optional[InitializeResponse] = None, lights_random_seed: Optional[int] = None, initialize_random_seed: Optional[int] = None, drive_random_seed: Optional[int] = None, drive_model_version: Optional[str] = None, scenario_log: Optional[Union[ScenarioLogLegacy, ScenarioLog]] = None, waypoints: Optional[WaypointsDict] = None, agents_dict: Optional[SimulationAgentDict] = None, agent_ids: Optional[List[str]] = None, ): """ Consume and store all initial information within a ScenarioLog data object. If random seed information is desired to be stored, it must be given separately but is not mandatory. Preferred: pass agents_dict (SimulationAgentDict) so UUID keys are preserved. Fallback: pass init_response for backwards compatibility — integer string keys ("0", "1", ...) will be assigned automatically. Legacy: pass scenario_log (ScenarioLog) to initialize from an existing log. Deprecated parameters (kept for backwards compatibility, use agents_dict instead): waypoints: Waypoints are now stored directly in AgentProperties within the agents_dict. agent_ids: Agent IDs are now the keys of agents_dict. """ if scenario_log is not None: if isinstance(scenario_log, ScenarioLog): self._scenario_log = scenario_log else: self._scenario_log = scenario_log.to_scenario_log() self.simulation_length = len(self._scenario_log.agent_data) return assert location is not None, "No scenario log given, must provide a location argument." assert location_info_response is not None, "No scenario log given, must provide a location_info_response argument." assert init_response is not None, "No scenario log given, must provide a init_response argument." if agents_dict is None: # Build from init_response (backwards compatibility) agent_properties = init_response.agent_properties if type(agent_properties[0]) == AgentAttributes: agent_properties = [convert_attributes_to_properties(attr) for attr in agent_properties] keys = agent_ids if agent_ids is not None else [str(i) for i in range(len(init_response.agent_states))] assert len(keys) == len(init_response.agent_states), ( f"agent_ids length ({len(keys)}) must match number of agents in init_response ({len(init_response.agent_states)})." ) agent_dict: SimulationAgentDict = {} for key, state, prop, rec in zip(keys, init_response.agent_states, agent_properties, init_response.recurrent_states): if waypoints is not None and key in waypoints: prop = deepcopy(prop) prop.waypoints = waypoints[key] agent_dict[key] = AgentData(state=state, properties=prop, recurrent=rec) self._scenario_log = ScenarioLog( agent_data=[deepcopy(agent_dict)], traffic_lights_states=[init_response.traffic_lights_states] if init_response.traffic_lights_states is not None else None, location=location, rendering_center=[ location_info_response.map_center.x, location_info_response.map_center.y ], rendering_fov=location_info_response.map_fov, lights_random_seed=lights_random_seed, initialize_random_seed=initialize_random_seed, drive_random_seed=drive_random_seed, initialize_model_version=init_response.api_model_version, drive_model_version=drive_model_version, light_recurrent_states=init_response.light_recurrent_states, recurrent_states=init_response.recurrent_states, ) self.simulation_length = 1
[docs] @validate_arguments def drive( self, drive_response: DriveResponse, current_present_indexes: Optional[List[int]] = None, new_agent_properties: Optional[List[AgentProperties]] = None, waypoints: Optional[WaypointsDict] = None, #unused, only for backwards compatibility with older logs that may have waypoints stored separately from agent properties agent_properties: Optional[List[AgentProperties]] = None, agents_dict: Optional[SimulationAgentDict] = None, agent_ids: Optional[List[str]] = None, ): """ Consume and store driving response information from a single timestep and append it to the end of the log. Preferred: pass agents_dict (SimulationAgentDict) with all agents for this timestep. Fallback: pass drive_response with optional agent_ids and agent_properties to build the dict. Legacy: pass current_present_indexes and new_agent_properties for index-based tracking. Deprecated parameters (kept for backwards compatibility, use agents_dict instead): waypoints: Waypoints are now stored directly in AgentProperties within the agents_dict. agent_ids: Agent IDs are now the keys of agents_dict. """ if agents_dict is not None: self._scenario_log.add_time_step_data(agents_dict) elif current_present_indexes is not None: # Backwards compatibility: index-based agent tracking all_props_map = dict(self._scenario_log.get_agent_properties_all()) if new_agent_properties is not None: next_idx = len(all_props_map) for j, prop in enumerate(new_agent_properties): all_props_map[str(next_idx + j)] = prop recurrent_states = drive_response.recurrent_states or [None] * len(current_present_indexes) agent_dict: SimulationAgentDict = {} for state_idx, prop_idx in enumerate(current_present_indexes): key = str(prop_idx) agent_dict[key] = AgentData( state=drive_response.agent_states[state_idx], properties=all_props_map[key], recurrent=recurrent_states[state_idx], ) self._scenario_log.add_time_step_data(agent_dict) else: # No agents_dict and no current_present_indexes provided: # assume agent keys have not changed from the previous timestep. # if agent_ids is provided use them, otherwise assume ids are the same from previous timestep keys = agent_ids if agent_ids is not None else list(self._scenario_log.get_agents().keys()) assert len(keys) == len(drive_response.agent_states), ( f"Number of agent_ids ({len(keys)}) does not match number of agents " f"in drive_response ({len(drive_response.agent_states)}). " f"If agents were added or removed, provide agents_dict or current_present_indexes." ) properties = agent_properties or [self._scenario_log.get_agents()[k].properties for k in keys] recurrent_states = drive_response.recurrent_states or [None] * len(keys) agent_dict: SimulationAgentDict = {} for key, state, prop, rec in zip(keys, drive_response.agent_states, properties, recurrent_states): agent_dict[key] = AgentData(state=state, properties=prop, recurrent=rec) self._scenario_log.add_time_step_data(agent_dict) if drive_response.traffic_lights_states is not None: self._scenario_log.traffic_lights_states.append(drive_response.traffic_lights_states) self._scenario_log.drive_model_version = drive_response.api_model_version self._scenario_log.light_recurrent_states = drive_response.light_recurrent_states self._scenario_log.recurrent_states = drive_response.recurrent_states self.simulation_length += 1
@property def current_present_indexes(self): """ Returns the indexes of the agents that are currently present within the simulation. """ all_ids = self._scenario_log.get_agent_ids_all() id_to_idx = {aid: i for i, aid in enumerate(all_ids)} current_snap = self._scenario_log.agent_data[self.simulation_length - 1] return [id_to_idx[k] for k in current_snap.keys() if k in id_to_idx] @property def all_agent_properties(self): """ Returns all agent properties that have been present in the simulation this log is capturing. """ return list(self._scenario_log.get_agent_properties_all().values())
[docs]class LogReader(LogBase): """ A class for conveniently reading in a log file then rendering it and/or plugging it into a simulation. Once the log is read, it is intended to be used in place of calling the API. """ def __init__( self, log_path: str ): """ The initialization of this object must be given the path to a JSON file in the IAI format. Assume that the 0th time step is taken from the output of :func:`initialize` and set the time step to the 1st time step whic correlates to the first time step produced by :func:`drive`. """ super().__init__() with open(log_path) as f: LOG_DATA = json.load(f) location = LOG_DATA["location"]["identifier"] agent_waypoints = None if "individual_suggestions" in LOG_DATA: agent_waypoints = {} for agent_id, waypoints in LOG_DATA["individual_suggestions"].items(): agent_waypoints[agent_id] = [] for pt in waypoints["states"]: data = pt["center"] agent_waypoints[agent_id].append(Point.fromlist([data["x"],data["y"]])) # Build agent_data from JSON data, using agent keys agent_data: List[SimulationAgentDict] = [] all_agent_properties: Dict[str, AgentProperties] = {} for i in range(LOG_DATA["scenario_length"]): agent_dict: SimulationAgentDict = defaultdict(AgentData) for agent_id, agent in LOG_DATA["predetermined_agents"].items(): if not agent_id in all_agent_properties: agent_attributes_json = agent["static_attributes"] agent_properties = AgentProperties() agent_properties.length = agent_attributes_json["length"] agent_properties.width = agent_attributes_json["width"] agent_properties.rear_axis_offset = agent_attributes_json["rear_axis_offset"] agent_properties.agent_type = agent["entity_type"] if agent_waypoints is not None: if agent_id in agent_waypoints: agent_properties.waypoints = agent_waypoints[agent_id] all_agent_properties[agent_id] = agent_properties ts_key = str(i) if ts_key in agent["states"]: agent_state_data = agent["states"][ts_key] state = AgentState.fromlist([ agent_state_data["center"]["x"], agent_state_data["center"]["y"], agent_state_data["orientation"], agent_state_data["speed"], ]) agent_dict[agent_id] = AgentData( state=state, properties=all_agent_properties[agent_id], recurrent=None, ) agent_data.append(agent_dict) all_traffic_light_states = [] for i in range(LOG_DATA["scenario_length"]): traffic_light_states_ts = {} for actor_id, actor in LOG_DATA["predetermined_controls"].items(): if actor["entity_type"] == "traffic_light": actor_info_ts = actor["states"][str(i)] traffic_light_states_ts[int(actor_id)] = actor_info_ts["control_state"] if traffic_light_states_ts: all_traffic_light_states.append(traffic_light_states_ts) if not all_traffic_light_states: all_traffic_light_states = None rendering_center = None rendering_fov = None if "birdview_options" in LOG_DATA: rendering_center = tuple([LOG_DATA["birdview_options"]["rendering_center"][0],LOG_DATA["birdview_options"]["rendering_center"][1]]) rendering_fov = LOG_DATA["birdview_options"]["renderingFOV"] light_recurrent_states = None if "light_recurrent_states" in LOG_DATA: light_recurrent_states = None if (LOG_DATA["light_recurrent_states"] == [] or LOG_DATA["light_recurrent_states"] is None) else [LightRecurrentState(state=state[0],time_remaining=state[1]) for state in LOG_DATA["light_recurrent_states"]] self._scenario_log = ScenarioLog( agent_data=agent_data, traffic_lights_states=all_traffic_light_states, location=location, rendering_center=rendering_center, rendering_fov=rendering_fov, lights_random_seed=None if not "lights_random_seed" in LOG_DATA else LOG_DATA["lights_random_seed"], initialize_random_seed=None if not "initialize_random_seed" in LOG_DATA else LOG_DATA["initialize_random_seed"], drive_random_seed=None if not "drive_random_seed" in LOG_DATA else LOG_DATA["drive_random_seed"], initialize_model_version=None if not "initialize_model_version" in LOG_DATA else LOG_DATA["initialize_model_version"], drive_model_version=None if not "drive_model_version" in LOG_DATA else LOG_DATA["drive_model_version"], light_recurrent_states=light_recurrent_states, recurrent_states=None, ) self._scenario_log_original = self._scenario_log self.reset_log() self.simulation_length = len(agent_data) self.initialize_model_version = self._scenario_log.initialize_model_version self.drive_model_version = self._scenario_log.drive_model_version self.all_waypoints = agent_waypoints self.location_info_response = location_info( location=self._scenario_log.location, rendering_fov=self._scenario_log.rendering_fov, rendering_center=self._scenario_log.rendering_center, )
[docs] @validate_arguments def return_scenario_logV2( self, timestep_range: Optional[Tuple[int,int]] = None ) -> ScenarioLog: """ Return the original scenario log. Optionally choose a time range within the log of interest. """ if timestep_range is None: return self._scenario_log_original else: for timestep in timestep_range: assert timestep >= 0 or timestep <= (self.simulation_length - 1), "Visualization time range valid." assert timestep_range[1] >= timestep_range[0], "Visualization time range valid." i, j = timestep_range[0], timestep_range[1] returned_log = deepcopy(self._scenario_log_original) returned_log.agent_data = returned_log.agent_data[i:j] if returned_log.traffic_lights_states is not None: returned_log.traffic_lights_states = returned_log.traffic_lights_states[i:j] return returned_log
[docs] @validate_arguments def return_scenario_log( self, timestep_range: Optional[Tuple[int,int]] = None ) -> ScenarioLogLegacy: """ Return the original scenario log as a legacy ScenarioLogLegacy. Optionally choose a time range within the log of interest. """ log_to_convert = self._scenario_log_original if timestep_range is not None: for timestep in timestep_range: assert timestep >= 0 or timestep <= (self.simulation_length - 1), "Visualization time range valid." assert timestep_range[1] >= timestep_range[0], "Visualization time range valid." i, j = timestep_range[0], timestep_range[1] log_to_convert = deepcopy(self._scenario_log_original) log_to_convert.agent_data = log_to_convert.agent_data[i:j] if log_to_convert.traffic_lights_states is not None: log_to_convert.traffic_lights_states = log_to_convert.traffic_lights_states[i:j] return ScenarioLogLegacy.from_scenario_log(log_to_convert)
@validate_arguments def _return_state_at_timestep( self, timestep: int ): """ Populate all state data from the given time step into the relevant member variables. """ if timestep >= self.simulation_length: return False agent_dict = self._scenario_log.agent_data[timestep] self.agent_states = [data.state for data in agent_dict.values()] self.agent_properties = [data.properties for data in agent_dict.values()] self.recurrent_states = None self.traffic_lights_states = None if self._scenario_log.traffic_lights_states is None else self._scenario_log.traffic_lights_states[timestep] self.light_recurrent_states = self._scenario_log.light_recurrent_states if timestep == (self.simulation_length - 1) else None return True
[docs] @validate_arguments def return_last_state(self): """ Read and make available state data from the final time step contained within the log which is useful as a launching point for another simulation. """ return self._return_state_at_timestep(timestep=self.simulation_length-1)
[docs] @validate_arguments def initialize(self): """ Read and make available state data from the 0th time step into the relevant state member variables e.g. agent_states. """ is_init_response = self._return_state_at_timestep(timestep=0) self.current_timestep = 1 return is_init_response
[docs] @validate_arguments def drive(self): """ Read and make available state data from the current time step into the relevant member variables then increment the current time step so that this function may be called again. If the end of the log has been reached, return False otherwise return True. """ if self.current_timestep >= self.simulation_length: return False is_drive_response = self._return_state_at_timestep(timestep=self.current_timestep) self.current_timestep += 1 return is_drive_response
[docs] @validate_arguments def reset_log(self): """ In the case the log was modified, revert the log to its initial state after being read and clear all state data. Furthermore, change the current time step such that the first :func:`drive` time step can be read. """ self._scenario_log = self._scenario_log_original self.agent_states = None self.agent_properties = None self.recurrent_states = None self.traffic_lights_states = None self.light_recurrent_states = None self.current_timestep = 1
@property def all_agent_properties(self): """ Return a list of agent properties of all agents present in the simulation. """ return self._scenario_log.get_agent_properties() @property def location(self): """ Return the location from the log. """ return self._scenario_log.location @property def log_length(self): """ Return the length of the simulation in time steps captured in this log. """ return len(self._scenario_log.agent_data) @property def agents(self) -> SimulationAgentDict: """ Return the agent dict for the most recently read timestep. """ ts = max(0, self.current_timestep - 1) return self._scenario_log.agent_data[ts]