Source code for invertedai.utils

import json
import os
import re
import csv
import math
import logging
import random
import time
import numpy as np
import warnings

from typing import Dict, Optional, List, Tuple, Union, Any
from copy import deepcopy
from pydantic import validate_call, validate_arguments

import requests
from requests import Response
from requests.auth import AuthBase
from requests.adapters import HTTPAdapter, Retry

import matplotlib.pyplot as plt
from matplotlib import animation
from matplotlib.patches import Rectangle
from matplotlib.animation import FuncAnimation
from matplotlib.axes import Axes
from matplotlib import transforms

import invertedai as iai
import invertedai.api
import invertedai.api.config
from invertedai import error
from invertedai.future import to_thread
from invertedai.error import InvertedAIError
from invertedai.common import (
    AgentState, 
    AgentAttributes, 
    AgentProperties, 
    AgentType,
    RecurrentState,
    StaticMapActor,
    TrafficLightState, 
    TrafficLightStatesDict 
)

H_SCALE = 10
text_x_offset = 0
text_y_offset = 0.7
text_size = 7
TIMEOUT_SECS = 600
MAX_RETRIES = 10
AGENT_SCOPE_FOV = 120

logger = logging.getLogger(__name__)

STATUS_MESSAGE = {
    403: "Access denied. Please check the provided API key.",
    429: "Throttled",
    502: "The server is having trouble communicating. This is usually a temporary issue. Please try again later.",
    504: "The server took too long to respond. Please try again later.",
    500: "The server encountered an unexpected issue. We're working to resolve this. Please try again later.",
}


