Предположим, у меня есть такой набор данных:

import numpy as np
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline

X,y = make_blobs(random_state=101) # My data

palette = sns.color_palette('bright',3)
sns.scatterplot(X[:,0], X[:,1],palette=palette,hue=y) # Visualizing the data

enter image description here

Я хотел бы выбрать данные, которые находятся близко к центру кластера. Скажем, я хочу выбрать данные близко к центру из cluster '0', сейчас я делаю вот так:

label_0 = X[y==0] # Want to select data from the label '0'

data_index = 2 # Manaully pick the point
sns.scatterplot(X[:,0], X[:,1],palette=palette,hue=y)
plt.scatter(label_0[data_index][0],label_0[data_index][1],marker='*')

enter image description here

Так как это не близко к центру, я меняю индекс и выбираю другой.

data_index = 4
sns.scatterplot(X[:,0], X[:,1],palette=palette,hue=y)
plt.scatter(label_0[data_index][0],label_0[data_index][1],marker='*')

Теперь уже близко. Но мне интересно, есть ли более эффективный способ добиться этого? Это можно сделать для небольшого набора данных, подобного этому, но если в моем наборе данных есть тысячи точек, я не думаю, что этот метод больше будет работать. введите описание изображения здесь

0
Raven Cheuk 5 Дек 2018 в 06:36

1 ответ

Лучший ответ

Один из подходов - использовать алгоритм K-средних. Это поможет вам найти центры каждого кластера.

Учитывая ваш набор данных, шаги будут следующими:

1) Найдите количество кластеров

num_clusters=len(np.unique(y)) #here 3

2) Примените k- означает кластеризацию ваших данных

from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(X)

3) Найдите центр каждого кластера

centers=kmeans.cluster_centers_ # gives the centers of each cluster
# array([[ 0.26542862,  1.85466779],
#        [-9.50316411, -6.52747391],
#        [ 3.64354311,  6.62683956]])

4) Поскольку эти центры могут не входить в ваши исходные данные, нам нужно найти к ним ближайшие точки.

from scipy import spatial

def nearest_point(array,query):
    return array[spatial.KDTree(array).query(query)[1]]

nearest_centers=np.array([nearest_point(X,center) for center in centers])
# array([[ 0.19313183,  1.80387958],
#       [-9.12488396, -6.32638926],
#       [ 3.65986315,  6.69035824]])

5) Постройте исходные данные и центры

sns.scatterplot(X[:,0], X[:,1],palette=palette,hue=y) 
for nc in nearest_centers:
    plt.scatter(nc[0],nc[1],marker='*',color='r')

Центры показаны красными крестами:

The centers are shows by the red crosses

0
Sruthi V 5 Дек 2018 в 04:55