Source code for invertedai.api.blame

import time
from typing import List, Optional, Tuple, Dict
from pydantic import BaseModel, validate_call

import invertedai as iai
from invertedai.api.config import TIMEOUT, should_use_mock_api
from invertedai.api.mock import (
    get_mock_birdview,
    get_mock_agents_at_fault,
    get_mock_blamed_reasons,
    get_mock_confidence_score,
)
from invertedai.error import APIConnectionError, InvalidInput
from invertedai.common import (
    AgentState,
    Image,
    AgentAttributes,
    TrafficLightStatesDict,
)

[docs]class BlameResponse(BaseModel): """ Response returned from an API call to :func:`iai.blame`. """ agents_at_fault: Optional[Tuple[int, ...]] #: A tuple containing all agents predicted to be at fault. If empty, BLAME has predicated no agents are at fault. reasons: Optional[Dict[int, List[str]]] #: A dictionary with agent IDs as keys corresponding to "agents_at_fault" paired with a list of reasons why the keyed agent is at fault (e.g. traffic_light_violation). confidence_score: Optional[float] #: Float value between [0,1] indicating BLAME's confidence in the response where 0.0 represents the minimum confidence and 1.0 represents maximum. birdviews: Optional[List[Image]] #: If `get_birdviews` was set, this contains the resulting image.
[docs]@validate_call def blame( location: str, colliding_agents: Tuple[int, int], agent_state_history: List[List[AgentState]], agent_attributes: List[AgentAttributes], traffic_light_state_history: Optional[List[TrafficLightStatesDict]] = None, get_reasons: bool = False, get_confidence_score: bool = False, get_birdviews: bool = False ) -> BlameResponse: """ Parameters ---------- location: Location name in IAI format. colliding_agents: Two agents involved in the collision. These integers should correspond to the indices of the relevant agents in the lists within agent_state_history. agent_state_history: Lists containing AgentState objects for every agent within the scene (up to 100 agents) for each time step within the relevant continuous sequence immediately preceding the collision. The list of AgentState objects should include the first time step of the collision and no time steps afterwards. The lists of AgentState objects preceding the collision should capture enough of the scenario context before the collision for BLAME to analyze and assign fault. For best results it is recommended to input 20-50 time steps of 0.1s each preceding the collision. Each AgentState state must include x: [float], y: [float] coordinates in meters, orientation: [float] in radians with 0 pointing along the positive x axis and pi/2 pointing along the positive y axis, and speed: [float] in m/s. agent_attributes: List of static AgentAttribute objects for all agents. Each agent requires, length: [float], width: [float], and rear_axis_offset: [float] all in meters. traffic_light_state_history: List of TrafficLightStatesDict objects containing the state of all traffic lights for every time step. The dictionary keys are the traffic_light IDs and value is the state, i.e., 'green', 'yellow', 'red', or None. get_reasons: Whether to return the reasons regarding why each agent was blamed. get_confidence_score: Whether to return how confident the BLAME is in its response. get_birdviews: Whether to return the image visualizing the collision case. This is very slow and should only be used for debugging. See Also -------- :func:`drive` :func:`initialize` :func:`location_info` :func:`light` """ if len(agent_state_history[0]) != len(agent_attributes): raise InvalidInput("Incompatible Number of Agents in either 'agent_state_history' or 'agent_attributes'.") if should_use_mock_api(): agents_at_fault = get_mock_agents_at_fault() birdviews = [get_mock_birdview()] reasons = get_mock_blamed_reasons() confidence_score = get_mock_confidence_score() response = BlameResponse( agents_at_fault=agents_at_fault, birdviews=birdviews, reasons=reasons, confidence_score=confidence_score ) return response model_inputs = dict( location=location, colliding_agents=colliding_agents, agent_state_history=[[state.tolist() for state in agent_states] for agent_states in agent_state_history], agent_attributes=[attr.tolist() for attr in agent_attributes], traffic_light_state_history=traffic_light_state_history, get_reasons=get_reasons , get_confidence_score=get_confidence_score, get_birdviews=get_birdviews ) start = time.time() timeout = TIMEOUT while True: try: response = iai.session.request(model="blame", data=model_inputs) response = BlameResponse( agents_at_fault=response["agents_at_fault"], reasons=response["reasons"], confidence_score=response["confidence_score"], birdviews=[Image.fromval(birdview) for birdview in response["birdviews"]] ) return response except APIConnectionError as e: iai.logger.warning("Retrying") if ( timeout is not None and time.time() > start + timeout ) or not e.should_retry: raise e
@validate_call async def async_blame( location: str, colliding_agents: Tuple[int, int], agent_state_history: List[List[AgentState]], agent_attributes: List[AgentAttributes], traffic_light_state_history: Optional[List[TrafficLightStatesDict]] = None, get_reasons: bool = False, get_confidence_score: bool = False, get_birdviews: bool = False ) -> BlameResponse: """ A light async version of :func:`blame` """ model_inputs = dict( location=location, colliding_agents=colliding_agents, agent_state_history=[[state.tolist() for state in agent_states] for agent_states in agent_state_history], agent_attributes=[attr.tolist() for attr in agent_attributes], traffic_light_state_history=traffic_light_state_history, get_reasons=get_reasons, get_confidence_score=get_confidence_score, get_birdviews=get_birdviews ) response = await iai.session.async_request(model="blame", data=model_inputs) response = BlameResponse( agents_at_fault=response["agents_at_fault"], reasons=response["reasons"], confidence_score=response["confidence_score"], birdviews=[Image.fromval(birdview) for birdview in response["birdviews"]] ) return response