Source code for dashboard.gnn_export

"""
GNN dataset exporter for the SCR Financial Networks dashboard.

Builds a graph dataset from the current simulation state + LLM-fetched bank
features and writes it to disk in multiple formats:

  nodes.csv          — node feature matrix (one row per bank)
  edges.csv          — directed edge list with weights
  graph_data.json    — full graph as JSON (PyG-loadable via custom loader)
  pyg_data.pt        — torch_geometric.data.Data object (if PyG installed)
  metadata.json      — feature names, bank labels, dataset provenance

Usage::

    from dashboard.gnn_export import build_and_export
    info = build_and_export(gnn_features, output_dir="data/gnn_datasets")
"""

from __future__ import annotations

import json
import logging
import os
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Tuple

import numpy as np

from dashboard.llm import GNN_NODE_FEATURES
from dashboard.data_loader import BANK_LABELS, BANK_COUNTRIES

logger = logging.getLogger(__name__)

# Features derived from the simulation (not from the LLM web-fetch).
# These are always available and supplement the LLM-fetched features.
_SIM_FEATURES = ["CET1_ratio", "LCR", "NSFR", "total_assets"]

# Binary node labels (for supervised GNN tasks).
_LABEL_FIELDS = ["solvent", "liquid"]


def _country_encoding(bank_ids: List[str]) -> Dict[str, int]:
    """Map unique country codes to integer labels."""
    countries = sorted({BANK_COUNTRIES.get(b, "XX") for b in bank_ids})
    return {b: countries.index(BANK_COUNTRIES.get(b, "XX")) for b in bank_ids}


