Skip to content

tgm-team/tgm

Repository files navigation

image

Efficient and Modular ML on Temporal Graphs

Read Our Docs» Read Our Paper»

Stars PyPI Downloads Tests Docs Coverage

About The Project

TGM is a research library for temporal graph learning, designed to accelerate training on dynamic graphs while enabling rapid prototyping of new methods. It provides a unified abstraction for both discrete and continuous-time graphs, supporting diverse tasks across link, node, and graph-level prediction.

Important

TGM is in beta, and may introduce breaking changes.

Highlights

  • Unified Temporal API: supports both continuous-time and discrete-time graphs, and graph discretization
  • Efficiency: ~7.8× faster training and ~175× faster discretization vs. existing research libraries
  • Research-Oriented: modular hook framework standardizes workflows for link, node, and graph-level tasks
  • Datasets: built-in support for popular datasets (e.g., TGB1)

Supported Methods

To request a method for prioritization, please open an issue or join the discussion.

Status Methods
Implemented EdgeBank2, GCN3, GC-LSTM4, GraphMixer5, TGAT6, TGN7, DygFormer8, TPNet9
Planned TNCN10, DyGMamba11, NAT12

Installation

From Source (recommended)

pip install git+https://github.com/tgm-team/tgm.git@main

From PyPi

pip install tgm-lib

Note

Windows is not directly tested in our CI. Additional setup may be required. For instance, for cuda:12.4, you will need to manually install the appropriate PyTorch wheels:

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

Quick Tour for New Users

System Design Overview

image image

TGM is organized as a three-layer architecture:

  1. Data Layer

    • Immutable, time-sorted coordinate-format graph storage with lightweight, concurrency-safe graph views.
    • Efficient time-based slicing and binary search over timestamps, enabling fast recent-neighbor retrieval.
    • Supports continuous-time and discrete-time loading, with vectorized snapshot creation.
    • Extensible backend allows alternative storage layouts for future models.
  2. Execution Layer

    • The DataLoader is responsible for iterating through the temporal graph data stream by time or events based on the user-defined granularity.
    • HookManager orchestrates transformations during data loading (e.g., temporal neighbor sampling), dynamically adding relevant attributes to the Batch yielded by the dataloader.
    • Hooks can be combined and registered under specific conditions (analytics, training, etc.).
    • Pre-defined recipes simplify common setups (e.g. TGB link prediction) and prevent common pitfalls (e.g., mismanaging negatives).
  3. ML Layer

    • Materializes batches directly on-device for model computation.
    • Supports node-, link-, and graph-level prediction.

Tip

Check out our paper for technical details.

Minimal Example

Here’s a basic example demonstrating how to train TGCN for dynamic node property prediction on tgbn-trade:

import torch
import torch.nn as nn
import torch.nn.functional as F

from tgm import DGraph, DGBatch
from tgm.data import DGData, DGDataLoader
from tgm.nn import TGCN, NodePredictor

# Load TGB data splits
train_data, val_data, test_data = DGData.from_tgb("tgbn-trade").split()

# Construct a DGraph and setup iteration by yearly ('Y') snapshots
train_dg = DGraph(train_data)
train_loader = DGDataLoader(train_dg, batch_unit="Y")

# tgbn-trade has no static node features, so we create Gaussian ones (dim=64)
static_node_feats = torch.randn((train_dg.num_nodes, 64))

class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_dim: int, embed_dim: int) -> None:
        super().__init__()
        self.recurrent = TGCN(in_channels=node_dim, out_channels=embed_dim)
        self.linear = nn.Linear(embed_dim, embed_dim)

    def forward(
        self, batch: DGBatch, node_feat: torch.tensor, h: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        edge_index = torch.stack([batch.src, batch.dst], dim=0)
        h_0 = self.recurrent(node_feat, edge_index, H=h)
        z = F.relu(h_0)
        z = self.linear(z)
        return z, h_0

# Initialize our model and optimizer
encoder = RecurrentGCN(node_dim=static_node_feats.shape[1], embed_dim=128)
decoder = NodePredictor(in_dim=128, out_dim=train_dg.dynamic_node_feats_dim)
opt = torch.optim.Adam(set(encoder.parameters()) | set(decoder.parameters()), lr=0.001)

# Training loop
h_0 = None
for batch in train_loader:
    opt.zero_grad()
    y_true = batch.dynamic_node_feats
    if y_true is None:
        continue

    z, h_0 = encoder(batch, static_node_feats, h_0)
    z_node = z[batch.node_ids]
    y_pred = decoder(z_node)

    loss = F.cross_entropy(y_pred, y_true)
    loss.backward()
    opt.step()
    h_0 = h_0.detach()

Running Pre-packaged Examples

TGM includes pre-packaged example scripts to help you get started quickly. The examples require extra dependencies beyond the core library.

To get started, follow our installation from source instructions and then install the additional dependencies:

pip install -e .[examples]

After installing the dependencies, you can run any of our examples. For instance, TGAT dynamic link prediction on tgbl-wiki:

python examples/linkproppred/tgat.py --dataset tgbl-wiki --device cuda

Note

By default, our link prediction examples default to tgbl-wiki, and node prediction use tgbn-trade. Examples run on CPU by default; use the --device flag to override this as shown above.

Next steps

Citation

If you use TGM in your work, please cite our paper:

@misc{chmura2025tgm,
  title  = {TGM: A Modular and Efficient Library for Machine Learning on Temporal Graphs},
  author = {Chmura, Jacob and Huang, Shenyang and Ngo, Tran Gia Bao and Parviz, Ali and Poursafaei, Farimah and Leskovec, Jure and Bronstein, Michael and Rabusseau, Guillaume and Fey, Matthias and Rabbany, Reihaneh},
  year   = {2025},
  note   = {arXiv:2510.07586}
}

Contributing

We welcome contributions. If you encounter problems or would like to propose a new features, please open an issue and join the discussion. For details on contributing to TGM, see our contribution guide.

(back to top)

References

Footnotes

  1. Temporal Graph Benchmark

  2. Towards Better Evaluation for Dynamic Link Prediction

  3. Semi-Supervised Classification with Graph Convolutional Networks

  4. GC-LSTM: Graph Convolution Embedded LSTM for Dynamic Link Prediction

  5. Do We Really Need Complicated Model Architectures For Temporal Networks?

  6. Inductive Representation Learning on Temporal Graphs

  7. Temporal Graph Networks for Deep Learning on Dynamic Graphs

  8. Towards Better Dynamic Graph Learning: New Architecture and Unified Library

  9. Improving Temporal Link Prediction via Temporal Walk Matrix Projection

  10. Efficient Neural Common Neighbor for Temporal Graph Link Prediction

  11. DyGMamba: Efficiently Modeling Long-Term Temporal Dependency on Continuous-Time Dynamic Graphs with State Space Models

  12. Neighborhood-aware Scalable Temporal Network Representation Learning