class Session:
    def __init__(self,debug_logger=None):
        self.session = requests.Session()
        self.session.mount(
            "https://",
            requests.adapters.HTTPAdapter(),
        )
        self.session.mount(
            "http://",
            requests.adapters.HTTPAdapter(),
        )
        self.session.headers.update(
            {
                "Content-Type": "application/json",
                "Accept-Encoding": "gzip, deflate, br",
                "Connection": "keep-alive",
                "x-client-version": iai.__version__,
            }
        )
        self._base_url = self._get_base_url()
        self._max_retries = float("inf")
        self._status_force_list = [403, 408, 429, 500, 502, 503, 504]
        self._base_backoff = 1  # Base backoff time in seconds
        self._backoff_factor = 2
        self._jitter_factor = 0.5
        self._current_backoff = self._base_backoff
        self._max_backoff = None

        self._debug_logger = debug_logger

    @property
    def base_url(self):
        return self._base_url

    @property
    def max_retries(self):
        return self._max_retries

    @max_retries.setter
    def max_retries(self, value):
        self._max_retries = value

    @property
    def status_force_list(self):
        return self._status_force_list

    @status_force_list.setter
    def status_force_list(self, value):
        self._status_force_list = value.copy()

    @property
    def base_backoff(self):
        return self._base_backoff

    @base_backoff.setter
    def base_backoff(self, value):
        self._base_backoff = value
        self.current_backoff = (
            self._base_backoff
        )  # Reset current_backoff when base_backoff changes

    @property
    def backoff_factor(self):
        return self._backoff_factor

    @backoff_factor.setter
    def backoff_factor(self, value):
        self._backoff_factor = value

    
    @property
    def current_backoff(self):
        return self._current_backoff
    
    @current_backoff.setter
    def current_backoff(self, value):
        self._current_backoff = value

    @property
    def max_backoff(self):
        return self._max_backoff
    
    @max_backoff.setter
    def max_backoff(self, value):
        self._max_backoff = value

    @property
    def jitter_factor(self):
        return self._jitter_factor
    
    @jitter_factor.setter
    def jitter_factor(self, value):
        self._jitter_factor = value


    def should_log(self, retry_count):
        return retry_count == 0 or math.log2(retry_count).is_integer()

    @base_url.setter
    def base_url(self, value):
        self._base_url = value

    def _verify_api_key(
        self, 
        api_token: str, 
        verifying_url: str
    ):
        """
        Verifies the API key by making a request to the verifying URL.

        Args:
            api_token (str): The API token to be used for authentication.
            verifying_url (str): The URL to be used for verification.

        Returns:
            str: The final verifying URL after fallback (if applicable).

        Raises:
            error.AuthenticationError: If access is denied due to an invalid API key.
        """
        self.session.auth = APITokenAuth(api_token)
        response = self.session.request(method="get", url=verifying_url)
        if verifying_url == iai.commercial_url and response.status_code != 200:
            # Check for academic access in case the previous call to the commercial server fails.
            logger.warning(
                "Commercial access denied and fallback to check for academic access."
            )
            verifying_url = iai.academic_url
            response_acd = self.session.request(method="get", url=verifying_url)
            if response_acd.status_code == 200:
                self.base_url = verifying_url
                response = response_acd
            elif response_acd.status_code != 403:
                response = response_acd
        if response.status_code == 403:
            raise error.AuthenticationError(
                "Access denied. Please check the provided API key."
            )
        return verifying_url

    def add_apikey(
        self,
        api_token: str = "",
        key_type: Optional[str] = None,
        url: Optional[str] = None,
    ):
        """
        Bind an API key to the session for authentication.

        Args:
            api_token (str): The API key to be added. Defaults to an empty string.
            key_type (str, optional): The type of API key. Defaults to None. When passed, the base_url will be set according to the key_type.
            url (str, optional): The URL to be used for the request. Defaults to None. When passed, the base_rul will be set to the passed value 
            and the key_type will be ignored.

        Raises:
            InvalidAPIKeyError: If the API key is empty and not in development mode.
            InvalidAPIKeyError: If the key_type is invalid.
            AuthenticationError: If access is denied due to an invalid API key.
            APIError: If the server encounters an error or is unable to perform the requested method.
        """
        if not iai.dev and not api_token:
            raise error.InvalidAPIKeyError("Empty API key received.")
        if url is None:
            request_url = self._get_base_url()
        if key_type is not None and key_type not in ["commercial", "academic"]:
            raise error.InvalidAPIKeyError(f"Invalid API key type: {key_type}.")
        if key_type == "academic":
            request_url = iai.academic_url
        elif key_type == "commercial":
            request_url = iai.commercial_url
        if url is not None:
            request_url = url
        self.base_url = self._verify_api_key(api_token, request_url)

    def use_mock_api(
        self, 
        use_mock: bool = True
    ) -> None:
        invertedai.api.config.mock_api = use_mock
        if use_mock:
            iai.logger.warning(
                "Using mock Inverted AI API - predictions will be trivial"
            )

    async def async_request(
        self, 
        *args, 
        **kwargs
    ):
        return await to_thread(self.request, *args, **kwargs)

    def request(
        self, 
        model: str, 
        params: Optional[dict] = None, 
        data: Optional[dict] = None
    ):
        method, relative_path = iai.model_resources[model]
        
        if self._debug_logger is not None:
            self._debug_logger.append_request(model,data)

        response = self._request(
            method=method,
            relative_path=relative_path,
            params=params,
            json_body=data,
        )

        if self._debug_logger is not None:
            self._debug_logger.append_response(model,response)

        return response

    def _request(
        self,
        method,
        relative_path: str = "",
        params=None,
        headers=None,
        json_body=None,
        data=None,
    ) -> Dict:
        try:
            retries = 0
            while retries < self.max_retries:
                try:
                    response = self.session.request(
                        method=method,
                        params=params,
                        url=self.base_url + relative_path,
                        headers=headers,
                        data=data,
                        json=json_body,
                    )
                except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e:
                    logger.warning("Error communicating with IAI, will retry.")
                    response = None
                if response is not None and response.status_code not in self.status_force_list:
                    self.current_backoff = max(
                        self.base_backoff, self.current_backoff / self.backoff_factor
                    )
                    response.raise_for_status()
                    break
                else:
                    if self.jitter_factor is not None:
                        jitter = random.uniform(-self.jitter_factor, self.jitter_factor)
                    else:
                        jitter = 0
                    if self.should_log(retries):
                        if response is not None:
                            logger.warning(
                                f"Retrying {relative_path}: Status {response.status_code}, Message {STATUS_MESSAGE.get(response.status_code, response.text)} Retry #{retries + 1}, Backoff {self.current_backoff} seconds"
                            )
                        else:
                            logger.warning(f"Retrying {relative_path}: No response received, Retry #{retries + 1}, Backoff {self.current_backoff} seconds")
                    time.sleep(min(self.current_backoff * (1 + jitter), self.max_backoff if self.max_backoff is not None else float("inf")))
                    self.current_backoff *= self.backoff_factor
                    if self.max_backoff is not None:
                        self.current_backoff = min(
                            self.current_backoff, self.max_backoff
                        )
                    retries += 1
            else:
                if response is not None:
                    response.raise_for_status()
                else:
                    error.APIConnectionError(
                        "Error communicating with IAI", should_retry=True)


        except requests.exceptions.ConnectionError as e:
            raise error.APIConnectionError(
                "Error communicating with IAI", should_retry=True
            ) from None
        except requests.exceptions.Timeout as e:
            raise error.APIConnectionError("Error communicating with IAI") from None
        except requests.exceptions.RequestException as e:
            if e.response.status_code == 403:
                raise error.AuthenticationError(STATUS_MESSAGE[403]) from None
            elif e.response.status_code in [400, 422]:
                raise error.InvalidRequestError(e.response.text, param="") from None
            elif e.response.status_code == 404:
                raise error.ResourceNotFoundError(e.response.text) from None
            elif e.response.status_code == 408:
                raise error.RequestTimeoutError(e.response.text) from None
            elif e.response.status_code == 413:
                raise error.RequestTooLarge(e.response.text) from None
            elif e.response.status_code == 429:
                raise error.RateLimitError(STATUS_MESSAGE[429]) from None
            elif e.response.status_code == 502:
                raise error.APIError(STATUS_MESSAGE[502]) from None
            elif e.response.status_code == 503:
                raise error.RequestTimeoutError(e.response.text) from None
            elif e.response.status_code == 504:
                raise error.ServiceUnavailableError(STATUS_MESSAGE[504]) from None
            elif 400 <= e.response.status_code < 500:
                raise error.APIError(e.response.text) from None
            else:
                raise error.APIError(STATUS_MESSAGE[500]) from None
        iai.logger.info(
            iai.logger.logfmt(
                "IAI API response",
                path=self.base_url,
                response_code=response.status_code,
            )
        )
        try:
            data = json.loads(response.content)
        except json.decoder.JSONDecodeError:
            raise error.APIError(
                f"HTTP code {response.status_code} from API ({response.content})",
                response.content,
                response.status_code,
                headers=response.headers,
            )
        return data

    def _get_base_url(self) -> str:
        """
        This function returns the endpoint for API calls, which includes the
        version and other endpoint specifications.
        The method path should be appended to the base_url
        """
        if not iai.dev:
            base_url = iai.commercial_url  # Default to commercial when initializing.
        else:
            base_url = iai.dev_url
        # TODO: Add endpoint option and versioning to base_url
        return base_url

    def _handle_error_response(
        self, 
        rbody, 
        rcode, 
        resp, 
        rheaders, 
        stream_error=False
    ):
        try:
            error_data = resp["error"]
        except (KeyError, TypeError):
            raise error.APIError(
                "Invalid response object from API: %r (HTTP response code "
                "was %d)" % (rbody, rcode),
                rbody,
                rcode,
                resp,
            )

        if "internal_message" in error_data:
            error_data["message"] += "\n\n" + error_data["internal_message"]

        iai.logger.info(
            iai.logger.logfmt(
                "IAI API error received",
                error_code=error_data.get("code"),
                error_type=error_data.get("type"),
                error_message=error_data.get("message"),
                error_param=error_data.get("param"),
                stream_error=stream_error,
            )
        )

        # Rate limits were previously coded as 400's with code 'rate_limit'
        if rcode == 429:
            return error.RateLimitError(
                error_data.get("message"), rbody, rcode, resp, rheaders
            )
        elif rcode in [400, 404, 415]:
            return error.InvalidRequestError(
                error_data.get("message"),
                error_data.get("param"),
                error_data.get("code"),
                rbody,
                rcode,
                resp,
                rheaders,
            )
        elif rcode == 401:
            return error.AuthenticationError(
                error_data.get("message"), rbody, rcode, resp, rheaders
            )
        elif rcode == 403:
            return error.PermissionError(
                error_data.get("message"), rbody, rcode, resp, rheaders
            )
        elif rcode == 409:
            return error.TryAgain(
                error_data.get("message"), rbody, rcode, resp, rheaders
            )
        else:
            return error.APIError(
                error_data.get("message"), rbody, rcode, resp, rheaders
            )

    def _interpret_response_line(
        self, 
        result
    ):
        rbody = result.content
        rcode = result.status_code
        rheaders = result.headers

        if rcode == 503:
            raise error.ServiceUnavailableError(
                "The server is overloaded or not ready yet.",
                rbody,
                rcode,
                headers=rheaders,
            )
        try:
            data = json.loads(rbody)
        except BaseException:
            raise error.APIError(
                f"HTTP code {rcode} from API ({rbody})", rbody, rcode, headers=rheaders
            )
        if "error" in data or not 200 <= rcode < 300:
            raise self._handle_error_response(rbody, rcode, data, rheaders)

        return data


