-
Notifications
You must be signed in to change notification settings - Fork 56
/
kmeans.py
24 lines (19 loc) · 856 Bytes
/
kmeans.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import numpy as np
class KMeans:
def __init__(self, n_clusters=4):
self.K = n_clusters
def fit(self, X):
self.centroids = X[np.random.choice(len(X), self.K, replace=False)]
self.intial_centroids = self.centroids
self.prev_label, self.labels = None, np.zeros(len(X))
while not np.all(self.labels == self.prev_label) :
self.prev_label = self.labels
self.labels = self.predict(X)
self.update_centroid(X)
return self
def predict(self, X):
return np.apply_along_axis(self.compute_label, 1, X)
def compute_label(self, x):
return np.argmin(np.sqrt(np.sum((self.centroids - x)**2, axis=1)))
def update_centroid(self, X):
self.centroids = np.array([np.mean(X[self.labels == k], axis=0) for k in range(self.K)])