Source code for pydgc.modules.ssc

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..utils import target_distribution
from torch.nn.parameter import Parameter


[docs]class SSCLayer(nn.Module): """Self-supervised clustering layer. Reference: http://proceedings.mlr.press/v48/xieb16.pdf Args: in_channels (int): Dimension of embeddings. out_channels (int): Number of clusters. method (str): Method of loss calculation, e.g., 'kl_div', 'cross_entropy', 'mse'. v (float): Parameter for soft assignment. """ def __init__(self, in_channels: int, out_channels: int, method: str = 'kl_div', v: float = 1.0): super(SSCLayer, self).__init__() self.cluster_centers = Parameter(torch.Tensor(out_channels, in_channels)) self.method = method self.v = v self.reset_parameters()
[docs] def reset_parameters(self): nn.init.xavier_normal_(self.cluster_centers.data)
[docs] def forward(self, embedding): q = 1.0 / (1.0 + torch.sum(torch.pow(embedding.unsqueeze(1) - self.cluster_centers, 2), 2) / self.v) q = q.pow((self.v + 1.0) / 2.0) q = (q.t() / torch.sum(q, 1)).t() return q
[docs] def get_q(self, embedding): with torch.no_grad(): self.eval() q = 1.0 / (1.0 + torch.sum(torch.pow(embedding.unsqueeze(1) - self.cluster_centers, 2), 2) / self.v) q = q.pow((self.v + 1.0) / 2.0) q = (q.t() / torch.sum(q, 1)).t() return q
[docs] @staticmethod def loss(q, method='kl_div'): """Calculate loss of self-supervised clustering. Args: q (torch.Tensor): Soft assignment matrix. method (str): Method of loss calculation, e.g., 'kl_div', 'cross_entropy', 'mse'. Returns: torch.Tensor: Loss value. """ p = target_distribution(q.detach().data) if method == 'mse': return F.mse_loss(q, p) if method == 'cross_entropy': return F.cross_entropy(q.log(), p) return F.kl_div(q.log(), p, reduction='batchmean')