Kmeans聚类 使用Pytorch和GPU加速
目标
sklearn库里面的kmeans算法默认运行在gpu上,运行效率较低。有时候需要在网络内动态的对特征进行分聚类。如果有基于Pytorch Tensor的kmeans实现则可以极大提升效率。
经过检索发现ContrastiveSceneContexts有类似实现,可以参考该实现:
环境
pytorch, pykeops
pip install pykeops -i https://pypi.tuna.tsinghua.edu.cn/simple
代码
import os
import torch
import numpy as np
import glob
import time
import argparse
import pykeops
from pykeops.torch import LazyTensor
pykeops.clean_pykeops()
def kmeans(pointcloud, k=10, iterations=10, verbose=True):
n, dim = pointcloud.shape # Number of samples, dimension of the ambient space
start = time.time()
clusters = pointcloud[:k, :].clone() # Simplistic random initialization
pointcloud_cuda = LazyTensor(pointcloud[:, None, :]) # (Npoints, 1, D)
# K-means loop:
for _ in range(iterations):
clusters_previous = clusters.clone()
clusters_gpu = LazyTensor(clusters[None, :, :]) # (1, Nclusters, D)
distance_matrix = ((pointcloud_cuda - clusters_gpu) ** 2).sum(-1) # (Npoints, Nclusters) symbolic matrix of squared distances
cloest_clusters = distance_matrix.argmin(dim=1).long().view(-1) # Points -> Nearest cluster
# #points for each cluster
clusters_count = torch.bincount(cloest_clusters, minlength=k).float() # Class weights
for d in range(dim): # Compute the cluster centroids with torch.bincount:
clusters[:, d] = torch.bincount(cloest_clusters, weights=pointcloud[:, d], minlength=k) / clusters_count
# for clusters that have no points assigned
mask = clusters_count == 0
clusters[mask] = clusters_previous[mask]
end = time.time()
if verbose:
print("K-means example with {:,} points in dimension {:,}, K = {:,}:".format(n, dim, k))
print('Timing for {} iterations: {:.5f}s = {} x {:.5f}s\n'.format(
iterations, end - start, iterations, (end-start) / iterations))
# nearest neighbouring search for each cluster
cloest_points_to_centers = distance_matrix.argmin(dim=0).long().view(-1)
return cloest_points_to_centers
Reference
- Ji Hou, Benjamin Graham, Matthias Nießner, Saining Xie:
Exploring Data-Efficient 3D Scene Understanding With Contrastive Scene Contexts. CVPR 2021: 15587-15597 - https://github.com/facebookresearch/ContrastiveSceneContexts/blob/83515bef4754b3d90fc3b3a437fa939e0e861af8/downstream/semseg/lib/sampling_points.py#L28
这篇好文章是转载于:学新通技术网
- 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
- 本站站名: 学新通技术网
- 本文地址: /boutique/detail/tanhghfggh
系列文章
更多
同类精品
更多
-
photoshop保存的图片太大微信发不了怎么办
PHP中文网 06-15 -
《学习通》视频自动暂停处理方法
HelloWorld317 07-05 -
word里面弄一个表格后上面的标题会跑到下面怎么办
PHP中文网 06-20 -
Android 11 保存文件到外部存储,并分享文件
Luke 10-12 -
photoshop扩展功能面板显示灰色怎么办
PHP中文网 06-14 -
微信公众号没有声音提示怎么办
PHP中文网 03-31 -
excel下划线不显示怎么办
PHP中文网 06-23 -
excel打印预览压线压字怎么办
PHP中文网 06-22 -
TikTok加速器哪个好免费的TK加速器推荐
TK小达人 10-01 -
怎样阻止微信小程序自动打开
PHP中文网 06-13