Source code for pydgc.clusterings.batch_kmeans_gpu

# -*- coding: utf-8 -*-
from functools import partial

import numpy as np
import torch
from tqdm import tqdm


[docs]def initialize(X, num_clusters, seed): """initialize cluster centers Args: X: (torch.tensor) matrix num_clusters: (int) number of clusters seed: (int) seed for kmeans Returns: (np.array) initial state """ num_samples = len(X) if seed == None: indices = np.random.choice(num_samples, num_clusters, replace=False) else: np.random.seed(seed) indices = np.random.choice(num_samples, num_clusters, replace=False) initial_state = X[indices] return initial_state
[docs]def kmeans( X, num_clusters, distance='euclidean', batch_size=100000, cluster_centers=[], tol=1e-4, tqdm_flag=True, iter_limit=0, device=torch.device('cpu'), gamma_for_soft_dtw=0.001, seed=None, ): """perform kmeans Reference: https://github.com/EdisonLeeeee/MAGI/blob/master/magi/batch_kmeans_cuda.py Args: X: (torch.tensor) matrix num_clusters: (int) number of clusters distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] seed: (int) seed for kmeans tol: (float) threshold [default: 0.0001] device: (torch.device) device [default: cpu] tqdm_flag: Allows to turn logs on and off iter_limit: hard limit for max number of iterations gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0 Returns: (torch.tensor, torch.tensor) cluster ids, cluster centers """ if tqdm_flag: print(f'running k-means on {device}..') if distance == 'euclidean': pairwise_distance_function = partial(pairwise_distance, batch_size=batch_size, device=device, tqdm_flag=tqdm_flag) else: raise NotImplementedError # convert to float X = X.float() # transfer to device X = X.to(device) if type(cluster_centers) == list: initial_state = initialize(X, num_clusters, seed=seed) else: if tqdm_flag: print('resuming') # find data point closest to the initial cluster center initial_state = cluster_centers dis = pairwise_distance_function(X, initial_state) choice_points = torch.argmin(dis, dim=0) initial_state = X[choice_points] initial_state = initial_state.to(device) iteration = 0 if tqdm_flag: tqdm_meter = tqdm(desc='[running kmeans]') while True: choice_cluster = pairwise_distance_function(X, initial_state) initial_state_pre = initial_state.clone() for index in range(num_clusters): # selected = idx[choice_cluster == index].to(device) # selected = X[selected] selected = torch.nonzero(choice_cluster == index).squeeze().to(device) selected = torch.index_select(X, 0, selected) # https://github.com/subhadarship/kmeans_pytorch/issues/16 if selected.shape[0] == 0: selected = X[torch.randint(len(X), (1,))] initial_state[index] = selected.mean(dim=0) center_shift = torch.sum( torch.sqrt( torch.sum((initial_state - initial_state_pre) ** 2, dim=1) )) # increment iteration iteration = iteration + 1 # update tqdm meter if tqdm_flag: tqdm_meter.set_postfix( iteration=f'{iteration}', center_shift=f'{center_shift ** 2:0.6f}', tol=f'{tol:0.6f}' ) tqdm_meter.update() if center_shift ** 2 < tol: break if iter_limit != 0 and iteration >= iter_limit: break return choice_cluster.cpu(), initial_state.cpu()
[docs]def kmeans_predict( X, cluster_centers, batch_size=100000, distance='euclidean', device=torch.device('cpu'), gamma_for_soft_dtw=0.001, tqdm_flag=True ): """predict using cluster centers Args: X: (torch.tensor) matrix cluster_centers: (torch.tensor) cluster centers distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] device: (torch.device) device [default: 'cpu'] gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0 Returns: (torch.tensor) cluster ids """ if tqdm_flag: print(f'predicting on {device}..') if distance == 'euclidean': pairwise_distance_function = partial(pairwise_distance, batch_size=batch_size, device=device, tqdm_flag=tqdm_flag) else: raise NotImplementedError # convert to float X = X.float() # transfer to device X = X.to(device) choice_cluster = pairwise_distance_function(X, cluster_centers, batch_size=batch_size) return choice_cluster.cpu()
[docs]def pairwise_distance(data1, data2, batch_size=100000, device=torch.device('cpu'), tqdm_flag=True): """compute pairwise distance Args: data1: (torch.tensor) matrix data2: (torch.tensor) matrix batch_size: (int) batch size device: (torch.device) device [default: 'cpu'] tqdm_flag: Allows to turn logs on and off Returns: (torch.tensor) pairwise distance """ if tqdm_flag: print(f'device is :{device}') # transfer to device data1, data2 = data1.to(device), data2.to(device) # N*1*M A = data1.unsqueeze(dim=1) # 1*N*M B = data2.unsqueeze(dim=0) if batch_size == -1: # full batch kmeans dis_ = (A - B) ** 2.0 # return N*N matrix for pairwise distance dis_ = dis_.sum(dim=-1).squeeze() return torch.argmin(dis_, dim=1) else: # mini-batch kmeans choice_cluster = torch.zeros(data1.shape[0]) for batch_idx in tqdm(range(int(np.ceil(data1.shape[0] / batch_size)))): dis = (A[batch_idx * batch_size: (batch_idx + 1) * batch_size] - B) ** 2.0 dis = dis.sum(dim=-1).squeeze() choice_ = torch.argmin(dis, dim=1) choice_cluster[batch_idx * batch_size: (batch_idx + 1) * batch_size] = choice_ choice_cluster = choice_cluster.long() return choice_cluster