@validate_call
def get_default_agent_properties(
    agent_count_dict: Dict[AgentType,int],
    use_agent_properties: Optional[bool] = True
) -> List[Union[AgentAttributes,AgentProperties]]:
    """
    Function that outputs a list a AgentAttributes with minimal default settings. 
    Mainly meant to be used to pad a list of AgentAttributes to send as input to
    initialize(). This list is created by reading a dictionary containing the
    desired agent types with the agent count for each type respectively.
    If desired to use deprecate AgentAttributes instead of AgentProperties, set the
    use_agent_properties flag to False.
    """

    agent_attributes_list = []

    for agent_type, agent_count in agent_count_dict.items():
        for _ in range(agent_count):
            if use_agent_properties:
                agent_properties = AgentProperties(agent_type=agent_type)
                agent_attributes_list.append(agent_properties)
            else:
                agent_attributes_list.append(AgentAttributes.fromlist([agent_type]))

    return agent_attributes_list


@validate_call
def convert_attributes_to_properties(
    attributes: AgentAttributes
) -> AgentProperties:
    """
    Convert deprecated AgentAttributes data type to AgentProperties.
    """

    properties = AgentProperties(
        length=attributes.length,
        width=attributes.width,
        rear_axis_offset=attributes.rear_axis_offset,
        agent_type=attributes.agent_type,
        waypoint=attributes.waypoint
    )

    return properties


