kmeans算法(java实现)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
package com.redbean.noveldata;

import lombok.Data;

import java.util.*;

/**
* @author redbean
* @create 2022/4/15 - 16:39
*/
public class kMeansTest {

public static void main(String[] args) {
ArrayList<float[]> dataSet = new ArrayList<float[]>();
dataSet.add(new float[] {1000025, 434464, 24143});
dataSet.add(new float[] {1000060, 557784, 62171});
dataSet.add(new float[] { 1, 4, 4});
dataSet.add(new float[] { 2, 6, 5});
dataSet.add(new float[] { 3, 9, 6});
dataSet.add(new float[] { 4, 5, 4});
dataSet.add(new float[] { 5, 4, 2});
dataSet.add(new float[] { 6, 9, 7});
dataSet.add(new float[] { 7, 9, 8});
dataSet.add(new float[] { 8, 2, 10});
dataSet.add(new float[] { 9, 9, 12});
dataSet.add(new float[] { 10, 8, 112});
dataSet.add(new float[] { 11, 8, 4});
KMeans kMeans = new KMeans(3, dataSet);
Set<Cluster> run = kMeans.run();
System.out.println("迭代次数:" + kMeans.getIterRunTimes());
for (Cluster cluster : run) {
System.out.println(cluster.getId());
System.out.println(cluster.getCenter());
System.out.println(cluster.getMembers());
System.out.println("====================");
}
}
}

@Data// 点的格式
class Point {
private float[] novelInfo;
private float Id;
private int clusterId; // 簇id
private float dist; // 与中心簇的欧几里得距离

public Point(float Id, float[] novelInfo) {
this.novelInfo = novelInfo;
this.Id = Id;
}

public Point(float[] novelInfo) {
this.Id = -1; // 表示不属于任何一个类
this.novelInfo = novelInfo;
}
}

@Data
// 簇群
class Cluster {
private int id; // 簇id
private Point center; // 中心簇点
private List<Point> members = new ArrayList<>(); // 成员

public Cluster(int id, Point center) {
this.id = id;
this.center = center;
}

public Cluster(int id, Point center, List<Point> members) {
this.id = id;
this.center = center;
this.members = members;
}

public void addPoint(Point newPoint) {
if (!members.contains(newPoint)) {
members.add(newPoint);
} else {
System.out.println("<<<<<<<<<<<<<<<<<< 样本数据{" + newPoint + "} 已经存在>>>>>>>>>>>>>>>>>>>>>>>");
}
}
}

// 计算欧氏距离
class DistanceCompute {
public double getEuclideanDis(Point point1, Point point2) {
float[] novelInfo1 = point1.getNovelInfo();
float[] novelInfo2 = point2.getNovelInfo();
double dist_temp = 0;

for (int i = 0; i < novelInfo1.length; i++) {
dist_temp += Math.pow(novelInfo1[i] - novelInfo2[i], 2);
}
return Math.sqrt(dist_temp);
}
}

// 计算类,输入:k,原始数据
class KMeans {
private int kNum; // 簇的个数
private final int ITER_SUM = 10; // 迭代次数

private final int ITER_MAX_TIMES = 1000000000; // 单次迭代的最大运行次数
private int iterRunTimes = 0; // 单次迭代运行的实际次数
private final float ITER_STOP = (float) 0.01; // 单次迭代终止条件,类的距离差

private List<float[]> novelData = null; // 原始数据
private static List<Point> pointList = null; // 原始数据构成的点集合
private DistanceCompute distanceCompute = new DistanceCompute(); // 创建欧氏距离计算
private int len = 0; // 每个数据点的维度

public KMeans(int kNum, List<float[]> novelData) {
this.kNum = kNum;
this.novelData = novelData;
// 初始化点集
init();
}

// 将每个原始数据转换为点类
private void init() {
pointList = new ArrayList<Point>();
len = novelData.get(0).length - 1;
// 将novelid赋值点id,后面的数据赋值给点novelinfo
int tmp_len = novelData.get(0).length - 1;
for (int i = 0, j = novelData.size(); i < j; i++) {
float[] tmp = new float[tmp_len];;
for (int f = 0; f < tmp_len; f++) {
tmp[f] = novelData.get(i)[f + 1];
}
pointList.add(new Point(novelData.get(i)[0], tmp));
}
}

// 第一次,随机选择中心点,构成中心簇点(这个时候一个簇里只有一个点,别的都没有)
public Set<Cluster> chooseCenterCluster() {
HashSet<Cluster> clusterHashSet = new HashSet<Cluster>();
Random random = new Random();
// 创建k个簇
for (int id = 0; id < kNum; ) {
// 中心簇点
Point point = pointList.get(random.nextInt(pointList.size()));
// 标记是否已经选择该数据
boolean flag = true;
for (Cluster cluster : clusterHashSet) {
if (cluster.getCenter().equals(point)) {
flag = false;
}
}
if (flag) {
Cluster cluster = new Cluster(id, point);
clusterHashSet.add(cluster);
id++;
}
}
return clusterHashSet;
}

// 对每个点分配一个簇
// 传入簇
public void cluster(Set<Cluster> clusterSet) {
// 计算每个点到每个中心簇点的欧式距离,并为每个点标记簇号
for (Point point : pointList) {
float min_dis = Integer.MAX_VALUE;
for (Cluster cluster : clusterSet) {
// 跟每个簇心进行欧式距离比较,如果后一个簇心的距离比前一个簇心的距离要小的话,这个点归为后一个簇类中
// 如果到后一个簇心的距离比到前一个簇心大的话,保持前一个簇心的距离和类别
float tmp_dis = (float) Math.min(distanceCompute.getEuclideanDis(point, cluster.getCenter()), min_dis);
if (tmp_dis != min_dis) {
min_dis = tmp_dis;
point.setClusterId(cluster.getId());
point.setDist(min_dis);
}
}
}
// 以上操作只是对点进行标记,并没有对簇进行改动
// 下面清空簇的成员,按照点的簇id重新加载簇
for (Cluster cluster : clusterSet) {
cluster.getMembers().clear();
for (Point point : pointList) {
if (point.getClusterId() == cluster.getId()) {
cluster.addPoint(point);
}
}
}
}

// 计算每个类的中心位置
public boolean getUpdate(Set<Cluster> hashSet) {
boolean isNeedIter = false;
for (Cluster cluster : hashSet) {
List<Point> members = cluster.getMembers();
float[] sumAll = new float[len];
// 各维度求和,(按簇类的列求和)
for (int i = 0; i < len; i++) {
for (int j = 0; j < members.size(); j++) {
sumAll[i] += members.get(j).getNovelInfo()[i];
}
}
// 计算平均值
for (int i = 0; i < len; i++) {
sumAll[i] = (float) sumAll[i] / members.size();
}

// 计算新旧簇心的距离,如果距离大于设定的距离ITER_STOP则继续迭代
if (distanceCompute.getEuclideanDis(cluster.getCenter(), new Point(sumAll)) > ITER_STOP) {
isNeedIter = true;
}
// 设置簇的中心位置
cluster.setCenter(new Point(sumAll));
}
return isNeedIter;
}

// 运行kmeans
public Set<Cluster> run() {
Set<Cluster> clusters = chooseCenterCluster();
boolean ifNeedIter = true;
while (ifNeedIter) {
cluster(clusters);
ifNeedIter = getUpdate(clusters);
iterRunTimes ++;
}
return clusters;
}

// 返回真正的运行次数
public int getIterRunTimes() {
return iterRunTimes;
}
}