GNN¶
Note
Some of the below documentation has been created with the assistance of generative AI and so should be taken with a grain of salt.
dataloader.py¶
This file converts map, backward Dijkstra, and path data into graph-structured data for training Graph Neural Networks (GNNs). It handles the creation of node features, edge connections, and labels for multi-agent path finding tasks.
Key operations:
Convert raw MAPF data into PyTorch Geometric Data objects
Generate local observation windows for each agent
Create graph structures with appropriate node and edge features
Process and normalize features for neural network training
Saves graphs to folder for use
Utility Functions¶
apply_masks¶
- apply_masks(data_len, curdata)¶
Applies train/test masks to split data for training and evaluation.
- Parameters:
data_len (int) – Length of the data
curdata (Data) – PyTorch Geometric Data object
- Returns:
Data object with train_mask and test_mask attributes
- Return type:
Data
normalize_graph_data¶
- normalize_graph_data(data, k, edge_normalize='k', bd_normalize='center')¶
Normalizes edge attributes and backward Dijkstra values in the graph data.
- Parameters:
data (Data) – PyTorch Geometric Data object
k (int) – Window size for local observation
edge_normalize (str) – Method for normalizing edge attributes
bd_normalize (str) – Method for normalizing backward Dijkstra values
- Returns:
Normalized Data object
- Return type:
Data
get_bd_prefs¶
- get_bd_prefs(pos_list, bds, range_num_agents)¶
Gets movement preferences based on backward Dijkstra values.
- Parameters:
pos_list (numpy.ndarray) – Agent positions of shape (N,2)
bds (numpy.ndarray) – Backward Dijkstra values of shape (N,W,H)
range_num_agents (numpy.ndarray) – Range of agent indices
- Returns:
Agent movement preferences
- Return type:
numpy.ndarray
create_data_object¶
- create_data_object(pos_list, bd_list, grid, k, m, goal_locs, extra_layers, bd_pred, labels=np.array([]), debug_checks=False)¶
Creates a PyTorch Geometric Data object from raw MAPF data. The usable extra information is agent_locations, agent_goal, near_goal_info, and at_goal_grid as well as bds which are almost always used. Agent locations provides the locations of the agents nearby, agent goal provides the goal locations of the agents nearby, near_goal_info provides the agents near the goal, and at_goal_grid provides a binary grid of the agents if they are at their goal.
- Parameters:
pos_list (numpy.ndarray) – Agent positions of shape (N,2)
bd_list (numpy.ndarray) – Backward Dijkstra values of shape (N,W,H)
grid (numpy.ndarray) – Grid map of shape (W,H)
k (int) – Window size for local observation
m (int) – Number of closest neighbors to consider
goal_locs (numpy.ndarray) – Goal locations of shape (N,2)
extra_layers (list) – List of additional features to include
bd_pred (str) – Whether to include backward Dijkstra predictions
labels (numpy.ndarray) – Ground truth labels
debug_checks (bool) – Whether to perform debug assertions
- Returns:
PyTorch Geometric Data object
- Return type:
Data
Dataset Classes¶
MyOwnDataset¶
- class MyOwnDataset(Dataset)¶
A PyTorch Geometric Dataset for processing and loading MAPF data as graphs.
- __init__(self, mapNpzFile, bdNpzFolder, pathNpzFolder, processedOutputFolder, num_cores, k, m, extra_layers, bd_pred, num_per_pt)¶
- Parameters:
mapNpzFile (str) – Path to NPZ file containing map data
bdNpzFolder (str) – Path to folder containing backward Dijkstra NPZ files
pathNpzFolder (str) – Path to folder containing path NPZ files
processedOutputFolder (str) – Path to output folder for processed data
num_cores (int) – Number of CPU cores to use for parallel processing
k (int) – Window size for local observation
m (int) – Number of closest neighbors to consider
extra_layers (list) – List of additional features to include
bd_pred (str) – Whether to include backward Dijkstra predictions
num_per_pt (int) – Number of graphs per PT file
- processed_dir(self) str ¶
- Returns:
Path to the processed directory
- Return type:
str
- load_status_data(self)¶
Loads a CSV file tracking the status of processed files.
- raw_file_names(self)¶
- Returns:
List of raw file paths
- Return type:
list
- has_process(self) bool ¶
- Returns:
Whether the dataset has a process method
- Return type:
bool
- has_download(self) bool ¶
- Returns:
Whether the dataset has a download method
- Return type:
bool
- create_and_save_graph(self, idx, time_instance)¶
Creates and saves a graph from a time instance.
- Parameters:
idx (int) – Index of the graph
time_instance (tuple) – Tuple of (pos_list, labels, bd_list, grid, goal_locs)
- Returns:
PyTorch Geometric Data object
- Return type:
Data
- custom_process(self)¶
Process all data files and tracks processing status.
- len(self)¶
- Returns:
Number of graphs in the dataset
- Return type:
int
- get(self, idx)¶
Retrieves a graph by index.
- Parameters:
idx (int) – Index of the graph to retrieve
- Returns:
PyTorch Geometric Data object
- Return type:
Data
Command Line Interface¶
The dataloader can be executed as a command-line tool:
python -m gnn.dataloader
--mapNpzFile=data_collection/data/benchmark_data/constant_npzs/all_maps.npz
--bdNpzFolder=data_collection/data/benchmark_data/constant_npzs
--pathNpzFolder=data_collection/data/logs/EXP_Test3/iter0/eecbs_npzs
--processedFolder=data_collection/data/logs/EXP_Test3/iter0/processed
--k=4
--m=5
--extra_layers=agent_locations,agent_goal,at_goal_grid
--bd_pred=1
trainer.py¶
This file trains a GNN model on MAPF data using PyTorch Geometric. It uses a SageConv model to predict agent actions.
There are some hyperparameter modifications that can be done to training. This includes k, m, bd_pred, and extra_layers. The k and m are the same as in the dataloader. The bd_pred is whether to include backward Dijkstra predictions. The extra_layers are the extra features to include in the model. Note that changing these may cause recreation of the pt files which is a time and space intensive process.
Command Line Interface
The trainer can be called from the command-line:
python -m gnn.trainer --exp_folder=data_collection/data/logs/EXP_Test --experiment=exp0 --iternum=0 --num_cores=4
--processedFolders=data_collection/data/logs/EXP_Test3/iter0/processed
--mapNpzFile=data_collection/data/benchmark_data/constant_npzs/all_maps.npz
--bdNpzFolder=data_collection/data/benchmark_data/constant_npzs
--pathNpzFolders=data_collection/data/logs/EXP_Test/iter0/eecbs_npzs