@validate_call
def iai_conditional_initialize(
    location: str, 
    agent_type_count: Dict[str,int],
    location_of_interest: Tuple[float] = (0,0),
    recurrent_states: Optional[List[RecurrentState]] = None,
    agent_properties: Optional[List[AgentProperties]] = None,
    states_history: Optional[List[List[AgentState]]] = None,
    traffic_light_state_history: Optional[List[TrafficLightStatesDict]] = None, 
    get_birdview: Optional[bool] = False,
    get_infractions: Optional[bool] = False,
    random_seed: Optional[int] = None,
    api_model_version: Optional[str] = None
):
    """
    A utility function to run initialize with conditional agents located at arbitrary distances from the location
    of interest. Only agents within a defined distance of the location of interest are passed to initialize as 
    conditional. Agents outisde of this distance are padded on to the initialize response, including their reccurent
    states. Recurrent states must be provided for all agents, otherwise this function behaves like :func:`initialize`.
    Please refer to the documentation for :func:`initialize` for more information.

    Arguments
    ----------
    location:
        Location name in IAI format.

    agent_type_count:
        A dictionary containing valid AgentType strings as keys mapped to an integer value specifying the desired
        number of agents of that type to initialize.

    location_of_interest:
        Optional coordinates for spawning agents with the given location as center instead of the default map center
    
    See Also
    --------
    :func:`initialize`
    """

    conditional_agent_properties = []
    conditional_agent_states_indexes = []
    conditional_recurrent_states = []
    outside_agent_states = []
    outside_agent_properties = []
    outside_recurrent_states = []

    current_agent_states = states_history[-1]
    conditional_agent_type_count = deepcopy(agent_type_count)
    for i in range(len(current_agent_states)):
        agent_state = current_agent_states[i]
        dist = math.dist(location_of_interest, (agent_state.center.x, agent_state.center.y))
        if dist < AGENT_SCOPE_FOV:
            conditional_agent_states_indexes.append(i)
            conditional_agent_properties.append(agent_properties[i])
            conditional_recurrent_states.append(recurrent_states[i])

            conditional_agent_type = agent_properties[i].agent_type
            if conditional_agent_type in conditional_agent_type_count:
                conditional_agent_type_count[conditional_agent_type] -= 1
                if conditional_agent_type_count[conditional_agent_type] <= 0:
                    del conditional_agent_type_count[conditional_agent_type]

        else:
            outside_agent_states.append(agent_state)
            outside_agent_properties.append(agent_properties[i])
            outside_recurrent_states.append(recurrent_states[i])

    if not conditional_agent_type_count: #The dictionary is empty.
        iai.logger.warning("Agent count requirement already satisfied, no new agents initialized.")

    padded_agent_properties = get_default_agent_properties(conditional_agent_type_count)
    conditional_agent_properties.extend(padded_agent_properties)

    conditional_agent_states = [[]*len(conditional_agent_states_indexes)]
    for ts in range(len(conditional_agent_states)):
        for agent_index in conditional_agent_states_indexes:
            conditional_agent_states[ts].append(states_history[ts][agent_index])

    response = invertedai.api.initialize(
        location = location,
        agent_properties = conditional_agent_properties,
        states_history = conditional_agent_states,
        location_of_interest = location_of_interest,
        traffic_light_state_history = traffic_light_state_history,
        get_birdview = get_birdview,
        get_infractions = get_infractions,
        random_seed = random_seed,
        api_model_version = api_model_version
    )
    response.agent_properties = response.agent_properties + outside_agent_properties
    response.agent_states = response.agent_states + outside_agent_states
    response.recurrent_states = response.recurrent_states + outside_recurrent_states

    return response


class APITokenAuth(AuthBase):
    def __init__(
        self, 
        api_token
    ):
        self.api_token = api_token

    def __call__(
        self, 
        r
    ):
        r.headers["x-api-key"] = self.api_token
        r.headers["api-key"] = self.api_token
        return r


def Jupyter_Render():
    import ipywidgets as widgets
    import matplotlib.pyplot as plt
    import numpy as np

    class Jupyter_Render(widgets.HBox):
        def __init__(self):
            super().__init__()
            output = widgets.Output()
            self.buffer = [np.zeros([128, 128, 3], dtype=np.uint8)]

            with output:
                self.fig, self.ax = plt.subplots(
                    constrained_layout=True, figsize=(5, 5)
                )
            self.im = self.ax.imshow(self.buffer[0])
            self.ax.set_axis_off()

            self.fig.canvas.toolbar_position = "bottom"

            self.max = 0
            # define widgets
            self.play = widgets.Play(
                value=0,
                min=0,
                max=self.max,
                step=1,
                description="Press play",
                disabled=False,
            )
            self.int_slider = widgets.IntSlider(
                value=0, min=0, max=self.max, step=1, description="Frame"
            )

            controls = widgets.HBox(
                [
                    self.play,
                    self.int_slider,
                ]
            )
            controls.layout = self._make_box_layout()
            widgets.jslink((self.play, "value"), (self.int_slider, "value"))
            output.layout = self._make_box_layout()

            self.int_slider.observe(self.update, "value")
            self.children = [controls, output]

        def update(
            self, 
            change
        ):
            self.im.set_data(self.buffer[self.int_slider.value])
            self.fig.canvas.draw()

        def add_frame(
            self, 
            frame
        ):
            self.buffer.append(frame)
            self.int_slider.max += 1
            self.play.max += 1
            self.int_slider.value = self.int_slider.max
            self.play.value = self.play.max

        def _make_box_layout(self):
            return widgets.Layout(
                border="solid 1px black",
                margin="0px 10px 10px 0px",
                padding="5px 5px 5px 5px",
            )

    return Jupyter_Render()


class IAILogger(logging.Logger):
    def __init__(
        self,
        name: str = "IAILogger",
        level: str = "WARNING",
        consoel: bool = True,
        log_file: bool = False,
    ) -> None:

        level = logging.getLevelName(level)
        log_level = level if isinstance(level, int) else 30
        super().__init__(name, log_level)
        if consoel:
            consoel_handler = logging.StreamHandler()
            self.addHandler(consoel_handler)
        if log_file:
            file_handler = logging.FileHandler("iai.log")
            self.addHandler(file_handler)

    @staticmethod
    def logfmt(
        message, 
        **params
    ):
        props = dict(message=message, **params)

        def fmt(key, val):
            # Handle case where val is a bytes or bytesarray
            if hasattr(val, "decode"):
                val = val.decode("utf-8")
            # Check if val is already a string to avoid re-encoding into ascii.
            if not isinstance(val, str):
                val = str(val)
            if re.search(r"\s", val):
                val = repr(val)
            # key should already be a string
            if re.search(r"\s", key):
                key = repr(key)
            return f"{key}={val}"

        return " ".join([fmt(key, val) for key, val in sorted(props.items())])


def rot(rot):
    """Rotate in 2d"""
    return np.array([[np.cos(rot), -np.sin(rot)], [np.sin(rot), np.cos(rot)]])


