KMeans聚类法

一、Kmeans内容

Video:~Bilibili-ln异教徒 Kmeans

  • 随机生成k个聚类中心点

    • 其中k值的确定可以采用“手肘法”,通过分析不同k值下误差平方和(SSE)的值的变化,选取其”肘部“作为合适的聚类数目

      其中$C_i$是第i个类,p是$C_i$中的样本点,$m_i$是$C_i$的质心($C_i$中所有样本点的均值),SSE是所有样本的聚类误差,代表了聚类效果的好坏。通过找寻SSE的值的下降幅度剧烈减小的位置来判断合适的k取值

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      function SSE = funcSSE(data,total) % data为数据,total为要划分的类别总数
      n = size(data,2); % data的ColCount
      for i = 2:total
      k = i;
      [label,c] = kmeans(data,k);
      temp = 0;
      for j = 1:i % 按照分类数量遍历
      loc = find(label == j) % 找到每个类别所对应的行标(Row),形成一个Arr
      for j2 = 1:n % 按照Col遍历
      temp = temp + sum((data(loc,j2)-c(j,j2)).^2)
      % 通过loc得到一个row区域内的所有label与c索引相同的数据
      % 依次loc所对应数据块中的每一列与c同列的数据相减并求点成的平方
      % 随后sum求和处理后每一列loc数据的和并且与先前的其他列数据结果相加
      % 将每一个label相同的data和c的结果再相加,全部用temp承接数值
      % 实际上是将原本每一数据点同其c的减法求平方(距离)过程纵向化了
      % 用于矩阵运算的遍历,以及其最终得到的结果就是每个数据每个坐标的距离
      % 所以可以用这种简单的方法Coding
      end
      end
      SSE(i-1) = temp % 令SSE这个Arr的数值仍然从Index=1开始
      end
      end
  • 根据聚类中心点,将数据分为k类。分类的原则是数据离哪个中心点最近就将他分为哪一类别

    • 距离的判断采用欧式距离(欧几里得距离):

      令各类中心为 $\mu_j$ ,计算各个数据到各类中心点的距离为$(x-\mu_j)^2$

  • 根据分好的类别的数据,重新计算聚类的类别中心点

  • 不断重复上述的两步,直到中心点不再变化,得到聚类结果

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    clear
    clc
    %% 数据导入
    data = xlsread('./k-means.xlsx');
    %% SSE手肘法判断
    km = 10; % 所需的k最大测试值
    SSE = funcSSE(data,km);
    plot(2:km,SSE)
    k = 5; % 选择 k=5
    %% Kmeans聚类
    [label,center] = kmeans(data,k)
    %% DispPlot
    figure
    scatter(0,data(label==1),'cyan')
    hold on
    scatter(0,data(label==2),'b')
    scatter(0,data(label==3),'g')
    scatter(0,data(label==4),'y')
    scatter(0,data(label==5),'black')
    for i = 1:k
    plot([-0.1,0.1],[center(i),center(i)],'r')
    end
    xlim([-1,1])

二、Python可使用代码

1
2
3
4
5
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import sklearn
1
2
plt.rcParams['font.family'] = 'Consolas'
plt.rcParams['font.sans-serif'] = ['Consolas']
1
2
3
4
5
6
7
8
9
# 随机数据
import random
random.seed(66)
x = [random.randint(1,100) for _ in range(100)]
y = [random.randint(1,100) for _ in range(100)]
data=list(zip(x,y))
df = pd.DataFrame(data)
df.columns = ['v1', 'v2']
df.head()
1
2
3
4
plt.scatter(x, y)
plt.xlabel('X')
plt.ylabel('Y')
plt.show()
1
2
3
4
5
6
7
8
9
10
11
12
# 手肘法判断图
from sklearn.metrics import silhouette_score
from sklearn.cluster import KMeans

