Source code for pydgc.clusterings.kmeans_gpu

# -*- coding: utf-8 -*-
import torch
import numpy as np

from torch import Tensor
from typing import Tuple


[docs]class KMeansGPU: """Performs K-means clustering on GPU Reference: https://github.com/yueliu1999/HSAN/blob/main/kmeans_gpu.py Args: n_clusters: (int) number of clusters distance: (str) distance metric [default: 'euclidean'] tol: (float) tolerance [default: 1e-4] max_iter: (int) maximum number of iterations [default: 500] device: (str) device [default: 'cuda'] """ def __init__(self, n_clusters: int, *, distance: str = 'euclidean', tol: float = 1e-4, max_iter: int = 500, device: str = 'cuda'): self.n_clusters = n_clusters self.max_iter = max_iter self.tol = tol self.distance = distance self.device = torch.device(device) self.cluster_centers_ = None self.labels_ = None
[docs] def initialize(self, X: Tensor) -> np.array: """initialize cluster centers Args: X: (torch.tensor) matrix Returns: (np.array) initial state """ num_samples = len(X) indices = np.random.choice(num_samples, self.n_clusters, replace=False) initial_state = X[indices] return initial_state
[docs] def pairwise_distance(self, data1: Tensor, data2: Tensor) -> Tensor: """compute pairwise distance Args: data1: (torch.tensor) matrix data2: (torch.tensor) matrix Returns: (torch.tensor) pairwise distance """ # transfer to device data1, data2 = data1.to(self.device), data2.to(self.device) # N*1*M A = data1.unsqueeze(dim=1) # 1*N*M B = data2.unsqueeze(dim=0) dis = (A - B) ** 2.0 # return N*N matrix for pairwise distance dis = dis.sum(dim=-1).squeeze() return dis
[docs] def pairwise_cosine(self, data1: Tensor, data2: Tensor) -> Tensor: """compute pairwise cosine distance Args: data1: (torch.tensor) matrix data2: (torch.tensor) matrix Returns: (torch.tensor) pairwise cosine distance """ # transfer to device data1, data2 = data1.to(self.device), data2.to(self.device) # N*1*M A = data1.unsqueeze(dim=1) # 1*N*M B = data2.unsqueeze(dim=0) # normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5] A_normalized = A / A.norm(dim=-1, keepdim=True) B_normalized = B / B.norm(dim=-1, keepdim=True) cosine = A_normalized * B_normalized # return N*N matrix for pairwise distance cosine_dis = 1 - cosine.sum(dim=-1).squeeze() return cosine_dis
[docs] def fit(self, X: Tensor) -> Tuple[Tensor, Tensor]: """perform kmeans Args: X: (torch.tensor) matrix Returns: (torch.tensor, torch.tensor) cluster ids, cluster centers """ if self.distance == 'euclidean': pairwise_distance_function = self.pairwise_distance elif self.distance == 'cosine': pairwise_distance_function = self.pairwise_cosine else: raise NotImplementedError # convert to float X = X.float().to(self.device) # initialize dis_min = float('inf') initial_state_best = None for i in range(20): initial_state = self.initialize(X) dis = pairwise_distance_function(X, initial_state).sum() if dis < dis_min: dis_min = dis initial_state_best = initial_state self.cluster_centers_ = initial_state_best iteration = 0 while iteration <= self.max_iter: dis = pairwise_distance_function(X, self.cluster_centers_) self.labels_ = torch.argmin(dis, dim=1) initial_state_pre = self.cluster_centers_.clone() for index in range(self.n_clusters): selected = torch.nonzero(self.labels_ == index).squeeze().to(self.device) selected = torch.index_select(X, 0, selected) self.cluster_centers_[index] = selected.mean(dim=0) center_shift = torch.sum( torch.sqrt( torch.sum((self.cluster_centers_ - initial_state_pre) ** 2, dim=1) )) # increment iteration iteration = iteration + 1 if center_shift ** 2 < self.tol: break return self.labels_.cpu(), self.cluster_centers_
[docs] def predict(self, X: Tensor) -> Tensor: """predict using cluster centers Args: X: (torch.tensor) matrix Returns: (torch.tensor) cluster ids """ if self.distance == 'euclidean': pairwise_distance_function = self.pairwise_distance elif self.distance == 'cosine': pairwise_distance_function = self.pairwise_cosine else: raise NotImplementedError # convert to float X = X.float().to(self.device) dis = pairwise_distance_function(X, self.cluster_centers_) self.labels_ = torch.argmin(dis, dim=1) return self.labels_.cpu()