"""
Temporal Graph Neural Network predictor for spectral metric evolution.
Replaces the flat LSTM with a GCN encoder that operates on the actual
interbank graph at each timestep, producing graph-level embeddings that
feed into a temporal LSTM for spectral metric forecasting.
Architecture: GCNConv layers → global_mean_pool → LSTM → FC → [λ₂, gap, ρ]
"""
from __future__ import annotations
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data
logger = logging.getLogger(__name__)
TARGET_NAMES = ["lambda_2", "spectral_gap", "spectral_radius"]
# Node features extracted per bank at each ABM step
NODE_FEATURE_NAMES = ["CET1_ratio", "LCR", "NSFR", "total_assets", "is_stressed"]
[docs]
class GNNEncoder(nn.Module):
"""Multi-layer GCN that produces a graph-level embedding."""
def __init__(self, in_channels: int, hidden_channels: int, num_gcn_layers: int = 3,
dropout: float = 0.1):
super().__init__()
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
self.convs.append(GCNConv(in_channels, hidden_channels))
self.bns.append(nn.BatchNorm1d(hidden_channels))
for _ in range(num_gcn_layers - 1):
self.convs.append(GCNConv(hidden_channels, hidden_channels))
self.bns.append(nn.BatchNorm1d(hidden_channels))
self.dropout = nn.Dropout(dropout)
self.act = nn.ReLU()
[docs]
def forward(self, x: torch.Tensor, edge_index: torch.Tensor,
edge_weight: Optional[torch.Tensor] = None,
batch: Optional[torch.Tensor] = None) -> torch.Tensor:
for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
h = conv(x, edge_index, edge_weight)
h = bn(h)
h = self.act(h)
if i < len(self.convs) - 1:
h = self.dropout(h)
x = h
if batch is None:
batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
return global_mean_pool(x, batch) # [num_graphs, hidden]
[docs]
class TemporalGNN(nn.Module):
"""GNN encoder + LSTM for temporal graph sequences → spectral predictions."""
def __init__(self, node_features: int, hidden_dim: int = 64,
output_dim: int = 3, num_gcn_layers: int = 3,
num_lstm_layers: int = 2, dropout: float = 0.1):
super().__init__()
self.gnn = GNNEncoder(node_features, hidden_dim, num_gcn_layers, dropout)
self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_lstm_layers,
batch_first=True, dropout=dropout if num_lstm_layers > 1 else 0.0)
self.fc = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, output_dim),
)
self.hidden_dim = hidden_dim
self.num_lstm_layers = num_lstm_layers
[docs]
def forward(self, graph_sequences: List[List[Data]]) -> torch.Tensor:
from torch_geometric.data import Batch
batch_size = len(graph_sequences)
seq_len = len(graph_sequences[0])
device = next(self.parameters()).device
# Batch all graphs across batch × seq_len for efficient encoding
all_graphs = [g for seq in graph_sequences for g in seq]
batched = Batch.from_data_list(all_graphs)
all_emb = self.gnn(batched.x, batched.edge_index, batched.edge_weight,
batched.batch) # [batch_size * seq_len, hidden]
embeddings = all_emb.view(batch_size, seq_len, self.hidden_dim)
h0 = torch.zeros(self.num_lstm_layers, batch_size, self.hidden_dim, device=device)
c0 = torch.zeros(self.num_lstm_layers, batch_size, self.hidden_dim, device=device)
out, _ = self.lstm(embeddings, (h0, c0))
return self.fc(out[:, -1, :])
[docs]
def count_parameters(self) -> int:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
[docs]
class GNNPredictor:
"""Drop-in replacement for SCGPredictor using a temporal GNN.
Supports a progress_callback(epoch, total_epochs, train_loss, test_loss)
for real-time UI updates during training.
"""
def __init__(self, seq_len: int = 10, hidden_dim: int = 64,
num_gcn_layers: int = 3, num_lstm_layers: int = 2,
dropout: float = 0.1):
self.seq_len = seq_len
self.hidden_dim = hidden_dim
self.num_gcn_layers = num_gcn_layers
self.num_lstm_layers = num_lstm_layers
self.dropout = dropout
self.model: Optional[TemporalGNN] = None
self._trained = False
self.train_metrics: Dict[str, Any] = {}
self.test_metrics: Dict[str, Any] = {}
self.test_actuals = np.array([])
self.test_predictions = np.array([])
self.training_history: List[Dict[str, float]] = []
self._feat_mean: Optional[np.ndarray] = None
self._feat_std: Optional[np.ndarray] = None
self._target_mean: Optional[np.ndarray] = None
self._target_std: Optional[np.ndarray] = None
def _snapshot_to_data(self, snap: Dict[str, Any]) -> Data:
"""Convert snapshot dict to PyG Data object with normalisation."""
x = torch.tensor(snap["node_features"], dtype=torch.float32)
if self._feat_mean is not None:
feat_std = self._feat_std.copy()
feat_std[feat_std < 1e-8] = 1.0
x = (x - torch.tensor(self._feat_mean, dtype=torch.float32)) / \
torch.tensor(feat_std, dtype=torch.float32)
ei = torch.tensor(snap["edge_index"], dtype=torch.long)
ew = torch.tensor(snap["edge_weight"], dtype=torch.float32) if len(snap["edge_weight"]) > 0 else None
if ew is not None and ew.numel() > 0 and ew.max() > 0:
ew = ew / ew.max()
return Data(x=x, edge_index=ei, edge_weight=ew)
def _build_sequences(self, snapshots: List[Dict]) -> Tuple[List[List[Dict]], np.ndarray]:
"""Build (graph_sequence, target) pairs from snapshot list."""
sequences, targets = [], []
for i in range(len(snapshots) - self.seq_len):
sequences.append(snapshots[i: i + self.seq_len])
tgt = snapshots[i + self.seq_len]
targets.append([tgt["targets"][k] for k in TARGET_NAMES])
return sequences, np.array(targets, dtype=np.float32)
@staticmethod
def _compute_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, Any]:
mse = float(np.mean((y_true - y_pred) ** 2))
r2_per = {}
for i, name in enumerate(TARGET_NAMES):
ss_res = np.sum((y_true[:, i] - y_pred[:, i]) ** 2)
ss_tot = np.sum((y_true[:, i] - np.mean(y_true[:, i])) ** 2)
if ss_tot < 1e-8:
r2_per[name] = 1.0 if ss_res < 1e-8 else 0.0
else:
r2_per[name] = float(max(-1.0, 1.0 - ss_res / ss_tot))
return {"mse": mse, "r2": float(np.mean(list(r2_per.values()))), "r2_per_target": r2_per}
[docs]
def train(
self,
snapshots: List[Dict[str, Any]],
epochs: int = 200,
lr: float = 3e-3,
test_fraction: float = 0.2,
progress_callback: Optional[Callable[[int, int, float, Optional[float]], None]] = None,
) -> float:
"""Train the temporal GNN. Returns final train loss.
Parameters
----------
progress_callback : callable(epoch, total_epochs, train_loss, test_loss_or_None)
Called every 5 epochs for UI progress updates.
"""
sequences, targets = self._build_sequences(snapshots)
n_seqs = len(sequences)
if n_seqs < 10:
raise ValueError(f"Need >= {self.seq_len + 10} snapshots, got {len(snapshots)} "
f"({n_seqs} sequences).")
split = max(5, int(n_seqs * (1 - test_fraction)))
train_seqs, test_seqs = sequences[:split], sequences[split:]
y_train, y_test = targets[:split], targets[split:]
# Fit normalisation on train
all_feats = np.concatenate([s["node_features"] for seq in train_seqs for s in seq], axis=0)
self._feat_mean = all_feats.mean(axis=0)
self._feat_std = all_feats.std(axis=0)
self._target_mean = y_train.mean(axis=0)
self._target_std = y_train.std(axis=0)
self._target_std[self._target_std < 1e-8] = 1.0
y_train_s = (y_train - self._target_mean) / self._target_std
train_graph_seqs = [[self._snapshot_to_data(s) for s in seq] for seq in train_seqs]
y_train_t = torch.tensor(y_train_s, dtype=torch.float32)
# Also prepare test graphs if we have them (for progress reporting)
test_graph_seqs = None
if len(test_seqs) > 0:
test_graph_seqs = [[self._snapshot_to_data(s) for s in seq] for seq in test_seqs]
n_feat = len(NODE_FEATURE_NAMES)
self.model = TemporalGNN(
n_feat, self.hidden_dim, len(TARGET_NAMES),
self.num_gcn_layers, self.num_lstm_layers, self.dropout,
)
n_params = self.model.count_parameters()
logger.info("TemporalGNN: %d params, %d GCN layers, %d LSTM layers, hidden=%d",
n_params, self.num_gcn_layers, self.num_lstm_layers, self.hidden_dim)
optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
loss_fn = nn.MSELoss()
batch_size = min(32, len(train_graph_seqs))
self.model.train()
self.training_history = []
final_loss = 0.0
for epoch in range(epochs):
perm = np.random.permutation(len(train_graph_seqs))
epoch_loss = 0.0
n_batches = 0
for start in range(0, len(perm), batch_size):
idx = perm[start: start + batch_size]
batch_seqs = [train_graph_seqs[i] for i in idx]
batch_y = y_train_t[idx]
optimizer.zero_grad()
pred = self.model(batch_seqs)
loss = loss_fn(pred, batch_y)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
optimizer.step()
epoch_loss += loss.item()
n_batches += 1
scheduler.step()
final_loss = epoch_loss / max(n_batches, 1)
# Progress reporting every 5 epochs
if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
test_loss = None
if test_graph_seqs is not None:
self.model.eval()
with torch.no_grad():
y_test_s = (y_test - self._target_mean) / self._target_std
test_pred_s = self.model(test_graph_seqs).numpy()
test_loss = float(np.mean((test_pred_s - y_test_s) ** 2))
self.model.train()
self.training_history.append({
"epoch": epoch + 1,
"train_loss": final_loss,
"test_loss": test_loss,
"lr": optimizer.param_groups[0]["lr"],
})
if progress_callback is not None:
progress_callback(epoch + 1, epochs, final_loss, test_loss)
self._trained = True
# Final evaluation
self.model.eval()
with torch.no_grad():
train_pred_s = self.model(train_graph_seqs).numpy()
train_pred = train_pred_s * self._target_std + self._target_mean
self.train_metrics = self._compute_metrics(y_train, train_pred)
if test_graph_seqs is not None and len(test_seqs) > 0:
with torch.no_grad():
test_pred_s = self.model(test_graph_seqs).numpy()
test_pred = test_pred_s * self._target_std + self._target_mean
self.test_metrics = self._compute_metrics(y_test, test_pred)
self.test_actuals = y_test
self.test_predictions = test_pred
else:
self.test_metrics = {"mse": 0.0, "r2": 0.0, "r2_per_target": {}}
logger.info(
"GNNPredictor trained: %d params, %d train / %d test, "
"train_mse=%.6f, test_mse=%.6f, test_r2=%.4f",
n_params, len(train_seqs), len(test_seqs),
self.train_metrics["mse"], self.test_metrics["mse"], self.test_metrics["r2"],
)
return final_loss
[docs]
def predict(self, recent_snapshots: List[Dict[str, Any]], steps: int = 20) -> List[Dict[str, float]]:
"""Autoregressively predict spectral metrics forward."""
if not self._trained or self.model is None:
raise RuntimeError("Model not trained.")
if len(recent_snapshots) < self.seq_len:
raise ValueError(f"Need >= {self.seq_len} snapshots, got {len(recent_snapshots)}.")
window = list(recent_snapshots[-self.seq_len:])
self.model.eval()
predictions: List[Dict[str, float]] = []
with torch.no_grad():
for _ in range(steps):
graph_seq = [self._snapshot_to_data(s) for s in window]
pred_s = self.model([graph_seq]).numpy()[0]
pred = pred_s * self._target_std + self._target_mean
pred_dict = {TARGET_NAMES[i]: float(pred[i]) for i in range(len(TARGET_NAMES))}
predictions.append(pred_dict)
new_snap = {
"node_features": window[-1]["node_features"].copy(),
"edge_index": window[-1]["edge_index"].copy(),
"edge_weight": window[-1]["edge_weight"].copy(),
"targets": pred_dict,
}
window = window[1:] + [new_snap]
return predictions