score01 = [ ]
for i in range(2, 7):
km1 = KMeans(n_clusters=i, random_state=42)
kcat1 = km1.fit_predict(df)
score = silhouette_score(df, kcat1)
score01.append(score)
plt.plot(range(2, 7),score01)
plt.show()
1
2
3
4
5
6
7
8
9
10
from sklearn.cluster import KMeans
n = 4
km = KMeans(n_clusters=n, random_state=42)
km.fit(df)
centers = km.cluster_centers_
print(centers)
kcat = km.predict(df)
df_cat = pd.concat([df,pd.DataFrame(kcat)],axis=1)
df_cat.columns = list(df.columns) + ['Cate']
df_cat.head()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
colorg1 = ['#1f77b4', '#2ca02c', '#ff7f0e', '#7f7f7f', '#9467bd']
colorg2 = ['#2ecc71', '#3498db', '#e74c3c', '#9b59b6', '#f39c12']

for i in range(n):
plt.scatter(df_cat[df_cat['Cate'] == i].iloc[:,0],
df_cat[df_cat['Cate'] == i].iloc[:,1],
label = f'Cate{i}',color = colorg1[i],)
for index, row in df_cat[df_cat['Cate'] == i].iterrows():
plt.plot([row['v1'], centers[i, 0]], [row['v2'], centers[i, 1]], color=colorg1[i])

plt.scatter(centers[:,0],centers[:,1],color = 'r', s = 100, zorder = 5)
plt.title('KMeans Result',size = 20)
plt.xlabel('V1', size = 15)
plt.ylabel('V2', size = 15)
plt.legend()
plt.savefig(r'./PIC.svg',dpi=500)
plt.show()
1
df_cat.to_csv(r'./kmoutput.csv')

三、Python展开Kmeans聚类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# 生成测试数据
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

random.seed(1)
kdatax = [random.randint(1,30) for _ in range(30)]
kdatay = [random.randint(1,30) for _ in range(30)]
kdata = pd.DataFrame([kdatax,kdatay]).transpose()
plt.plot(kdata.iloc[:,0], kdata.iloc[:,1], 'ro')
kdata.head()

colorg1 = ['#1f77b4', '#2ca02c', '#ff7f0e', '#7f7f7f', '#9467bd']

def show(kdata, cen, ax, k):
plt.plot()

def Dep(ax, t, cen,):
pls = [ ]
for d in t:
dval = [ ]
dsum = 0
for i in range(len(cen)):
dval.append(np.sqrt(sum((cen[i] - d)**2)))
place = dval.index(min(dval))
pls.append(place)
return pls

def Init_cen(kdata, k, ax):
centers = [ ]
for j in range(k):
subcenters = [ ]
for i in range(ax):
maxval = (kdata.iloc[:,i].max())
minval = (kdata.iloc[:,i].min())
subcenters.append(random.randint(minval,maxval))
centers.append(subcenters)
print('Init:',centers)
return centers


def KmeansH(kdata, k, seed = 0,centers= [0],co = 0, ax = 0, karray = 0):
if ax ==0:
ax = len(kdata.columns)
karray = np.array(kdata)
if seed !=0:
random.seed(seed)
if co == 0:
centers = Init_cen(kdata, k, ax)
co = 1
else:
co += 1
label = Dep(ax, karray, centers)
kdatanew = pd.concat((kdata,pd.DataFrame(label)),axis=1)

cn1 = [ ]
for i in range(k):
cn2 = [ ]
for j in range(ax):
cn2.append(np.mean(kdatanew[kdatanew.iloc[:,ax]==i].iloc[:,j]))
cn1.append(cn2)


print(co,centers)

if abs((np.array(cn1) - np.array(centers)).sum()) <= 1e-9:
random.seed()
return kdatanew
else:
return KmeansH(kdata, k, seed,cn1 , co, ax, karray)

k=4
kdatak = KmeansH(kdata,k,42)

centers = pd.DataFrame([[25.0, 10.416666666666666], [9.555555555555555, 24.666666666666668], [14.166666666666666, 10.0], [3.0, 5.333333333333333]])
centers

colorg1 = ['#1f77b4', '#2ca02c', '#ff7f0e', '#7f7f7f', '#9467bd']

for i in range(k):

plt.plot(kdatak[kdatak.iloc[:,2]==i].iloc[:,0],
kdatak[kdatak.iloc[:,2]==i].iloc[:,1],
'o',color = colorg1[i])

plt.plot(centers.iloc[:,0],centers.iloc[:,1],'r*')
plt.show()