In [None]:

from numpy import *
set_printoptions(legacy = "1.25")

from scipy.linalg import norm

def nearest_index(x,means):
	distances = norm(means-x,axis=1)
	return argmin(distances)


In [None]:

def assign_clusters(dataset,means):
	clusters = [ [ ] for _ in range(k) ]
	for x in dataset:
		i = nearest_index(x,means)
		clusters[i].append(x)
	return [ c for c in clusters if len(c) ]


In [None]:

def update_means(clusters):
	return array([ sum(c,axis=0)/len(c) for c in clusters ])


In [None]:

from numpy.random import default_rng
samples = default_rng().random

def kmeans(dataset,k):
	close_enough = False
	(N,d) = dataset.shape
	means = samples((k,d))
	while not close_enough:
		clusters = assign_clusters(dataset,means)
		print([len(c) for c in clusters])
		new_means = update_means(clusters)
		# only check closeness if number of means unchanged
		if len(new_means) == len(means): 
			close_enough = allclose(means,new_means)
		means = new_means
	return means, clusters

d, k, N = 2, 7, 100
dataset = samples((N,d))

means, clusters = kmeans(dataset, k)


In [None]:

def plot_cluster(mean,cluster,color,marker):
	scatter(*array(cluster).T, s = 30, c = color, marker = marker)
	scatter(*mean, s=20, c=color, marker='*')


In [None]:

from matplotlib.pyplot import *

from numpy.random import default_rng
samples = default_rng().random

d = 2
k,N = 7,100

from random import choice

def hexcolor():
	chars = '0123456789abcdef'
	return "#" + ''.join([choice(chars) for _ in range(6)])

dataset = samples((N,d))
means = samples((k,d))
colors = [ hexcolor() for _ in range(k) ]

scatter(*dataset.T,s = 10)
grid()
show()

means, clusters = kmeans(dataset,k)

for i,cluster in enumerate(clusters):
	tex = '$' + str(i) + '$'
	mean = means[i]
	color = colors[i]
	plot_cluster(mean,cluster,color,tex)

grid()
show()
