l5kit.prediction.vectorized.safepathnet_model module¶
- class l5kit.prediction.vectorized.safepathnet_model.SafePathNetModel(history_num_frames_ego: int, history_num_frames_agents: int, num_timesteps: int, weights_scaling: List[float], criterion: torch.nn.modules.module.Module, disable_other_agents: bool, disable_map: bool, disable_lane_boundaries: bool, agent_num_trajectories: int, max_num_agents: int = 30, cost_prob_coeff: float = 0.01)¶
Bases:
torch.nn.modules.module.Module
SafePathNet model - Unified prediction and planning model with multimodal output.
- embed_polyline(features: torch.Tensor, mask: torch.Tensor) Tuple[torch.Tensor, torch.Tensor] ¶
Embeds the inputs, generates the positional embedding and calls the local subgraph.
- Parameters
features – input features
mask – availability mask
- Tensor features
[batch_size, num_elements, max_num_points, max_num_features]
- Tensor mask
[batch_size, num_elements, max_num_points]
:return tuple of local subgraph output, (in-)availability mask
- forward(data_batch: Dict[str, torch.Tensor]) Dict[str, torch.Tensor] ¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- model_call(agents_polys: torch.Tensor, static_polys: torch.Tensor, agents_avail: torch.Tensor, static_avail: torch.Tensor, type_embedding: torch.Tensor, lane_bdry_len: int) Tuple[torch.Tensor, torch.Tensor] ¶
Encapsulates calling the global_head and preparing needed data.
- Parameters
agents_polys – dynamic elements - i.e. vectors corresponding to agents
static_polys – static elements - i.e. vectors corresponding to map elements
agents_avail – availability of agents
static_avail – availability of map elements
type_embedding – agent type embeddings
lane_bdry_len – number of map elements
- training: bool¶