Source code for pydgc.pipelines.dgcluster_pipeline

# -*- coding: utf-8 -*-
from torch_geometric.utils import add_self_loops

from ..models import DGCLUSTER
from . import BasePipeline
from argparse import Namespace
from ..utils import perturb_data
import numpy as np
import scipy.sparse as sp
import torch.nn.functional as F
import torch


[docs]class DGCLUSTERPipeline(BasePipeline): """DGCLUSTER pipeline. Args: args (Namespace): Arguments. """ def __init__(self, args: Namespace): super().__init__(args)
[docs] def augment_data(self): self.data = perturb_data(self.data, self.cfg.dataset.augmentation) if self.dataset_name == "DBLP": self.data.edge_index = add_self_loops(self.data.edge_index)[0] # transform = T.NormalizeFeatures() num_nodes = self.data.x.shape[0] num_edges = (self.data.edge_index.shape[1]) labels = self.data.y.flatten() if not isinstance(labels, torch.Tensor): labels = torch.from_numpy(labels).long() else: if labels.dtype != torch.long: labels = labels.long() oh_labels = F.one_hot(labels, num_classes=max(labels) + 1) sparse_adj = sp.csr_matrix((np.ones(num_edges), self.data.edge_index.cpu().numpy()), shape=(num_nodes, num_nodes)) degree = torch.tensor(sparse_adj.sum(axis=1)).squeeze().float().to(self.device) num_edges = int((self.data.edge_index.shape[1]) / 2) self.data.oh_labels = oh_labels self.data.sparse_adj = sparse_adj self.data.degree = degree self.cfg.dataset.num_edges = num_edges self.data.to(self.device)
[docs] def build_model(self): model = DGCLUSTER(self.logger, self.cfg) self.logger.model_info(model) return model