# -*- coding: utf-8 -*-
import os
import torch
import numpy as np
import os.path as osp
from torch import Tensor
from torch.utils.data import Dataset as TorchDataset
from ogb.nodeproppred import PygNodePropPredDataset
from sklearn.neighbors import kneighbors_graph
from torch_geometric.data import Data, InMemoryDataset, Dataset, download_google_url
from torch_geometric.graphgym.loader import set_dataset_attr
from torch_geometric.utils import index_to_mask, to_undirected, add_remaining_self_loops
from torch_geometric.datasets import (Planetoid, Coauthor, Amazon, WebKB, Actor, CitationFull,
CoraFull, AttributedGraphDataset, NELL, Reddit, Reddit2, Yelp, AmazonProducts,
LastFMAsia, Airports, HeterophilousGraphDataset)
PYG_SUPPORTED_DATASET = ["CORA", "CITE", "CITESEER", "PUBMED", "COCS", "COPS", "AMAC", "AMAP", "CORNELL",
"TEXAS", "WISC", "ACTOR", "DBLPFULL", "CORAFULL", "WIKI", "BLOG", "PPI",
"FLICKR", "FACEBOOK", "TWEIBO", "MAG", "NELL", "REDDIT", "REDDIT2",
"YELP", "AMP", "LFMA", "BAT", "EAT", "ROMAN"]
DGC_SUPPORTED_DATASET = ["ACM", "DBLP", "UAT"]
NONGRAPH_SUPPORTED_DATASET = ["USPS", "HHAR", "REUT"]
OGB_SUPPORTED_DATASET = ["ARXIV"]
DATASET_NAME_MAP = {
"ACM": 'ACM',
"DBLP": 'DBLP',
"UAT": 'UAT',
"BAT": 'Brazil',
"EAT": 'Europe',
"USPS": 'USPS',
"HHAR": 'HHAR',
"REUT": 'REUT',
"CORA": 'Cora',
"CITESEER": 'CiteSeer',
"CITE": 'CiteSeer',
"PUBMED": 'PubMed',
"COCS": 'CS',
"COPS": 'Physics',
"AMAC": 'Computers',
"AMAP": 'Photo',
"CORNELL": 'Cornell',
"TEXAS": 'Texas',
"WISC": 'Wisconsin',
"ACTOR": 'Actor',
"DBLPFULL": 'DBLP',
"CORAFULL": 'Corafull',
"WIKI": 'Wiki',
"BLOG": 'BlogCatalog',
"PPI": 'PPI',
"FLICKR": 'Flickr',
"FACEBOOK": 'Facebook',
"TWEIBO": 'TWeibo',
"MAG": 'MAG',
"NELL": 'NELL',
"REDDIT": 'Reddit',
"REDDIT2": 'Reddit',
"YELP": 'Yelp',
"AMP": 'AmazonProducts',
"LFMA": 'LastFMAsia',
"ROMAN": 'roman-empire',
"ARXIV": "ogbn-arxiv"
}
METRIC_MAP = {"USPS": "heat", "HHAR": "cosine", "REUT": "cosine"}
[docs]class UserDataset(InMemoryDataset):
"""User custom Dataset inherited from InMemoryDataset of PyG
Args:
root (str): Path of data stored
dataset_name (str): Name of dataset
"""
def __init__(self, root: str, dataset_name: str):
self.dataset_name = dataset_name.upper()
super().__init__(root)
self.load(self.processed_paths[0])
@property
def raw_file_names(self) -> str:
return f"{self.dataset_name.upper()}.npz"
@property
def processed_file_names(self) -> str:
return "data.pt"
[docs] def download(self):
pass
[docs] def process(self):
data_path = os.path.join(self.root, 'raw', self.raw_file_names)
raw_data = np.load(data_path, allow_pickle=True)
x = torch.from_numpy(raw_data['feature']).to(torch.float32)
graph = torch.from_numpy(raw_data['graph'])
edge_index = graph.nonzero().t().to(torch.int64)
y = torch.from_numpy(raw_data['label']).to(torch.int64)
data = Data(x=x, edge_index=edge_index, y=y)
self.save([data], self.processed_paths[0])
[docs]class NonGraphDataset(InMemoryDataset):
"""Dataset object for constructing non-graph data
Args:
root (str): Path of data stored
dataset_name (str): Name of dataset
neighbors (int, optional): k for knn. Defaults to 1.
metric (str, optional): Similarity measurement. Defaults to 'minkowski'.
p (int, optional): Power parameter for the Minkowski metric. Defaults to 2.
"""
def __init__(self,
root: str,
dataset_name: str,
neighbors: int = 1,
metric: str = 'minkowski',
p: int = 2):
self.root = root
self.dataset_name = dataset_name.split('_')[0].upper()
self.neighbors = neighbors
self.metric = metric
self.p = p
super().__init__(root)
self.load(self.processed_paths[0])
[docs] def download(self):
pass
@property
def raw_file_names(self) -> str:
return f"{self.dataset_name.upper()}.npz"
@property
def processed_dir(self) -> str:
return os.path.join(self.root, f"{self.neighbors}NN", 'processed')
@property
def processed_file_names(self) -> str:
return "data.pt"
[docs] def process(self):
data_path = os.path.join(self.root, 'raw', self.raw_file_names)
raw_data = np.load(data_path, allow_pickle=True)
x = raw_data['feature']
if self.metric == 'heat':
graph_tensor = heat_kernel_knn_graph(x, self.neighbors)
else:
graph = kneighbors_graph(x, n_neighbors=self.neighbors, mode='connectivity', metric=self.metric, p=self.p, include_self=False, n_jobs=-1)
graph_tensor = torch.from_numpy(graph.toarray())
edge_index = graph_tensor.nonzero().t()
x_tensor = torch.from_numpy(x)
y = torch.from_numpy(raw_data['label'])
data = Data(x=x_tensor, edge_index=edge_index, y=y)
self.save([data], self.processed_paths[0])
[docs]def heat_kernel_knn_graph(x: np.ndarray, k: int) -> Tensor:
"""Construct heat kernel graph
Args:
x (np.ndarray): Input data
k (int): Number of neighbors
Returns:
Tensor: Adjacency matrix
"""
xy = np.matmul(x, x.transpose())
xx = (x * x).sum(1).reshape(-1, 1)
xx_yy = xx + xx.transpose()
euclidean_distance = xx_yy - 2 * xy
euclidean_distance[euclidean_distance < 1e-5] = 0
distance_matrix = np.sqrt(euclidean_distance)
# heat kernel, exp^{- euclidean^2/t}
distance_matrix = - (distance_matrix ** 2) / 2
distance_matrix = np.exp(distance_matrix)
# top k
distance_matrix = torch.from_numpy(distance_matrix)
top_k, index = torch.topk(distance_matrix, k)
top_k_min = torch.min(top_k, dim=-1).values.unsqueeze(-1).repeat(1, distance_matrix.shape[-1])
ones = torch.ones_like(distance_matrix)
zeros = torch.zeros_like(distance_matrix)
knn_graph = torch.where(torch.ge(distance_matrix, top_k_min), ones, zeros)
return knn_graph
[docs]class DGCGraphDataset(UserDataset):
"""DGC Dataset object for constructing graph data
Args:
root (str): Path of data stored
dataset_name (str): Name of dataset
"""
def __init__(self, root, dataset_name):
super().__init__(root, dataset_name)
[docs] def download(self) -> None:
if self.dataset_name in ["ACM", "DBLP"]:
file_id = '1QQK4-5hMcP5MitE3vu6hnBydiunQEaPh' if self.dataset_name == 'ACM' else '1n614RUq-SLh_b3xxffaP5bgt1lBxM_OJ'
folder = osp.join(self.root, 'raw')
filename = f'{self.dataset_name}.npz'
download_google_url(file_id, folder, filename)
else:
raise ValueError(f"Custom dataset {self.dataset_name} does not exist!")
[docs]class DGCNonGraphDataset(NonGraphDataset):
"""DGC Non-Graph Dataset object for constructing graph from non-graph data
Args:
root (str): Path of data stored
dataset_name (str): Name of dataset
neighbors (int, optional): k for knn. Defaults to 1.
metric (str, optional): Similarity measurement. Defaults to 'minkowski'.
p (int, optional): Power parameter for the Minkowski metric. Defaults to 2.
"""
def __init__(self, root, dataset_name, neighbors=1, metric='minkowski', p=2):
super().__init__(root, dataset_name, neighbors, metric, p)
[docs] def download(self) -> None:
if self.dataset_name in ["USPS", "HHAR", "REUT"]:
file_id_dict = {
'USPS': '1d-PBz2Hk3ZHbgr4QeaZuD7Dsk1Qw21jw',
'HHAR': '1bCBvv3ENYScXPf0uST9tZ1diaSnK5aLg',
'REUT': '1b4MV5a-B3kHqDFlj59lgpzDhDjawB3gA'
}
file_id = file_id_dict[self.dataset_name]
folder = osp.join(self.root, 'raw')
filename = f'{self.dataset_name}.npz'
download_google_url(file_id, folder, filename)
else:
raise ValueError(f"Custom non-graph dataset {self.dataset_name} does not exist!")
[docs]def load_pyg(dataset_dir: str, dataset_name: str) -> Dataset:
"""Load PyG dataset built in PyDGC.
Args:
dataset_dir (str): Dataset stored root path.
dataset_name (str): Dataset name. Available datasets: CORA, CITE, CITESEER, PUBMED, BAT, EAT, UAT, COCS, COPS, AMAC, AMAP, CORNELL, TEXAS, WISC, WIKI, BLOG, PPI, FLICKR, FACEBOOK, TWEIBO, MAG, ACTOR, CORAFULL, DBLPFULL, NELL, REDDIT, REDDIT2, YELP, AMP, LFMA, ROMAN.
Returns:
Dataset: PyG dataset.
"""
if dataset_name in ["CORA", "CITE", "CITESEER", "PUBMED"]:
return Planetoid(dataset_dir, name=DATASET_NAME_MAP[dataset_name])
if dataset_name in ["BAT", "EAT", "UAT"]:
return Airports(dataset_dir, name=DATASET_NAME_MAP[dataset_name])
if dataset_name in ["COCS", "COPS"]:
return Coauthor(dataset_dir, name=DATASET_NAME_MAP[dataset_name])
if dataset_name in ["AMAC", "AMAP"]:
return Amazon(dataset_dir, name=DATASET_NAME_MAP[dataset_name])
if dataset_name in ["CORNELL", "TEXAS", "WISC"]:
return WebKB(dataset_dir, name=DATASET_NAME_MAP[dataset_name])
if dataset_name in ["WIKI", "BLOG", "PPI", "FLICKR", "FACEBOOK", "TWEIBO", "MAG"]:
return AttributedGraphDataset(dataset_dir, name=DATASET_NAME_MAP[dataset_name])
if dataset_name == "ACTOR":
return Actor(dataset_dir)
if dataset_name == "CORAFULL":
return CoraFull(dataset_dir)
if dataset_name == "DBLPFULL":
return CitationFull(dataset_dir, name=DATASET_NAME_MAP[dataset_name])
if dataset_name == "NELL":
return NELL(dataset_dir)
if dataset_name == "REDDIT":
return Reddit(dataset_dir)
if dataset_name == "REDDIT2":
return Reddit2(dataset_dir)
if dataset_name == "YELP":
return Yelp(dataset_dir)
if dataset_name == "AMP":
return AmazonProducts(dataset_dir)
if dataset_name == "LFMA":
return LastFMAsia(dataset_dir)
if dataset_name == "ROMAN":
return HeterophilousGraphDataset(dataset_dir, name=DATASET_NAME_MAP[dataset_name])
raise ValueError(f"Unsupported pyg dataset {dataset_name}!")
[docs]def load_dgc_graph(dataset_dir: str, dataset_name: str) -> Dataset:
"""Load custom DGC graph dataset.
Args:
dataset_dir (str): Dataset stored root path.
dataset_name (str): Dataset name.
Returns:
Dataset: Custom DGC graph dataset.
"""
return DGCGraphDataset(dataset_dir, dataset_name)
[docs]def load_dgc_non_graph(dataset_dir: str,
dataset_name: str,
*,
neighbors: int = 1,
metric: str = 'minkowski',
p: int = 2) -> Dataset:
"""Load custom non-graph dataset.
Args:
dataset_dir (str): Dataset stored root path.
dataset_name (str): Dataset name for non-graph dataset.
neighbors (int, optional): K for KNN. Self is not included. Defaults to 1.
metric (str, optional): Distance type, 'minkowski' for default. Defaults to 'minkowski'.
p (int, optional): Power parameter for the Minkowski metric. When p = 1, this is equivalent to using manhattan_distance (l1), and euclidean_distance (l2) for p = 2. For arbitrary p, minkowski_distance (l_p) is used.
Returns
NonGraphDataset: Custom non-graph Dataset object.
"""
return DGCNonGraphDataset(dataset_dir, dataset_name, neighbors, metric, p)
[docs]def load_ogb(dataset_dir: str, dataset_name: str) -> Dataset:
"""Load OGB dataset.
Args:
dataset_dir (str): Dataset stored root path.
dataset_name (str): Dataset name.
Returns:
Dataset: OGB dataset.
"""
dataset = PygNodePropPredDataset(root=dataset_dir, name=dataset_name)
splits = dataset.get_idx_split()
split_names = ['train_mask', 'val_mask', 'test_mask']
if splits is not None:
for i, key in enumerate(splits.keys()):
split_idx = splits[key]
if not isinstance(split_idx, torch.Tensor):
split_idx = torch.tensor(split_idx)
mask = index_to_mask(split_idx, size=dataset.data.y.shape[0])
set_dataset_attr(dataset, split_names[i], mask, len(mask))
else:
raise ValueError("Splits returned by get_idx_split() is None.")
edge_index = to_undirected(add_remaining_self_loops(dataset.data.edge_index)[0])
set_dataset_attr(dataset, 'edge_index', edge_index,
edge_index.shape[1])
return dataset
[docs]def load_dataset(dataset_dir: str, dataset_name: str, p: int = 2, is_custom: bool = False, custom_is_graph: bool = True, metric: str = 'minkowski') -> Dataset:
"""Load raw datasets.
Args:
dataset_dir (str): Dataset stored root path.
dataset_name (str): Dataset name.
p (int, optional): Power parameter for the Minkowski metric. Defaults to 2.
is_custom (bool, optional): Whether the dataset is custom. Defaults to False.
custom_is_graph (bool, optional): Whether the custom dataset is graph. Defaults to True.
metric (str, optional): Distance type for non-graph data. Defaults to 'minkowski'.
Returns:
Dataset: Raw dataset.
"""
try:
dataset_dir = dataset_dir.split('_')[0] if dataset_dir.__contains__('_') else dataset_dir
neighbors = int(dataset_name.split('_')[-1]) if dataset_name.__contains__('_') else 1
dataset_name = dataset_name.split('_')[0] if dataset_name.__contains__('_') else dataset_name
if dataset_name in OGB_SUPPORTED_DATASET:
return load_ogb(dataset_dir, DATASET_NAME_MAP[dataset_name])
elif dataset_name in PYG_SUPPORTED_DATASET:
return load_pyg(dataset_dir, dataset_name)
elif dataset_name in DGC_SUPPORTED_DATASET:
return load_dgc_graph(dataset_dir, dataset_name)
elif dataset_name in NONGRAPH_SUPPORTED_DATASET:
return load_dgc_non_graph(dataset_dir, dataset_name, neighbors=neighbors, metric=METRIC_MAP[dataset_name[:4]], p=p)
# load custom dataset
elif is_custom:
if custom_is_graph:
return load_dgc_graph(dataset_dir, dataset_name)
else:
return load_dgc_non_graph(dataset_dir, dataset_name, neighbors=neighbors, metric=metric, p=p)
else:
raise ValueError
except NotADirectoryError:
print(f"{dataset_dir} is not a directory!")
except ValueError:
print(f"Dataset name {dataset_name} is unsupported! Must be selected from {str(PYG_SUPPORTED_DATASET + DGC_SUPPORTED_DATASET + NONGRAPH_SUPPORTED_DATASET + OGB_SUPPORTED_DATASET)}")
except Exception as e:
print(f"Unknown error occurred: {e}")
# Always raise an error if no valid Dataset is returned
raise RuntimeError("Failed to load dataset. Please check the error messages above.")
[docs]def preprocess_custom_data(root: str, dataset_name: str, dataset_type: str = 'graph'):
"""Transform dataset with format from Awesome-Deep-Graph-Clustering.
Args:
root (str): Dataset stored root path.
dataset_name (str): Dataset name.
dataset_type (str, optional): Dataset type. Options: 'graph', 'non-graph'. Defaults to 'graph'.
"""
try:
if not osp.isdir(root):
raise NotADirectoryError(f"{root} is not a directory!")
if dataset_type not in ['graph', 'non-graph']:
raise ValueError(f"Dataset type {dataset_type} is unsupported! Supported: 'graph', 'non-graph'.")
save_path = osp.join(root, f'{dataset_name.upper()}/raw/{dataset_name.upper()}.npz')
f_path = osp.join(root, f"{dataset_name}/raw/feature.npy")
if not osp.exists(f_path):
raise FileNotFoundError(f"{f_path} not found!")
feature = np.load(f_path, allow_pickle=True)
l_path = osp.join(root, f"{dataset_name}/raw/label.npy")
if not osp.exists(l_path):
raise FileNotFoundError(f"{l_path} not found!")
label = np.load(l_path, allow_pickle=True)
if dataset_type == 'graph':
g_path = osp.join(root, f"{dataset_name}/raw/graph.npy")
if not osp.exists(g_path):
raise FileNotFoundError(f"{g_path} not found!")
graph = np.load(g_path, allow_pickle=True)
np.savez(save_path, feature=feature, graph=graph, label=label)
else:
np.savez(save_path, feature=feature, label=label)
except NotADirectoryError as e:
print(e)
except FileNotFoundError as e:
print(e)
except ValueError as e:
print(e)
except Exception as e:
print(e)
return None
[docs]class LoadAttribute(TorchDataset):
"""Load attribute dataset.
Args:
x (np.ndarray): Attribute matrix.
"""
def __init__(self, x):
if isinstance(x, torch.Tensor) and x.device != torch.device('cpu'):
x = x.cpu()
self.x = x
def __len__(self):
return self.x.shape[0]
def __getitem__(self, idx):
return torch.from_numpy(np.array(self.x[idx])).float(), \
torch.from_numpy(np.array(idx))