Source code for pydgc.pipelines.base_pipeline

# -*- coding: utf-8 -*-
import time
import torch
import traceback
import numpy as np
import os.path as osp

from argparse import Namespace
from abc import ABC, abstractmethod
from yacs.config import CfgNode as CN

from ..models import DGCModel
from ..datasets import load_dataset
from ..utils.logger import create_logger
from ..utils.visualization import DGCVisual
from ..utils.device import auto_select_device
from ..metrics import build_results_dict
from ..utils import load_dataset_specific_cfg, setup_seed, get_formatted_time, dump_cfg, check_required_cfg


[docs]class BasePipeline(ABC): """Standardized pipeline for deep graph clustering. Args: args (Namespace): Arguments for setting values frequently changed. """ def __init__(self, args: Namespace): torch.set_default_dtype(torch.float32) self.args = args self.cfg_file_path = "config.yaml" if not hasattr(args, "cfg_file_path") else args.cfg_file_path if hasattr(args, "dataset_name"): self.dataset_name = args.dataset_name else: raise ValueError("Please specify dataset name! You can specify it in run.py or use --dataset_name!") self.cfg = None self.logger = None self.device = None self.data = None self.ground_truth = None self.predicted_labels = None self.results = {} self.loss_curve = [] self.nmi_curve = [] self.embeddings = None self.times = [] self.current_round = 0
[docs] def load_config(self): """load config from yaml Args: self.cfg_file_path (str): Path to the config file. self.dataset_name (str): Name of the dataset. """ self.cfg = load_dataset_specific_cfg(self.cfg_file_path, self.dataset_name) cfg = check_required_cfg(self.cfg, dataset_name=self.dataset_name) if isinstance(cfg, CN): self.cfg = cfg self.cfg.dataset.name = self.dataset_name if hasattr(self.args, 'drop_edge'): self.cfg.dataset.augmentation.drop_edge = float(self.args.drop_edge) if hasattr(self.args, 'drop_feature'): self.cfg.dataset.augmentation.drop_feature = float(self.args.drop_feature) if hasattr(self.args, 'add_edge'): self.cfg.dataset.augmentation.add_edge = float(self.args.add_edge) if hasattr(self.args, 'add_noise'): self.cfg.dataset.augmentation.add_noise = float(self.args.add_noise) if hasattr(self.args, 'rounds'): self.cfg.train.rounds = int(self.args.rounds) self.cfg.evaluate.each = self.args.eval_each
[docs] def load_logger(self): """Load logger. Args: self.cfg (CN): Config object. """ log_file = osp.join(self.cfg.logger.dir, f'{get_formatted_time()}.log') self.logger = create_logger(self.cfg.logger.name, self.cfg.logger.mode, log_file) auto_select_device(self.logger, self.cfg) self.device = torch.device(self.cfg.device) if self.cfg.train.rounds > 1: self.results = build_results_dict(self.cfg.evaluate)
[docs] def load_dataset(self): """Load dataset. Args: self.cfg (CN): Config object. self.dataset_name (str): Name of the dataset. """ try: if not self.cfg: raise ValueError("Please load config before loading data!") if not self.cfg.dataset.is_custom: dataset = load_dataset(self.cfg.dataset.dir, self.dataset_name) else: dataset = load_dataset(self.cfg.dataset.dir, self.dataset_name, p=self.cfg.dataset.p, is_custom=self.cfg.dataset.is_custom, custom_is_graph=self.cfg.dataset.custom_is_graph, metric=self.cfg.dataset.metric) self.cfg.dataset.n_clusters = dataset.num_classes # if self.dataset_name.lower() == "arxiv": # data = dataset[0] # else: data = dataset[0] if data.x.layout == torch.sparse_csr: data.x = data.x.to_dense() data.x = data.x.float() self.cfg.dataset.num_nodes = data.num_nodes self.cfg.dataset.num_features = data.num_features num_edges = int((data.edge_index.shape[1]) / 2) self.cfg.dataset.num_edges = num_edges self.ground_truth = data.y.numpy() self.data = data except ValueError as e: print(e) except Exception as e: print(e)
[docs] @abstractmethod def augment_data(self): pass
[docs] @abstractmethod def build_model(self) -> DGCModel: """Build model. Args: self.cfg (CN): Config object. Returns: DGCModel: Model object. """ pass
[docs] def evaluate(self, results): """Evaluate model. Args: self.cfg (CN): Config object. results (dict): Evaluation results. """ if self.cfg.train.rounds > 1: for key, value in results.items(): self.results[key].append(value) else: self.results = results
[docs] def visualize(self): """Visualize results. Args: self.cfg (CN): Config object. """ cfg = self.cfg.visualize plot = DGCVisual(save_path=cfg.dir, font_family=['Times New Roman', 'SimSun'], font_size=24) if cfg.tsne: self.logger.flag(f"TSNE START") plot.plot_clustering(self.embeddings.cpu().numpy(), self.predicted_labels, palette='Set2', method='tsne', filename='tsne_plot') self.logger.flag(f"TSNE END") if cfg.umap: self.logger.flag(f"UMAP START") plot.plot_clustering(self.embeddings.cpu().numpy(), self.predicted_labels, palette='Set2', method='umap', filename='umap_plot') self.logger.flag(f"UMAP END") if cfg.heatmap: self.logger.flag(f"HEATMAP START") plot.plot_heatmap(self.embeddings.cpu().numpy(), self.predicted_labels, method='inner_product', show_axis=False, show_color_bar=False) self.logger.flag(f"HEATMAP END") if cfg.loss: self.logger.flag(f"LOSS START") plot.plot_loss(self.loss_curve, metrics=self.nmi_curve) self.logger.flag(f"LOSS END")
[docs] def run(self, pretrain=False, flag="TRAIN"): """Run pipeline. Args: self.cfg_file_path (str): Path to the config file. self.dataset_name (str): Name of the dataset. self.args (Namespace): Arguments. pretrain (bool): Whether to pretrain the model. flag (str): Flag for logging. """ try: self.load_config() self.load_logger() self.load_dataset() self.augment_data() if self.cfg.train.seed == -1: # set seed to no. current round for round_ in range(self.cfg.train.rounds): self.logger.flag(f"Round: {round_+1}/{self.cfg.train.rounds} Dataset: {self.dataset_name}") setup_seed(round_) start = time.time() model = self.build_model() if pretrain: if hasattr(model, 'pretrain'): self.loss_curve = model.pretrain(self.data, self.cfg.train.pretrain, flag) end = time.time() time_cost = round(end - start, 4) self.times.append(time_cost) self.logger.info(f"Time cost: {time_cost}") return else: raise ValueError("Model does not support pretraining!") else: self.loss_curve, self.nmi_curve, embeddings, predicted_labels, results = model.train_model(self.data, self.cfg.train) end = time.time() time_cost = round(end - start, 4) self.times.append(time_cost) self.logger.info(f"Time cost: {time_cost}") self.predicted_labels = predicted_labels.numpy() self.embeddings = embeddings.detach() self.evaluate(results) if self.cfg.visualize.when == 'each': self.visualize() else: # fixed seed with given seed setup_seed(self.cfg.train.seed) for round_ in range(self.cfg.train.rounds): self.logger.flag(f"Round: {round_+1}/{self.cfg.train.rounds} Dataset: {self.dataset_name}") start = time.time() model = self.build_model() if pretrain: if hasattr(model, 'pretrain'): self.loss_curve = model.pretrain(self.data, self.cfg.train.pretrain, flag) end = time.time() time_cost = end - start self.times.append(time_cost) self.logger.info(f"Time cost: {time_cost}") return else: raise ValueError("Model does not support pretraining!") else: self.loss_curve, self.nmi_curve, embeddings, predicted_labels, results = model.train_model(self.data, self.cfg.train) end = time.time() time_cost = end - start self.times.append(time_cost) self.logger.info(f"Time cost: {time_cost}") self.predicted_labels = predicted_labels.numpy() self.embeddings = embeddings.detach() self.evaluate(results) if self.cfg.visualize.when == 'each': self.visualize() self.logger.table(self.cfg.logger.dir, self.dataset_name, self.results) self.logger.info(f"Average time cost: {np.mean(self.times)}±{np.std(self.times)}") mem_used = torch.cuda.max_memory_allocated(device=self.device) / 1024 / 1024 self.logger.info(f"The max memory allocated to model is: {mem_used:.2f} MB.") if self.cfg.visualize.when == 'end': self.visualize() dump_cfg(self.cfg) except Exception as e: self.logger.error(str(e)) self.logger.error(traceback.format_exc())