# -*- coding: utf-8 -*-
from typing import Any
import torch.nn as nn
import torch.nn.functional as F
from ..clusterings import kmeans
from sklearn.cluster import SpectralClustering, KMeans
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
from . import DGCModel
from yacs.config import CfgNode as CN
from ..metrics import DGCMetric
from ..utils import Logger
import numpy as np
from typing import Callable, List, NamedTuple, Optional, Tuple, Union
import torch
from torch import Tensor
from torch_sparse import SparseTensor
import time
[docs]class Encoder(torch.nn.Module):
"""Encoder model for MAGI-Batch.
Args:
in_channels (int): Input feature dimension.
hidden_channels (list): Hidden layer dimensions.
base_model (torch.nn.Module): Base model for graph convolution.
dropout (float): Dropout rate.
ns (float): Negative slope for LeakyReLU.
"""
def __init__(self, in_channels: int, hidden_channels, base_model=SAGEConv, dropout: float = 0.5, ns: float = 0.5):
super(Encoder, self).__init__()
self.base_model = base_model
self.dropout = dropout
self.k = len(hidden_channels)
self.ns = ns
self.convs = nn.ModuleList()
self.convs.extend([base_model(in_channels, hidden_channels[0])])
for i in range(1, self.k):
self.convs.extend(
[base_model(hidden_channels[i-1], hidden_channels[i])])
self.reset_parameters()
[docs] def reset_parameters(self):
"""初始化模型参数"""
for i in range(self.k):
self.convs[i].reset_parameters()
[docs] def forward(self, x: torch.Tensor, edge_index=None, adjs=None, dropout=True):
if not adjs:
for i in range(self.k):
if dropout:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[i](x, edge_index)
x = F.leaky_relu(x, self.ns)
else:
for i, (edge_index, _, size) in enumerate(adjs):
if dropout:
x = F.dropout(x, p=self.dropout, training=self.training)
x_target = x[:size[1]] # Target nodes are always placed first.
x = self.convs[i]((x, x_target), edge_index)
x = F.leaky_relu(x, self.ns)
return x
[docs]class Loss(nn.Module):
"""Loss function for MAGI-Batch.
Args:
temperature (float): Temperature parameter for softmax.
scale_by_temperature (bool): Whether to scale the loss by temperature.
scale_by_weight (bool): Whether to scale the loss by weight.
"""
def __init__(self, temperature=0.07, scale_by_temperature=True, scale_by_weight=False):
super(Loss, self).__init__()
self.temperature = temperature
self.scale_by_temperature = scale_by_temperature
self.scale_by_weight = scale_by_weight
[docs] def forward(self, out, mask):
device = (torch.device('cuda') if out.is_cuda else torch.device('cpu'))
row, col, val = mask.storage.row(), mask.storage.col(), mask.storage.value()
row, col, val = row.to(device), col.to(device), val.to(device)
batch_size = out.shape[0]
# compute logits
dot = torch.matmul(out, out.T)
dot = torch.div(dot, self.temperature)
# for numerical stability
logits_max, _ = torch.max(dot, dim=1, keepdim=True)
dot = dot - logits_max.detach()
logits_mask = torch.scatter(
torch.ones(batch_size, batch_size).to(device),
1,
torch.arange(batch_size).view(-1, 1).to(device),
0
)
exp_logits = torch.exp(dot) * logits_mask
log_probs = dot - torch.log(exp_logits.sum(1, keepdim=True))
if torch.any(torch.isnan(log_probs)):
raise ValueError("Log_prob has nan!")
labels = row.view(row.shape[0], 1)
unique_labels, labels_count = labels.unique(dim=0, return_counts=True)
log_probs = log_probs[row, col]
log_probs = log_probs.view(-1, 1)
loss = torch.zeros_like(unique_labels, dtype=torch.float).to(device)
loss.scatter_add_(0, labels, log_probs)
loss = -1 * loss / labels_count.float().unsqueeze(1)
if self.scale_by_temperature:
loss *= self.temperature
loss = loss.mean()
return loss
[docs]def clustering(feature, n_clusters, true_labels, kmeans_device='cpu', batch_size=100000, tol=1e-4,
device=torch.device('cuda:0'), spectral_clustering=False):
"""Clustering function.
Args:
feature (torch.Tensor): Latent representation.
n_clusters (int): Number of clusters.
true_labels (torch.Tensor): True labels.
kmeans_device (str): Device for kmeans.
batch_size (int): Batch size for kmeans.
tol (float): Tolerance for kmeans.
device (torch.device): Device for kmeans.
spectral_clustering (bool): Whether to use spectral clustering.
Returns:
torch.Tensor: Clustering labels.
None: Clustering centers.
"""
if spectral_clustering:
if isinstance(feature, torch.Tensor):
feature = feature.numpy()
print("spectral clustering on cpu...")
Cluster = SpectralClustering(
n_clusters=n_clusters, affinity='precomputed', random_state=0)
f_adj = np.matmul(feature, np.transpose(feature))
predict_labels = Cluster.fit_predict(f_adj)
else:
if kmeans_device == 'cuda':
if isinstance(feature, np.ndarray):
feature = torch.tensor(feature)
print("kmeans on gpu...")
predict_labels, _ = kmeans(
X=feature, num_clusters=n_clusters, batch_size=batch_size, tol=tol, device=device)
predict_labels = predict_labels.numpy()
else:
if isinstance(feature, torch.Tensor):
feature = feature.numpy()
print("kmeans on cpu...")
Cluster = KMeans(n_clusters=n_clusters, max_iter=10000, n_init=20)
predict_labels = Cluster.fit_predict(feature)
return torch.from_numpy(predict_labels), None
[docs]def scale(z: torch.Tensor):
"""Scale the latent representation.
Args:
z (torch.Tensor): Latent representation.
Returns:
torch.Tensor: Scaled latent representation.
"""
zmax = z.max(dim=1, keepdim=True)[0]
zmin = z.min(dim=1, keepdim=True)[0]
z_std = (z - zmin) / ((zmax - zmin) + 1e-20)
z_scaled = z_std
return z_scaled
[docs]class EdgeIndex(NamedTuple):
edge_index: Tensor
e_id: Optional[Tensor]
size: Tuple[int, int]
[docs] def to(self, *args, **kwargs):
edge_index = self.edge_index.to(*args, **kwargs)
e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None
return EdgeIndex(edge_index, e_id, self.size)
[docs]class Adj(NamedTuple):
adj_t: SparseTensor
e_id: Optional[Tensor]
size: Tuple[int, int]
[docs] def to(self, *args, **kwargs):
adj_t = self.adj_t.to(*args, **kwargs)
e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None
return Adj(adj_t, e_id, self.size)
[docs]class NeighborSampler(torch.utils.data.DataLoader):
"""Neighbor sampler for graph convolution.
This code adapted from the pytorch geometric
(https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/loader/neighbor_sampler.py).
"""
def __init__(self,
edge_index: Union[Tensor, SparseTensor],
adj: SparseTensor,
sizes: List[int],
is_train: bool = False,
wt: int = 20,
wl: int = 4,
drop_last=False,
node_idx: Optional[Tensor] = None,
num_nodes: Optional[int] = None,
return_e_id: bool = True,
transform: Callable = None,
**kwargs):
edge_index = edge_index.to('cpu')
if 'collate_fn' in kwargs:
del kwargs['collate_fn']
if 'dataset' in kwargs:
del kwargs['dataset']
# Save for Pytorch Lightning < 1.6:
self.edge_index = edge_index
self.adj = adj
self.node_idx = node_idx
self.num_nodes = num_nodes
self.is_train = is_train
self.drop_last = drop_last
self.wt = wt
self.wl = wl
self.sizes = sizes
self.return_e_id = return_e_id
self.transform = transform
self.is_sparse_tensor = isinstance(edge_index, SparseTensor)
self.__val__ = None
# Obtain a *transposed* `SparseTensor` instance.
if not self.is_sparse_tensor:
if (num_nodes is None and node_idx is not None
and node_idx.dtype == torch.bool):
num_nodes = node_idx.size(0)
if (num_nodes is None and node_idx is not None
and node_idx.dtype == torch.long):
num_nodes = max(int(edge_index.max()), int(node_idx.max())) + 1
if num_nodes is None:
num_nodes = int(edge_index.max()) + 1
value = torch.arange(edge_index.size(1)) if return_e_id else None
self.adj_t = SparseTensor(row=edge_index[0], col=edge_index[1],
value=value,
sparse_sizes=(num_nodes, num_nodes)).t()
else:
adj_t = edge_index
if return_e_id:
self.__val__ = adj_t.storage.value()
value = torch.arange(adj_t.nnz())
adj_t = adj_t.set_value(value, layout='coo')
self.adj_t = adj_t
self.adj_t.storage.rowptr()
if node_idx is None:
node_idx = torch.arange(self.adj_t.sparse_size(0))
elif node_idx.dtype == torch.bool:
node_idx = node_idx.nonzero(as_tuple=False).view(-1)
super().__init__(
node_idx.view(-1).tolist(), collate_fn=self.sample, drop_last=self.drop_last, **kwargs)
[docs] def get_batch(self, random_nodes):
random_nodes_count = random_nodes.shape[0]
rowptr, col, _ = self.adj.csr()
# stage one
random_nodes_repeat = random_nodes.repeat(self.wt)
rw1 = self.adj.random_walk(random_nodes_repeat, self.wl)[:, 1:]
if not isinstance(rw1, torch.Tensor):
rw1 = rw1[0]
rw1 = rw1.t().reshape(-1, random_nodes_count).t()
batch = []
for i in range(random_nodes_count):
rw_nodes, rw_times = torch.unique(rw1[i], return_counts=True)
nodes = rw_nodes[rw_times > rw_times.float().mean()].tolist()
batch += nodes
batch += random_nodes.tolist()
batch = torch.tensor(batch).unique()
# stage two
batch_size = batch.shape[0]
batch_repeat = batch.repeat(self.wt)
rw2 = self.adj.random_walk(batch_repeat, self.wl)[:, 1:]
if not isinstance(rw2, torch.Tensor):
rw2 = rw2[0]
rw2 = rw2.t().reshape(-1, batch_size).t()
row, col, val = [], [], []
for i in range(batch.shape[0]):
rw2_nodes, rw2_times = torch.unique(rw2[i], return_counts=True)
row += [batch[i].item()] * rw2_nodes.shape[0]
col += rw2_nodes.tolist()
val += rw2_times.tolist()
unique_nodes = list(set(row + col))
subg2g = dict(zip(unique_nodes, list(range(len(unique_nodes)))))
row = [subg2g[x] for x in row]
col = [subg2g[x] for x in col]
idx = torch.tensor([subg2g[x] for x in batch.tolist()])
adj_ = SparseTensor(row=torch.LongTensor(row), col=torch.LongTensor(col), value=torch.tensor(val),
sparse_sizes=(len(unique_nodes), len(unique_nodes)))
adj_batch, _ = adj_.saint_subgraph(idx)
# adj_batch = adj_batch.set_diag(0.) # bug
adj_batch_sp = adj_batch.to_scipy(layout='coo')
adj_batch_sp.setdiag([0] * idx.shape[0])
adj_batch = SparseTensor.from_scipy(adj_batch_sp)
return batch, adj_batch
[docs] def sample(self, batch):
if not isinstance(batch, Tensor):
batch = torch.tensor(batch)
adj_batch = None
if self.is_train:
batch, adj_batch = self.get_batch(batch)
batch_size: int = len(batch)
adjs = []
n_id = batch
for size in self.sizes:
adj_t, n_id = self.adj_t.sample_adj(n_id, size, replace=False)
e_id = adj_t.storage.value()
size = adj_t.sparse_sizes()[::-1]
if self.__val__ is not None:
adj_t.set_value_(self.__val__[e_id], layout='coo')
if self.is_sparse_tensor:
adjs.append(Adj(adj_t, e_id, size))
else:
row, col, _ = adj_t.coo()
edge_index = torch.stack([col, row], dim=0)
adjs.append(EdgeIndex(edge_index, e_id, size))
adjs = adjs[0] if len(adjs) == 1 else adjs[::-1]
out = (batch_size, n_id, adjs)
out = self.transform(*out) if self.transform is not None else out
return out, adj_batch, batch
def __repr__(self) -> str:
return f'{self.__class__.__name__}(sizes={self.sizes})'
[docs]def get_mask(adj):
"""Get mask for positive edges.
Args:
adj (SparseTensor): Adjacency matrix.
Returns:
SparseTensor: Masked adjacency matrix.
"""
batch_mean = adj.mean(dim=1)
mean = batch_mean[torch.LongTensor(adj.storage.row())]
mask = (adj.storage.value() - mean) > - 1e-10
row, col, val = adj.storage.row()[mask], adj.storage.col()[
mask], adj.storage.value()[mask]
adj_ = SparseTensor(row=row, col=col, value=val,
sparse_sizes=(adj.size(0), adj.size(1)))
return adj_
[docs]class MAGIBatch(DGCModel):
""" Revisiting Modularity Maximization for Graph Clustering: A Contrastive Learning Perspective.
Reference: https://doi.org/10.1145/3637528.3671967
Args:
logger (Logger): Logger object.
cfg (CN): Configuration object.
"""
def __init__(self, logger: Logger, cfg: CN):
super(MAGIBatch, self).__init__(logger, cfg)
encoder_dims = cfg.model.dims.encoder
projection_dims = cfg.model.dims.projection
encoder_dims.insert(0, cfg.dataset.num_features)
self.encoder = Encoder(encoder_dims[0], encoder_dims[1:],
dropout=cfg.model.dropout, ns=cfg.model.ns).to(self.device)
self.tau = cfg.model.tau
self.in_channels = encoder_dims[-1]
self.project_hidden = projection_dims if projection_dims != "" else None
self.Loss = Loss(temperature=self.tau)
self.project = None
if self.project_hidden is not None:
self.project = nn.ModuleList()
self.activations = nn.ModuleList()
self.project.extend(
[nn.Linear(self.in_channels, self.project_hidden[0])])
self.activations.extend([nn.PReLU(projection_dims[0])])
for i in range(1, len(self.project_hidden)):
self.project.extend(
[nn.Linear(self.project_hidden[i - 1], self.project_hidden[i])])
self.activations.extend([nn.PReLU(projection_dims[i])])
self.project.to(self.device)
self.loss_curve = []
self.nmi_curve = []
self.best_embedding = None
self.best_predicted_labels = None
self.best_results = {'ACC': -1}
[docs] def reset_parameters(self):
pass
[docs] def forward(self, data, n_id) -> Any:
x = data.x.to(self.device)
adjs = data.adjs
x = self.encoder(x[n_id], adjs=adjs)
if self.project is not None:
for i in range(len(self.project_hidden)):
x = self.project[i](x)
x = self.activations[i](x)
return x
[docs] def loss(self, *args, **kwargs) -> Tensor:
pass
[docs] def train_model(self, data: Data, cfg: CN = None, flag: str = "TRAIN MAGI"):
if cfg is None:
cfg = self.cfg.train
edge_index, adj = data.edge_index, data.adj
size = self.cfg.dataset.size
train_loader = NeighborSampler(edge_index, adj,
is_train=True,
node_idx=None,
wt=self.cfg.dataset.wt,
wl=self.cfg.dataset.wl,
sizes=size,
batch_size=cfg.batchsize,
shuffle=True,
drop_last=True,
num_workers=6)
optimizer = torch.optim.Adam(self.parameters(), lr=float(cfg.lr), weight_decay=float(cfg.weight_decay))
# train
time_train = time.time()
for epoch in range(1, cfg.max_epoch + 1):
self.train()
total_loss = total_examples = 0
for (batch_size, n_id, adjs), adj_batch, batch in train_loader:
# `adjs` holds a list of `(edge_index, e_id, size)` tuples.
if len(self.cfg.model.dims.encoder[1:]) == 1:
adjs = [adjs]
adjs = [adj.to(self.device) for adj in adjs]
data.adjs = adjs
adj_ = get_mask(adj_batch)
optimizer.zero_grad()
out = self.forward(data, n_id)
out = F.normalize(out, p=2, dim=1)
loss = self.Loss(out, adj_)
loss.backward()
optimizer.step()
total_loss += float(loss)
total_examples += batch_size
self.loss_curve.append(total_loss / total_examples)
self.logger.loss(epoch, total_loss / total_examples)
if epoch % 10 == 0:
if self.cfg.evaluate.each:
embedding, predicted_labels, results = self.evaluate(data)
self.nmi_curve.append(results['NMI'])
if results['ACC'] > self.best_results['ACC']:
self.best_embedding = embedding
self.best_predicted_labels = predicted_labels
self.best_results = results
time_cost = time.time() - time_train
if time_cost // 60 > cfg.max_duration:
break
if not self.cfg.evaluate.each:
embedding, predicted_labels, results = self.evaluate(data)
self.nmi_curve = None
return self.loss_curve, self.nmi_curve, embedding, predicted_labels, results
return self.loss_curve, self.nmi_curve, self.best_embedding, self.best_predicted_labels, self.best_results
[docs] def get_embedding(self, data) -> Tensor:
# eval
edge_index, adj = data.edge_index, data.adj
size = self.cfg.dataset.size
test_loader = NeighborSampler(edge_index, adj,
is_train=False,
node_idx=None,
sizes=size,
batch_size=10000,
shuffle=False,
drop_last=False,
num_workers=6)
with torch.no_grad():
self.eval()
z = []
for count, ((batch_size, n_id, adjs), _, batch) in enumerate(test_loader):
if len(self.cfg.model.dims.encoder[1:]) == 1:
adjs = [adjs]
adjs = [adj.to(self.device) for adj in adjs]
data.adjs = adjs
out = self.forward(data, n_id)
z.append(out.detach().cpu().float())
embedding = torch.cat(z, dim=0)
embedding = F.normalize(embedding, p=2, dim=1)
return embedding
[docs] def clustering(self, data) -> Tuple[Tensor, Tensor, Any]:
embedding = self.get_embedding(data)
labels, clustering_centers = clustering(embedding.cpu().numpy(), self.cfg.dataset.n_clusters, data.y,
kmeans_device=self.cfg.train.kmeans.device,
batch_size=self.cfg.train.kmeans.batch, tol=1e-4, device=self.cfg.device,
spectral_clustering=False)
return embedding, labels, clustering_centers
[docs] def evaluate(self, data):
embedding, predicted_labels, clustering_centers = self.clustering(data)
ground_truth = data.y.numpy().flatten()
metric = DGCMetric(ground_truth, predicted_labels.numpy(), embedding, data.edge_index)
results = metric.evaluate_one_epoch(self.logger, self.cfg.evaluate)
return embedding, predicted_labels, results