[docs] def build_graph_tensors( bank_ids: List[str], node_data: Dict[str, Dict[str, Any]], edges: List[Dict[str, Any]], feature_cols: Optional[List[str]] = None, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[str]]: """ Convert node/edge dicts into numpy arrays ready for GNN consumption. Returns ------- X : float32 [N, F] — node feature matrix (NaN-imputed with column mean) edge_index : int64 [2, E] — source/target index pairs edge_attr : float32 [E, 1] — edge weight (normalised 0-1) y : int64 [N, 2] — binary labels [solvent, liquid] feat_names : list[str] — feature column names (matches X columns) """ if feature_cols is None: feature_cols = GNN_NODE_FEATURES + ["country_code"] country_enc = _country_encoding(bank_ids) n = len(bank_ids) idx = {b: i for i, b in enumerate(bank_ids)} # ── Node features ──────────────────────────────────────────────────────── X_raw = np.full((n, len(feature_cols)), np.nan, dtype=np.float32) for i, bid in enumerate(bank_ids): row = node_data.get(bid, {}) for j, feat in enumerate(feature_cols): if feat == "country_code": X_raw[i, j] = float(country_enc.get(bid, 0)) else: v = row.get(feat) if v is not None: try: X_raw[i, j] = float(v) except (TypeError, ValueError): pass # Impute missing values with column median (robust to outliers) for j in range(X_raw.shape[1]): col = X_raw[:, j] valid = col[~np.isnan(col)] fill = float(np.median(valid)) if len(valid) > 0 else 0.0 X_raw[np.isnan(col), j] = fill # ── Labels ─────────────────────────────────────────────────────────────── y = np.zeros((n, len(_LABEL_FIELDS)), dtype=np.int64) for i, bid in enumerate(bank_ids): row = node_data.get(bid, {}) for j, lf in enumerate(_LABEL_FIELDS): y[i, j] = int(bool(row.get(lf, True))) # ── Edges ──────────────────────────────────────────────────────────────── valid_edges = [e for e in edges if e["source"] in idx and e["target"] in idx and e["weight"] > 0] if valid_edges: max_w = max(e["weight"] for e in valid_edges) or 1.0 srcs = np.array([idx[e["source"]] for e in valid_edges], dtype=np.int64) dsts = np.array([idx[e["target"]] for e in valid_edges], dtype=np.int64) edge_index = np.stack([srcs, dsts], axis=0) # [2, E] edge_attr = np.array( [[e["weight"] / max_w] for e in valid_edges], dtype=np.float32 ) # [E, 1] else: edge_index = np.zeros((2, 0), dtype=np.int64) edge_attr = np.zeros((0, 1), dtype=np.float32) return X_raw, edge_index, edge_attr, y, feature_cols
[docs] def build_and_export( gnn_features: Dict[str, Dict[str, Any]], sim_graph: Dict[str, Any], output_dir: str = "data/gnn_datasets", tag: Optional[str] = None, ) -> Dict[str, Any]: """ Build a GNN dataset from LLM-fetched features + simulation graph and save it to *output_dir*. Parameters ---------- gnn_features : {bank_id: {feature: value}} Output of ``fetch_bank_features_for_gnn()``. sim_graph : {nodes: [...], edges: [...]} Output of ``simulation_state.get_network_graph_data()``. output_dir : str Directory to write dataset files into. tag : str, optional Short label for the export (used in filenames). Defaults to a timestamp. Returns ------- dict with keys: output_dir, files, n_nodes, n_edges, n_features, timestamp """ timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") tag = tag or timestamp out = os.path.join(output_dir, tag) os.makedirs(out, exist_ok=True) nodes_raw = sim_graph.get("nodes", []) edges_raw = sim_graph.get("edges", []) bank_ids = [n["id"] for n in nodes_raw] # Merge simulation node data with LLM-fetched features # LLM features take precedence for shared fields (except solvent/liquid which # come from the ABM). node_data: Dict[str, Dict[str, Any]] = {} for nd in nodes_raw: bid = nd["id"] merged = dict(nd) # sim fields: CET1_ratio, LCR, NSFR, total_assets, solvent, liquid llm_fields = gnn_features.get(bid, {}) for k, v in llm_fields.items(): if v is not None: merged[k] = v # LLM overrides sim for financial ratios node_data[bid] = merged feature_cols = GNN_NODE_FEATURES + ["country_code"] X, edge_index, edge_attr, y, feat_names = build_graph_tensors( bank_ids, node_data, edges_raw, feature_cols ) # ── Save CSVs ──────────────────────────────────────────────────────────── import csv nodes_csv = os.path.join(out, "nodes.csv") with open(nodes_csv, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["bank_id", "label", "country"] + feat_names + ["solvent", "liquid"]) for i, bid in enumerate(bank_ids): writer.writerow( [bid, BANK_LABELS.get(bid, bid), BANK_COUNTRIES.get(bid, "")] + X[i].tolist() + [int(y[i, 0]), int(y[i, 1])] ) edges_csv = os.path.join(out, "edges.csv") with open(edges_csv, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["source_idx", "target_idx", "source_id", "target_id", "weight_norm"]) for k in range(edge_index.shape[1]): s, t = int(edge_index[0, k]), int(edge_index[1, k]) writer.writerow([s, t, bank_ids[s], bank_ids[t], float(edge_attr[k, 0])]) # ── Save graph JSON ─────────────────────────────────────────────────────── graph_json = { "bank_ids": bank_ids, "bank_labels": {b: BANK_LABELS.get(b, b) for b in bank_ids}, "bank_countries": {b: BANK_COUNTRIES.get(b, "") for b in bank_ids}, "feature_names": feat_names, "label_names": _LABEL_FIELDS, "node_features": X.tolist(), # [N, F] "edge_index": edge_index.tolist(), # [2, E] "edge_attr": edge_attr.tolist(), # [E, 1] "labels": y.tolist(), # [N, 2] "raw_node_data": node_data, # full unprocessed fields per bank } graph_json_path = os.path.join(out, "graph_data.json") with open(graph_json_path, "w") as f: json.dump(graph_json, f, indent=2, default=str) # ── Metadata ───────────────────────────────────────────────────────────── meta = { "timestamp": timestamp, "tag": tag, "n_nodes": len(bank_ids), "n_edges": int(edge_index.shape[1]), "n_features": len(feat_names), "feature_names": feat_names, "label_names": _LABEL_FIELDS, "llm_coverage": { bid: sum(1 for v in gnn_features.get(bid, {}).values() if v is not None) for bid in bank_ids }, "source": "Cerebras LLM + DuckDuckGo web search + SCR ABM simulation", } meta_path = os.path.join(out, "metadata.json") with open(meta_path, "w") as f: json.dump(meta, f, indent=2) files = { "nodes_csv": nodes_csv, "edges_csv": edges_csv, "graph_json": graph_json_path, "metadata": meta_path, } # ── PyTorch Geometric .pt (optional) ───────────────────────────────────── try: import torch from torch_geometric.data import Data # type: ignore data = Data( x=torch.tensor(X, dtype=torch.float32), edge_index=torch.tensor(edge_index, dtype=torch.long), edge_attr=torch.tensor(edge_attr, dtype=torch.float32), y=torch.tensor(y, dtype=torch.long), ) # Store metadata as graph-level attributes data.bank_ids = bank_ids data.feat_names = feat_names data.label_names = _LABEL_FIELDS pt_path = os.path.join(out, "pyg_data.pt") torch.save(data, pt_path) files["pyg_pt"] = pt_path logger.info("Saved PyTorch Geometric Data object → %s", pt_path) except ImportError: logger.info("torch_geometric not installed — skipping .pt export") # ── Also save numpy arrays ──────────────────────────────────────────────── np.save(os.path.join(out, "X.npy"), X) np.save(os.path.join(out, "edge_index.npy"), edge_index) np.save(os.path.join(out, "edge_attr.npy"), edge_attr) np.save(os.path.join(out, "y.npy"), y) files["numpy"] = out logger.info( "GNN dataset exported: %d nodes, %d edges, %d features → %s", len(bank_ids), int(edge_index.shape[1]), len(feat_names), out, ) return { "output_dir": out, "files": files, "n_nodes": len(bank_ids), "n_edges": int(edge_index.shape[1]), "n_features": len(feat_names), "timestamp": timestamp, }