Source code for pydgc.models.magi

# -*- coding: utf-8 -*-
from typing import Tuple, List, Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from ..clusterings import kmeans
from sklearn.cluster import SpectralClustering, KMeans
from torch import Tensor
from torch_geometric.data import Data

from torch_geometric.nn import GCNConv

from . import DGCModel

from yacs.config import CfgNode as CN

from ..metrics import DGCMetric
from ..utils import Logger
import numpy as np


[docs]class Encoder(nn.Module): """Encoder for MAGI. Args: in_channels (int): Number of input channels. hidden_channels (list): List of hidden channels. base_model (torch.nn.Module): Base model for graph convolution. dropout (float): Dropout rate. ns (float): Negative slope for leaky ReLU. """ def __init__(self, in_channels: int, hidden_channels, base_model=GCNConv, dropout: float = 0.5, ns: float = 0.5): super(Encoder, self).__init__() self.base_model = base_model self.dropout = dropout self.k = len(hidden_channels) self.ns = ns self.convs = nn.ModuleList() self.convs.extend([base_model(in_channels, hidden_channels[0])]) for i in range(1, self.k): self.convs.extend( [base_model(hidden_channels[i - 1], hidden_channels[i])]) self.reset_parameters()
[docs] def reset_parameters(self): """初始化模型参数""" for i in range(self.k): self.convs[i].reset_parameters()
[docs] def forward(self, x: torch.Tensor, edge_index=None, adjs=None, dropout=True): if not adjs: for i in range(self.k): if dropout: x = F.dropout(x, p=self.dropout, training=self.training) x = self.convs[i](x, edge_index) x = F.leaky_relu(x, self.ns) else: for i, (edge_index, _, size) in enumerate(adjs): if dropout: x = F.dropout(x, p=self.dropout, training=self.training) x_target = x[:size[1]] # Target nodes are always placed first. x = self.convs[i]((x, x_target), edge_index) x = F.leaky_relu(x, self.ns) return x
[docs]class Loss(nn.Module): """Loss function for MAGI. Args: temperature (float): Temperature scale_by_temperature (bool): Whether to scale loss by temperature. scale_by_weight (bool): Whether to scale loss by weight. """ def __init__(self, temperature=0.07, scale_by_temperature=True, scale_by_weight=False): super(Loss, self).__init__() self.temperature = temperature self.scale_by_temperature = scale_by_temperature self.scale_by_weight = scale_by_weight
[docs] def forward(self, out, mask): device = (torch.device('cuda') if out.is_cuda else torch.device('cpu')) row, col, val = mask.storage.row(), mask.storage.col(), mask.storage.value() row, col, val = row.to(device), col.to(device), val.to(device) batch_size = out.shape[0] # compute logits dot = torch.matmul(out, out.T) dot = torch.div(dot, self.temperature) # for numerical stability logits_max, _ = torch.max(dot, dim=1, keepdim=True) dot = dot - logits_max.detach() logits_mask = torch.scatter( torch.ones(batch_size, batch_size).to(device), 1, torch.arange(batch_size).view(-1, 1).to(device), 0 ) exp_logits = torch.exp(dot) * logits_mask log_probs = dot - torch.log(exp_logits.sum(1, keepdim=True)) if torch.any(torch.isnan(log_probs)): raise ValueError("Log_prob has nan!") labels = row.view(row.shape[0], 1) unique_labels, labels_count = labels.unique(dim=0, return_counts=True) log_probs = log_probs[row, col] log_probs = log_probs.view(-1, 1) loss = torch.zeros_like(unique_labels, dtype=torch.float).to(device) loss.scatter_add_(0, labels, log_probs) loss = -1 * loss / labels_count.float().unsqueeze(1) if self.scale_by_temperature: loss *= self.temperature loss = loss.mean() return loss
[docs]def clustering(feature, n_clusters, true_labels, kmeans_device='cpu', batch_size=100000, tol=1e-4, device=torch.device('cuda:0'), spectral_clustering=False): """Clustering function. Args: feature (torch.Tensor): Latent representation. n_clusters (int): Number of clusters. true_labels (torch.Tensor): True labels. kmeans_device (str): Device for kmeans. batch_size (int): Batch size. tol (float): Tolerance. device (torch.device): Device. spectral_clustering (bool): Whether to use spectral clustering. Returns: torch.Tensor: Clustering labels. None: Clustering centers. """ if spectral_clustering: if isinstance(feature, torch.Tensor): feature = feature.numpy() print("spectral clustering on cpu...") Cluster = SpectralClustering( n_clusters=n_clusters, affinity='precomputed', random_state=0) f_adj = np.matmul(feature, np.transpose(feature)) predict_labels = Cluster.fit_predict(f_adj) else: if kmeans_device == 'cuda': if isinstance(feature, np.ndarray): feature = torch.tensor(feature) print("kmeans on gpu...") predict_labels, _ = kmeans( X=feature, num_clusters=n_clusters, batch_size=batch_size, tol=tol, device=device) predict_labels = predict_labels.numpy() else: if isinstance(feature, torch.Tensor): feature = feature.numpy() print("kmeans on cpu...") Cluster = KMeans(n_clusters=n_clusters, max_iter=10000, n_init=20) predict_labels = Cluster.fit_predict(feature) return torch.from_numpy(predict_labels), None
[docs]def scale(z: torch.Tensor): """Scale the latent representation. Args: z (torch.Tensor): Latent representation. Returns: torch.Tensor: Scaled latent representation. """ zmax = z.max(dim=1, keepdim=True)[0] zmin = z.min(dim=1, keepdim=True)[0] z_std = (z - zmin) / ((zmax - zmin) + 1e-20) z_scaled = z_std return z_scaled
[docs]class MAGI(DGCModel): """ Revisiting Modularity Maximization for Graph Clustering: A Contrastive Learning Perspective. Reference: https://doi.org/10.1145/3637528.3671967 Args: logger (Logger): Logger object. cfg (CN): Configuration object. """ def __init__(self, logger: Logger, cfg: CN): super(MAGI, self).__init__(logger, cfg) encoder_dims = cfg.model.dims.encoder.copy() projection_dims = cfg.model.dims.projection encoder_dims.insert(0, cfg.dataset.num_features) self.encoder = Encoder(encoder_dims[0], encoder_dims[1:], base_model=GCNConv, dropout=cfg.model.dropout, ns=cfg.model.ns).to(self.device) self.tau = cfg.model.tau self.in_channels = encoder_dims[-1] self.project_hidden = projection_dims if projection_dims != "" else None self.activation = nn.PReLU self.Loss = Loss(temperature=self.tau) self.project = None if self.project_hidden is not None: self.project = nn.ModuleList() self.activations = nn.ModuleList() self.project.extend( [nn.Linear(self.in_channels, self.project_hidden[0])]) self.activations.extend([nn.PReLU(projection_dims[0])]) for i in range(1, len(self.project_hidden)): self.project.extend( [nn.Linear(self.project_hidden[i - 1], self.project_hidden[i])]) self.activations.extend([nn.PReLU(projection_dims[i])]) self.project.to(self.device) self.loss_curve = [] self.nmi_curve = [] self.best_embedding = None self.best_predicted_labels = None self.best_results = {'ACC': -1}
[docs] def reset_parameters(self): pass
[docs] def forward(self, data) -> Any: x = data.x.to(self.device) edge_index = data.edge_index.to(self.device) x = self.encoder(x, edge_index) if self.project is not None: for i in range(len(self.project_hidden)): x = self.project[i](x) x = self.activations[i](x) return x
[docs] def loss(self, *args, **kwargs) -> Tensor: pass
[docs] def train_model(self, data: Data, cfg: CN = None, flag: str = "TRAIN MAGI"): if cfg is None: cfg = self.cfg.train optimizer = torch.optim.Adam(self.parameters(), lr=float(cfg.lr), weight_decay=float(cfg.weight_decay)) mask = data.mask.to(self.device) # train for epoch in range(1, cfg.max_epoch + 1): self.train() optimizer.zero_grad() out = self.forward(data) out = scale(out) out = F.normalize(out, p=2, dim=1) loss = self.Loss(out, mask) loss.backward() optimizer.step() self.loss_curve.append(loss.item()) self.logger.loss(epoch, loss) if epoch % 1 == 0: if self.cfg.evaluate.each: embedding, predicted_labels, results = self.evaluate(data) self.nmi_curve.append(results['NMI']) if results['ACC'] > self.best_results['ACC']: self.best_embedding = embedding self.best_predicted_labels = predicted_labels self.best_results = results if not self.cfg.evaluate.each: embedding, predicted_labels, results = self.evaluate(data) self.nmi_curve = None return self.loss_curve, self.nmi_curve, embedding, predicted_labels, results return self.loss_curve, self.nmi_curve, self.best_embedding, self.best_predicted_labels, self.best_results
[docs] def get_embedding(self, data) -> Tensor: # eval with torch.no_grad(): self.eval() out = self.forward(data) out = scale(out) embedding = F.normalize(out, p=2, dim=1) return embedding.detach()
[docs] def clustering(self, data) -> Tuple[Tensor, Tensor, Any]: embedding = self.get_embedding(data) labels, clustering_centers = clustering(embedding.cpu().numpy(), self.cfg.dataset.n_clusters, data.y, spectral_clustering=True) return embedding, labels, clustering_centers
[docs] def evaluate(self, data): embedding, predicted_labels, clustering_centers = self.clustering(data) ground_truth = data.y.numpy() metric = DGCMetric(ground_truth, predicted_labels.numpy(), embedding, data.edge_index) results = metric.evaluate_one_epoch(self.logger, self.cfg.evaluate) return embedding, predicted_labels, results