Source code for pydgc.models.ae

# -*- coding: utf-8 -*-
import torch
import torch.nn.functional as F

from . import DGCModel
from torch import Tensor
from ..utils import Logger
from typing import List, Tuple
from ..metrics import DGCMetric
from sklearn.cluster import KMeans
from ..clusterings import KMeansGPU
from ..datasets import LoadAttribute
from torch_geometric.data import Data
from yacs.config import CfgNode as CN
from torch.utils.data import DataLoader
from ..modules import MLPEncoder, MLPDecoder


[docs]class AE(DGCModel): """Autoencoder model with MLP as encoder and decoder. Performs kmeans on embeddings. Args: logger (Logger): Logger. cfg (CN): Config. """ def __init__(self, logger: Logger, cfg: CN): super(AE, self).__init__(logger, cfg) dims = cfg.model.dims.copy() dims.insert(0, self.cfg.dataset.num_features) self.encoder = MLPEncoder(dims, self.cfg.model.act, self.cfg.model.act_last).to(self.device) self.decoder = MLPDecoder(dims, self.cfg.model.act, self.cfg.model.act_last).to(self.device) self.loss_curve = [] self.nmi_curve = [] self.best_embedding = None self.best_predicted_labels = None self.best_results = {'ACC': -1} self.reset_parameters()
[docs] def reset_parameters(self): self.encoder.reset_parameters() self.decoder.reset_parameters()
[docs] def forward(self, x) -> Tuple[List[Tensor], List[Tensor]]: x = x.to(self.device) encodes = self.encoder(x) decodes = self.decoder(encodes[-1]) return encodes, decodes
[docs] def loss(self, x: Tensor, hat_x: Tensor) -> Tensor: x = x.to(self.device) loss = F.mse_loss(hat_x, x) return loss
[docs] def train_model(self, data: Data, cfg: CN = None, flag: str = "TRAIN AE"): # when ae is trained in pre-training mode, cfg.pretrain must be input as parameter attribute = LoadAttribute(data.x) train_loader = DataLoader(attribute, batch_size=256, shuffle=True) if cfg is None: cfg = self.cfg.train self.logger.flag(flag) optimizer = torch.optim.Adam(self.parameters(), lr=float(cfg.lr)) for epoch in range(1, cfg.max_epoch + 1): self.train() loss_sum = torch.tensor(0.0) for batch_idx, (x, _) in enumerate(train_loader): optimizer.zero_grad() _, decodes = self.forward(x) loss = self.loss(x, decodes[-1]) loss.backward() optimizer.step() loss_sum += loss.item() self.loss_curve.append(loss_sum.item()) self.logger.loss(epoch, loss_sum) if epoch % 1 == 0: if self.cfg.evaluate.each: embedding, predicted_labels, results = self.evaluate(data) self.nmi_curve.append(results['NMI']) if results['ACC'] > self.best_results['ACC']: self.best_embedding = embedding self.best_predicted_labels = predicted_labels self.best_results = results if not self.cfg.evaluate.each: embedding, predicted_labels, results = self.evaluate(data) return self.loss_curve, self.nmi_curve, embedding, predicted_labels, results return self.loss_curve, self.nmi_curve, self.best_embedding, self.best_predicted_labels, self.best_results
[docs] def get_embedding(self, x: Tensor) -> Tensor: x = x.to(self.device) with torch.no_grad(): self.eval() encodes = self.encoder(x) return encodes[-1]
[docs] def clustering(self, data: Data, method: str = 'kmeans_gpu') -> Tuple[Tensor, Tensor, Tensor]: embedding = self.get_embedding(data.x) if method == 'kmeans_gpu': labels_, clustering_centers_ = KMeansGPU(self.cfg.dataset.n_clusters).fit(embedding) return embedding, labels_, clustering_centers_ if method == 'kmeans_cpu' or self.device == 'cpu': embedding = embedding.cpu().numpy() kmeans = KMeans(self.cfg.dataset.n_clusters, n_init=20) kmeans.fit_predict(embedding) labels_ = kmeans.labels_ clustering_centers_ = kmeans.cluster_centers_ labels_, clustering_centers_ = torch.from_numpy(labels_), torch.from_numpy(clustering_centers_) return torch.from_numpy(embedding), labels_, clustering_centers_
[docs] def evaluate(self, data: Data): embedding, predicted_labels, clustering_centers = self.clustering(data) ground_truth = data.y.numpy() metric = DGCMetric(ground_truth, predicted_labels.numpy(), embedding, data.edge_index) results = metric.evaluate_one_epoch(self.logger, self.cfg.evaluate) return embedding, predicted_labels, results