[docs]class ScenePlotter(): """ A class providing features for handling the data visualization of a scene involving IAI data. """ def __init__( self, map_image: Optional[np.array] = None, fov: Optional[float] = None, xy_offset: Optional[Tuple[float,float]] = None, static_actors: Optional[List[StaticMapActor]] = None, open_drive: Optional[str] = None, resolution: Tuple[int,int] = (640, 480), dpi: float = 100, left_hand_coordinates: bool = False, **kwargs ): """ Arguments ---------- map_image: An image used as the background for the visualization decoded from the birdview map taken from location info. fov: A single float value representing the field of view of the visualization that can be taken from location info. xy_offset: A tuple coordinate of the center of the map in metres that can be taken from location info. static_actors: A list of StaticMapActor objects representing objects such as traffic lights that can be taken from location info. open_drive: If using an ASAM OpenDRIVE format map for visualization, this string parameter is used to indicate the path to the corresponding CSV file. resolution: The desired resolution of the map image expressed as a Tuple with two integers for the width and height respectively. dpi: Dots per inch to define the level of detail in the image. left_hand_coordinates: Boolean flag dictating whether the X-coordinates of all agents and actors should be reversed to fit a left hand coordinate system. Keyword Arguments ----------------- map_image: Base image onto which the scene is visualized. This parameter must be provided if using an ASAM OpenDRIVE format map. fov: float The field of view in meters corresponding to the map_image attribute. This parameter must be provided if using an ASAM OpenDRIVE format map. xy_offset: The left-hand offset for the center of the map image. This parameter must be provided if using an ASAM OpenDRIVE format map. static_actors: A list of static actor agents (e.g. traffic lights) represented as StaticMapActor objects, in the scene. This parameter must be provided if using an ASAM OpenDRIVE format map. See Also -------- :func:`location_info` """ self._left_hand_coordinates = left_hand_coordinates self.conditional_agents = None self.agent_properties = None self.traffic_lights_history = None self.agent_states_history = None self._open_drive = open_drive self._dpi = dpi self._resolution = resolution self.map_image = map_image self.fov = fov self.xy_offset = xy_offset self.static_actors = static_actors self.traffic_lights = {static_actor.actor_id: static_actor for static_actor in self.static_actors if static_actor.agent_type == 'traffic_light'} if self._open_drive is None: self.extent = (- self.fov / 2 + self.xy_offset[0], self.fov / 2 + self.xy_offset[0]) + \ (- self.fov / 2 + self.xy_offset[1], self.fov / 2 + self.xy_offset[1]) self.traffic_light_colors = { "red": (1.0, 0.0, 0.0), "green": (0.0, 1.0, 0.0), "yellow": (1.0, 0.8, 0.0), } self.agent_c = (0.125,0.29,0.529) self.agent_ped_c = (1.0, 0.75, 0.8) self.cond_c = (0.78, 0.0, 0.0) self.dir_c = (0.392,1.0,1.0) self.v_c = (0.2, 0.75, 0.2) self.dir_lines = {} self.v_lines = {} self.actor_boxes = {} self.traffic_light_boxes = {} self.box_labels = {} self.frame_label = None self.current_ax = None self.numbers = None self.agent_face_colors = None self.agent_edge_colors = None self.reset_recording()
[docs] @validate_arguments def initialize_recording( self, agent_states: List[AgentState], agent_attributes: Optional[List[AgentAttributes]] = None, agent_properties: Optional[List[AgentProperties]] = None, traffic_light_states: Optional[Dict[int, TrafficLightState]] = None, conditional_agents: Optional[List[int]] = None ): """ Record the initial state of the scene to be visualized. This function also acts as an implicit reset of the recording and removes previous agent state, agent attribute, conditional agent, traffic light, and agent style data. Arguments ---------- agent_states: A list of AgentState objects corresponding to the initial time step to be visualized. agent_attributes: Static attributes of the agent, which don’t change over the course of a simulation. We assume every agent is a rectangle obeying a kinematic bicycle model. agent_properties: Static attributes of the agent (with the AgentProperties data type), which don’t change over the course of a simulation. We assume every agent is a rectangle obeying a kinematic bicycle model. traffic_light_states: Optional parameter containing the state of the traffic lights corresponding to the initial time step to be visualized. This parameter should only be used if the corresponding map contains traffic light static actors. conditional_agents: Optional parameter containing a list of agent IDs corresponding to conditional agents to be visualized to distinguish themselves. """ assert (agent_attributes is not None) ^ (agent_properties is not None), \ "Either agent_attributes or agent_properties is populated. Populating both or neither field is invalid." if agent_attributes is not None: self.agent_properties = [convert_attributes_to_properties(attr) for attr in agent_attributes] warnings.warn('agent_attributes is deprecated. Please use agent_properties.',category=DeprecationWarning) else: self.agent_properties = agent_properties self.agent_states_history = [agent_states] self.traffic_lights_history = [traffic_light_states] if conditional_agents is not None: self.conditional_agents = conditional_agents else: self.conditional_agents = [] self.agent_face_colors = None self.agent_edge_colors = None
[docs] def reset_recording(self): """ Explicitly reset the recording and remove the previous agent state, agent attribute, conditional agent, traffic light, and agent style data. """ self.agent_states_history = [] self.traffic_lights_history = [] self.agent_properties = None self.conditional_agents = [] self.agent_properties = None self.agent_face_colors = None self.agent_edge_colors = None
[docs] @validate_arguments def record_step( self, agent_states: List[AgentState], traffic_light_states: Optional[Dict[int, TrafficLightState]] = None ): """ Record a single timestep of scene data to be used in a visualization Arguments ---------- agent_states: A list of AgentState objects corresponding to the initial time step to be visualized. traffic_light_states: Optional parameter containing the state of the traffic lights corresponding to the initial time step to be visualized. This parameter should only be used if the corresponding map contains traffic light static actors. """ self.agent_states_history.append(agent_states) self.traffic_lights_history.append(traffic_light_states)
[docs] @validate_arguments(config=dict(arbitrary_types_allowed=True)) def plot_scene( self, agent_states: List[AgentState], agent_attributes: Optional[List[AgentAttributes]] = None, agent_properties: Optional[List[AgentProperties]] = None, traffic_light_states: Optional[Dict[int, TrafficLightState]] = None, conditional_agents: Optional[List[int]] = None, ax: Optional[Axes] = None, numbers: Optional[List[int]] = None, direction_vec: bool = True, velocity_vec: bool = False, agent_face_colors: Optional[List[Optional[Tuple[float,float,float]]]] = None, agent_edge_colors: Optional[List[Optional[Tuple[float,float,float]]]] = None ): """ Plot a single timestep of data then reset the recording. Arguments ---------- agent_states: A list of agents to be visualized in the image. agent_attributes: Static attributes of the agent, which don’t change over the course of a simulation. We assume every agent is a rectangle obeying a kinematic bicycle model. agent_properties: Static attributes of the agent (with the AgentProperties data type), which don’t change over the course of a simulation. We assume every agent is a rectangle obeying a kinematic bicycle model. traffic_light_states: Optional parameter containing the state of the traffic lights to be visualized in the image. This parameter should only be used if the corresponding map contains traffic light static actors. conditional_agents: Optional parameter containing a list of agent IDs of conditional agents to be visualized in the image to distinguish themselves. ax: A matplotlib Axes object used to plot the image. By default, an Axes object is created if a value of None is passed. numbers: A list of agent ID's that should be plotted in the image. By default this value is set to None. direction_vec: Flag to determine if a vector showing the vehicles direction should be plotted in the image. By default this flag is set to True. velocity_vec: Flag to determine if the a vector showing the vehicles velocity should be plotted in the animation. By default this flag is set to False. agent_face_colors: An optional parameter containing a list of either RGB tuples indicating the desired color of the agent with the corresponding index ID. A value of None in this list will use the default color. This value gets overwritten by the conditional agent color. agent_edge_colors: An optional parameter containing a list of either RGB tuples indicating the desired color of a border around the agent with the corresponding index ID. A value of None in this list will use the default color. This value gets overwritten by the conditional agent color. """ assert (agent_attributes is not None) ^ (agent_properties is not None), \ "Either agent_attributes or agent_properties is populated. Populating both or neither field is invalid." if agent_attributes is not None: agent_properties = [convert_attributes_to_properties(attr) for attr in agent_attributes] warnings.warn('agent_attributes is deprecated. Please use agent_properties.',category=DeprecationWarning) self.initialize_recording( agent_states=agent_states, agent_properties=agent_properties, traffic_light_states=traffic_light_states, conditional_agents=conditional_agents ) self._validate_agent_style_data( agent_face_colors=agent_face_colors, agent_edge_colors=agent_edge_colors ) self._plot_frame( idx=0, ax=ax, numbers=numbers, direction_vec=direction_vec, velocity_vec=velocity_vec, plot_frame_number=False ) self.reset_recording()
[docs] @validate_arguments(config=dict(arbitrary_types_allowed=True)) def animate_scene( self, output_name: Optional[str] = None, start_idx: int = 0, end_idx: int = -1, ax: Optional[Axes] = None, numbers: Optional[List[int]] = None, direction_vec: bool = True, velocity_vec: bool = False, plot_frame_number: bool = False, agent_face_colors: Optional[List[Optional[Tuple[float,float,float]]]] = None, agent_edge_colors: Optional[List[Optional[Tuple[float,float,float]]]] = None ) -> FuncAnimation: """ Produce an animation of sequentially recorded steps. A matplotlib animation object can be returned and/or a gif saved of the scene. Parameters ---------- output_name: File name of the gif to which the animation will be saved. start_idx: The index of the time step from which the animation will begin. By default it is assumed all recorded steps are desired to be animated. end_idx: The index of the time step from which the animation will end. By default it is assumed all recorded steps are desired to be animated. ax: A matplotlib Axes object used to plot the animation. By default, an Axes object is created if a value of None is passed. numbers: A list of agent ID's that should be plotted in the image. By default this value is set to None. direction_vec: Flag to determine if a vector showing the vehicles direction should be plotted in the animation. By default this flag is set to True. velocity_vec: Flag to determine if the a vector showing the vehicles velocity should be plotted in the animation. By default this flag is set to False. plot_frame_number: Flag to determine if the frame numbers should be plotted in the animation. By default this flag is set to False. agent_face_colors: An optional parameter containing a list of either RGB tuples indicating the desired color of the agent with the corresponding index ID. A value of None in this list will use the default color. This value gets overwritten by the conditional agent color. agent_edge_colors: An optional parameter containing a list of either RGB tuples indicating the desired color of a border around the agent with the corresponding index ID. A value of None in this list will use the default color. This value gets overwritten by the conditional agent color. """ self._validate_agent_style_data(agent_face_colors,agent_edge_colors) self._initialize_plot(ax=ax, numbers=numbers, direction_vec=direction_vec, velocity_vec=velocity_vec, plot_frame_number=plot_frame_number) end_idx = len(self.agent_states_history) if end_idx == -1 else end_idx fig = self.current_ax.figure fig.set_size_inches(self._resolution[0] / self._dpi, self._resolution[1] / self._dpi, True) def animate(i): self._update_frame_to(i) ani = FuncAnimation( fig, animate, np.arange(start_idx, end_idx), interval=100) if output_name is not None: ani.save(f'{output_name}', writer='pillow', dpi=self._dpi) return ani
def _transform_point_to_left_hand_coordinate_frame(self,x,orientation): t_x = 2*self.xy_offset[0] - x if orientation >= 0: t_orientation = -orientation + math.pi else: t_orientation = -orientation - math.pi return t_x, t_orientation def _plot_frame(self, idx, ax=None, numbers=None, direction_vec=True, velocity_vec=False, plot_frame_number=False): self._initialize_plot(ax=ax, numbers=numbers, direction_vec=direction_vec, velocity_vec=velocity_vec, plot_frame_number=plot_frame_number) self._update_frame_to(idx) def _validate_agent_style_data(self,agent_face_colors,agent_edge_colors): if self.agent_properties is not None: if agent_face_colors is not None: if len(agent_face_colors) != len(self.agent_properties): raise Exception("Number of agent face colors does not match number of agents.") if agent_edge_colors is not None: if len(agent_edge_colors) != len(self.agent_properties): raise Exception("Number of agent edge colors does not match number of agents.") self.agent_face_colors = agent_face_colors self.agent_edge_colors = agent_edge_colors def _initialize_plot(self, ax=None, numbers=None, direction_vec=True, velocity_vec=False, plot_frame_number=False): if ax is None: plt.clf() ax = plt.gca() if self._open_drive is None: ax.imshow(self.map_image, extent=self.extent) else: self._draw_xodr_map(ax) self.extent = (self.xy_offset[0] - self.fov / 2, self.xy_offset[0] + self.fov / 2) +\ (self.xy_offset[1] - self.fov / 2, self.xy_offset[1] + self.fov / 2) ax.set_xlim((self.extent[0], self.extent[1])) ax.set_ylim((self.extent[2], self.extent[3])) self.current_ax = ax self.dir_lines = {} self.v_lines = {} self.actor_boxes = {} self.traffic_light_boxes = {} self.box_labels = {} self.frame_label = None self.numbers = numbers self.direction_vec = direction_vec self.velocity_vec = velocity_vec self.plot_frame_number = plot_frame_number self._update_frame_to(0) def _get_color(self,agent_idx,color_list): c = None if color_list and color_list[agent_idx]: is_good_color_format = isinstance(color_list[agent_idx],tuple) for pc in color_list[agent_idx]: is_good_color_format *= isinstance(pc,float) and (0.0 <= pc <= 1.0) if not is_good_color_format: raise Exception(f"Expected color format is Tuple[float,float,float] with 0 <= float <= 1 but received {color_list[agent_idx]}.") c = color_list[agent_idx] return c def _update_frame_to(self, frame_idx): for i, (agent, agent_attribute) in enumerate( zip(self.agent_states_history[frame_idx], self.agent_properties) ): self._update_agent(i, agent, agent_attribute) if self.traffic_lights_history[frame_idx] is not None: for light_id, light_state in self.traffic_lights_history[frame_idx].items(): self._plot_traffic_light(light_id, light_state) if self.plot_frame_number: if self.frame_label is None: self.frame_label = self.current_ax.text( self.extent[0], self.extent[2], str(frame_idx), c="r", fontsize=18 ) else: self.frame_label.set_text(str(frame_idx)) if self._open_drive is None: self.current_ax.set_xlim(*self.extent[0:2]) self.current_ax.set_ylim(*self.extent[2:4]) def _update_agent(self, agent_idx, agent, agent_attribute): l, w = agent_attribute.length, agent_attribute.width if agent_attribute.agent_type == "pedestrian": l, w = 1.5, 1.5 x, y = agent.center.x, agent.center.y v = agent.speed psi = agent.orientation if self._left_hand_coordinates: x, psi = self._transform_point_to_left_hand_coordinate_frame(x,psi) box = np.array([ [0, 0], [l * 0.5, 0], # direction vector [0, 0], [v * 0.5, 0], # speed vector at (0.5 m / s ) / m ]) box = np.matmul(rot(psi), box.T).T + np.array([[x, y]]) if self.direction_vec: marker_offset = agent_attribute.length/4 x_data = x + marker_offset*math.cos(psi) y_data = y + marker_offset*math.sin(psi) marker_data = (3, 0, (-90+180*psi/math.pi)) if agent_idx not in self.dir_lines: self.dir_lines[agent_idx] = self.current_ax.plot( x_data, y_data, marker=marker_data, markersize=agent_attribute.width*400/self.fov, linestyle='None', c=self.dir_c ) else: self.dir_lines[agent_idx][0].set_xdata(x_data) self.dir_lines[agent_idx][0].set_ydata(y_data) self.dir_lines[agent_idx][0].set_marker(marker_data) if self.velocity_vec: if agent_idx not in self.v_lines: self.v_lines[agent_idx] = self.current_ax.plot( box[2:4, 0], box[2:4, 1], lw=1.5, c=self.v_c )[0] # plot the speed else: self.v_lines[agent_idx].set_xdata(box[2:4, 0]) self.v_lines[agent_idx].set_ydata(box[2:4, 1]) if self.numbers is not None and agent_idx in self.numbers: if agent_idx not in self.box_labels: self.box_labels[agent_idx] = self.current_ax.text( x, y, str(agent_idx), c="r", fontsize=18 ) self.box_labels[agent_idx].set_clip_on(True) else: self.box_labels[agent_idx].set_x(x) self.box_labels[agent_idx].set_y(y) lw = 1 fc = self._get_color(agent_idx,self.agent_face_colors) if fc is None: if agent_idx in self.conditional_agents: fc = self.cond_c else: fc = self.agent_c ec = self._get_color(agent_idx,self.agent_edge_colors) if ec is None: lw = 0 ec = fc rect = Rectangle( (x - l / 2, y - w / 2), l, w, angle=psi * 180 / np.pi, rotation_point='center', fc=fc, ec=ec, lw=lw ) if agent_idx in self.actor_boxes: self.actor_boxes[agent_idx].remove() self.actor_boxes[agent_idx] = rect self.actor_boxes[agent_idx].set_clip_on(True) self.current_ax.add_patch(self.actor_boxes[agent_idx]) def _plot_traffic_light(self, light_id, light_state): light = self.traffic_lights[light_id] x, y = light.center.x, light.center.y psi = light.orientation l, w = max(light.length,1.0), max(light.width,1.0) if self._left_hand_coordinates: x, psi = self._transform_point_to_left_hand_coordinate_frame(x,psi) rect = Rectangle( (x - l / 2, y - w / 2), l, w, angle=psi * 180 / np.pi, rotation_point="center", fc=self.traffic_light_colors[light_state], lw=0, ) if light_id in self.traffic_light_boxes: self.traffic_light_boxes[light_id].remove() self.current_ax.add_patch(rect) self.traffic_light_boxes[light_id] = rect def _draw_xodr_map(self, ax, extras=False): """ This function plots the parsed xodr map the `odrplot` of `esmini` is used for plotting and parsing xodr https: // esmini.github.io/ # _tools_overview """ with open(self._open_drive) as f: reader = csv.reader(f, skipinitialspace=True) positions = list(reader) ref_x = [] ref_y = [] ref_z = [] ref_h = [] lane_x = [] lane_y = [] lane_z = [] lane_h = [] border_x = [] border_y = [] border_z = [] border_h = [] road_id = [] road_id_x = [] road_id_y = [] road_start_dots_x = [] road_start_dots_y = [] road_end_dots_x = [] road_end_dots_y = [] lane_section_dots_x = [] lane_section_dots_y = [] arrow_dx = [] arrow_dy = [] current_road_id = None current_lane_id = None current_lane_section = None new_lane_section = False for i in range(len(positions) + 1): if i < len(positions): pos = positions[i] # plot road id before going to next road if i == len(positions) or ( pos[0] == "lane" and i > 0 and current_lane_id == "0" ): if current_lane_section == "0": road_id.append(int(current_road_id)) index = int(len(ref_x[-1]) / 3.0) h = ref_h[-1][index] road_id_x.append( ref_x[-1][index] + (text_x_offset * math.cos(h) - text_y_offset * math.sin(h)) ) road_id_y.append( ref_y[-1][index] + (text_x_offset * math.sin(h) + text_y_offset * math.cos(h)) ) road_start_dots_x.append(ref_x[-1][0]) road_start_dots_y.append(ref_y[-1][0]) if len(ref_x) > 0: arrow_dx.append(ref_x[-1][1] - ref_x[-1][0]) arrow_dy.append(ref_y[-1][1] - ref_y[-1][0]) else: arrow_dx.append(0) arrow_dy.append(0) lane_section_dots_x.append(ref_x[-1][-1]) lane_section_dots_y.append(ref_y[-1][-1]) if i == len(positions): break if pos[0] == "lane": current_road_id = pos[1] current_lane_section = pos[2] current_lane_id = pos[3] if pos[3] == "0": ltype = "ref" ref_x.append([]) ref_y.append([]) ref_z.append([]) ref_h.append([]) elif pos[4] == "no-driving": ltype = "border" border_x.append([]) border_y.append([]) border_z.append([]) border_h.append([]) else: ltype = "lane" lane_x.append([]) lane_y.append([]) lane_z.append([]) lane_h.append([]) else: if ltype == "ref": ref_x[-1].append(float(pos[0])) ref_y[-1].append(float(pos[1])) ref_z[-1].append(float(pos[2])) ref_h[-1].append(float(pos[3])) elif ltype == "border": border_x[-1].append(float(pos[0])) border_y[-1].append(float(pos[1])) border_z[-1].append(float(pos[2])) border_h[-1].append(float(pos[3])) else: lane_x[-1].append(float(pos[0])) lane_y[-1].append(float(pos[1])) lane_z[-1].append(float(pos[2])) lane_h[-1].append(float(pos[3])) # plot driving lanes in blue for i in range(len(lane_x)): ax.plot(lane_x[i], lane_y[i], linewidth=1.0, color="#222222") # plot road ref line segments for i in range(len(ref_x)): ax.plot(ref_x[i], ref_y[i], linewidth=2.0, color="#BB5555") # plot border lanes in gray for i in range(len(border_x)): ax.plot(border_x[i], border_y[i], linewidth=1.0, color="#AAAAAA") if extras: # plot red dots indicating lane dections for i in range(len(lane_section_dots_x)): ax.plot( lane_section_dots_x[i], lane_section_dots_y[i], "o", ms=4.0, color="#BB5555", ) for i in range(len(road_start_dots_x)): # plot a yellow dot at start of each road ax.plot( road_start_dots_x[i], road_start_dots_y[i], "o", ms=5.0, color="#BBBB33", ) # and an arrow indicating road direction ax.arrow( road_start_dots_x[i], road_start_dots_y[i], arrow_dx[i], arrow_dy[i], width=0.1, head_width=1.0, color="#BB5555", ) # plot road id numbers for i in range(len(road_id)): ax.text( road_id_x[i], road_id_y[i], road_id[i], size=text_size, ha="center", va="center", color="#3333BB", ) return None