Source code for pydgc.pipelines.magi_pipeline

# -*- coding: utf-8 -*-
import torch
from torch_sparse import SparseTensor

from ..models import MAGI
from . import BasePipeline
from argparse import Namespace
from ..utils import perturb_data
from torch_geometric.utils import to_undirected, add_remaining_self_loops, add_self_loops


[docs]def get_sim(batch, adj, wt=20, wl=3): """Get similarity matrix. Args: batch (torch.Tensor): Batch indices. adj (SparseTensor): Adjacency matrix. wt (int, optional): Number of random walks. Defaults to 20. wl (int, optional): Length of random walks. Defaults to 3. Returns: torch.Tensor: Similarity matrix. """ rowptr, col, _ = adj.csr() batch_size = batch.shape[0] batch_repeat = batch.repeat(wt) rw = adj.random_walk(batch_repeat, wl)[:, 1:] if not isinstance(rw, torch.Tensor): rw = rw[0] rw = rw.t().reshape(-1, batch_size).t() row, col, val = [], [], [] for i in range(batch.shape[0]): rw_nodes, rw_times = torch.unique(rw[i], return_counts=True) row += [batch[i].item()] * rw_nodes.shape[0] col += rw_nodes.tolist() val += rw_times.tolist() unique_nodes = list(set(row + col)) subg2g = dict(zip(unique_nodes, list(range(len(unique_nodes))))) row = [subg2g[x] for x in row] col = [subg2g[x] for x in col] idx = torch.tensor([subg2g[x] for x in batch.tolist()]) adj_ = SparseTensor(row=torch.LongTensor(row), col=torch.LongTensor(col), value=torch.tensor(val), sparse_sizes=(len(unique_nodes), len(unique_nodes))) adj_batch, _ = adj_.saint_subgraph(idx) adj_batch = adj_batch.set_diag(0.) # src, dst = dict_r[idx[adj_batch.storage.row()[3].item()].item()], dict_r[idx[adj_batch.storage.col()[3].item()].item()] return batch, adj_batch
[docs]def get_mask(adj): """Get mask matrix. Args: adj (SparseTensor): Adjacency matrix. Returns: SparseTensor: Mask matrix. """ batch_mean = adj.mean(dim=1) mean = batch_mean[torch.LongTensor(adj.storage.row())] mask = (adj.storage.value() - mean) > - 1e-10 row, col, val = adj.storage.row()[mask], adj.storage.col()[ mask], adj.storage.value()[mask] adj_ = SparseTensor(row=row, col=col, value=val, sparse_sizes=(adj.size(0), adj.size(1))) return adj_
[docs]class MAGIPipeline(BasePipeline): """MAGI pipeline. Args: args (Namespace): Arguments. """ def __init__(self, args: Namespace): super().__init__(args)
[docs] def augment_data(self): self.data = perturb_data(self.data, self.cfg.dataset.augmentation) x, edge_index, y = self.data.x, self.data.edge_index, self.data.y if self.dataset_name == "DBLP": edge_index = to_undirected(add_self_loops(edge_index)[0]) else: edge_index = to_undirected(add_remaining_self_loops(edge_index)[0]) N, E = self.data.num_nodes, self.data.num_edges adj = SparseTensor(row=edge_index[0], col=edge_index[1], sparse_sizes=(N, N)) adj.fill_value_(1.) batch = torch.LongTensor(list(range(N))) batch, adj_batch = get_sim(batch, adj, wt=self.cfg.dataset.wt, wl=self.cfg.dataset.wl) mask = get_mask(adj_batch) self.data.edge_index = edge_index self.data.mask = mask
[docs] def build_model(self): model = MAGI(self.logger, self.cfg) self.logger.model_info(model) return model