k平均++法;k-means++
概要
k-means++は多次元ベクトルで表現されたデータ集合を与えた数のクラスタに分けるクラスタリング手法の一つである。
サポートベクタマシンなどの教師あり学習に属するクラスタリングとは異なり、データに対してどのクラスに属するかは学習時に与えない、教師なし学習に属するクラスタリング手法である。
各クラスタの境界面は母点をクラスタ中心としたボロノイ図の境界になる
アルゴリズム
・初期値決定
データ集合からデータ点をランダムに選び1つ目のクラスタ中心にする
\(k\)個のクラスタ中心を選ぶまで以下を繰り返す
-各データ点\(\boldsymbol{x_i}\)の最近傍クラスタ中心との距離\(D_i\)を求める
-新しいクラスタ中心として\(\boldsymbol{x_i}\)が選ばれる確率を\(D_i / \sum D_i \)としランダムに決定
・初期クラスタ割当て
各データ点\(\boldsymbol{x_i}\)のクラスタ番号を最近傍のクラスタ中心に割当てられているクラスタ番号にする
・k-mean収束ループ
各データ点\(\boldsymbol{x_i}\)のクラスタ番号に変化がなくなるまで以下を繰り返す
-クラスタ番号ごとのデータ点の平均ベクトルを新しいクラスタ中心とする
-各データ点\(\boldsymbol{x_i}\)のクラスタ番号を最近傍のクラスタ中心に割当てられているクラスタ番号にする
k-means++のコード
namespace Clustering {
/// <summary>k-means++法</summary>
public class KMeansClustering : IClusteringMethod {
protected class LabelVector {
public Vector Vector { get; set; }
public int Label { get; set; }
}
Vector[] center_vectors;
/// <summary>コンストラクタ</summary>
/// <param name="group_count">データクラス数</param>
public KMeansClustering(int group_count) {
if(group_count <= 1) {
throw new ArgumentException(nameof(group_count));
}
this.GroupCount = group_count;
}
/// <summary>データクラス数</summary>
public int GroupCount {
get; private set;
}
/// <summary>ベクトルの次元数</summary>
public int VectorDim {
get; private set;
}
/// <summary>クラスタ中心ベクトル</summary>
public Vector[] CenterVectors => center_vectors;
/// <summary>単一サンプルを分類</summary>
/// <param name="vector">サンプルベクタ</param>
public int Classify(Vector vector) {
return NearestVector(vector);
}
/// <summary>複数サンプルを分類</summary>
/// <param name="vectors">サンプルベクタ集合</param>
public IEnumerable<int> Classify(IEnumerable<Vector> vectors) {
return vectors.Select((vector) => Classify(vector));
}
/// <summary>学習</summary>
/// <param name="vector_dim">サンプルベクタ次元数</param>
/// <param name="vectors_groups">データクラスごとのサンプルベクタ集合</param>
public void Learn(int vector_dim, params List<Vector>[] vectors_groups) {
Initialize();
ValidateSample(vector_dim, vectors_groups);
center_vectors = new Vector[GroupCount];
VectorDim = vector_dim;
Random random = new Random(0);
var vectors = vectors_groups[0].ToArray();
int vector_count = vectors.Length;
// K-means++初期値決定シークエンス
center_vectors[0] = vectors[random.Next(vector_count)];
for(int group_index = 1; group_index < GroupCount; group_index++) {
double dist_sum = 0;
double[] dist_list = new double[vector_count];
for(int vector_index = 0; vector_index < vector_count; vector_index++) {
double dist_min = double.PositiveInfinity;
for(int cluster_index = 0; cluster_index < group_index; cluster_index++) {
double dist = Vector.SquareDistance(vectors[vector_index], center_vectors[cluster_index]);
if(dist < dist_min) {
dist_min = dist;
}
}
dist_sum += dist_list[vector_index] = dist_min;
}
double r = random.NextDouble() * dist_sum;
for(int vector_index = 0; vector_index < vector_count; vector_index++) {
r -= dist_list[vector_index];
if(r < 0) {
center_vectors[group_index] = vectors[vector_index];
break;
}
center_vectors[group_index] = vectors[vector_count - 1];
}
}
// クラスタ割当て
var labeled_vectors = vectors.Select((vector) => new LabelVector { Vector = vector, Label = NearestVector(vector) }).ToArray();
bool ischanged_label = true;
// k-mean収束ループ
while(ischanged_label) {
ischanged_label = false;
for(int cluster_index = 0; cluster_index < center_vectors.Length; cluster_index++) {
center_vectors[cluster_index] = Vector.Zero(VectorDim);
}
int[] label_count = new int[center_vectors.Length];
foreach(var vector in labeled_vectors) {
center_vectors[vector.Label] += vector.Vector;
label_count[vector.Label]++;
}
for(int cluster_index = 0; cluster_index < center_vectors.Length; cluster_index++) {
center_vectors[cluster_index] /= label_count[cluster_index];
}
for(int vector_index = 0; vector_index < vector_count; vector_index++) {
var labeled_vector = labeled_vectors[vector_index];
int label_old = labeled_vector.Label;
int label_new = NearestVector(labeled_vector.Vector);
if(label_old != label_new) {
ischanged_label = true;
}
labeled_vector.Label = label_new;
}
}
}
/// <summary>初期化</summary>
public void Initialize() {
center_vectors = null;
}
/// <summary>最近傍のベクトルを探索</summary>
protected int NearestVector(Vector vector) {
double dist_min = double.PositiveInfinity;
int nearest_cluster_index = 0;
for(int cluster_index = 0; cluster_index < center_vectors.Length; cluster_index++) {
double dist = Vector.SquareDistance(vector, center_vectors[cluster_index]);
if(dist < dist_min) {
dist_min = dist;
nearest_cluster_index = cluster_index;
}
}
return nearest_cluster_index;
}
/// <summary>サンプルの正当性を検証</summary>
private void ValidateSample(int vector_dim, List<Vector>[] vectors_groups) {
if(vector_dim < 1) {
throw new ArgumentException(nameof(vector_dim));
}
if(vectors_groups == null) {
throw new ArgumentNullException(nameof(vectors_groups));
}
if(vectors_groups.Length != 1) {
throw new ArgumentException(nameof(vectors_groups));
}
foreach(var vectors in vectors_groups) {
if(vectors.Count < GroupCount) {
throw new ArgumentException(nameof(vectors_groups));
}
foreach(var vector in vectors) {
if(vector.Dim != vector_dim) {
throw new ArgumentException(nameof(vectors_groups));
}
}
}
}
}
}
実行サンプル
Random random = new Random(9);
List<Vector> vectors = new List<Vector>();
Vector center;
center = new Vector(2 * random.NextDouble() - 1, 2 * random.NextDouble() - 1);
for(int i = 0; i < 25; i++) {
vectors.Add(center + new Vector(0.8 * random.NextDouble() - 0.4, 0.8 * random.NextDouble() - 0.4));
}
center = new Vector(2 * random.NextDouble() - 1, 2 * random.NextDouble() - 1);
for(int i = 0; i < 20; i++) {
vectors.Add(center + new Vector(0.8 * random.NextDouble() - 0.4, 0.8 * random.NextDouble() - 0.4));
}
center = new Vector(2 * random.NextDouble() - 1, 2 * random.NextDouble() - 1);
for(int i = 0; i < 20; i++) {
vectors.Add(center + new Vector(0.8 * random.NextDouble() - 0.4, 0.8 * random.NextDouble() - 0.4));
}
center = new Vector(2 * random.NextDouble() - 1, 2 * random.NextDouble() - 1);
for(int i = 0; i < 25; i++) {
vectors.Add(center + new Vector(0.8 * random.NextDouble() - 0.4, 0.8 * random.NextDouble() - 0.4));
}
KMeansClustering kmean = new KMeansClustering(4);
kmean.Learn(2, vectors);
Plot(kmean, vectors, $"../../../plot/kmeans.png");
サンプルコードPlot関数
static void Plot(KMeansClustering kmean, List<Vector> vectors, string filepath) {
Bitmap image = new Bitmap(500, 500);
Color[] colors = new Color[]{ Color.FromArgb(255, 128, 128), Color.FromArgb(128, 128, 255), Color.FromArgb(128, 255, 128), Color.FromArgb(255, 255, 128) };
for(int x, y = 0; y < image.Height; y++) {
for(x = 0; x < image.Width; x++) {
double vx = (x - image.Width / 2) / (image.Width * 0.4);
double vy = (y - image.Height / 2) / (image.Height * 0.4);
int cluster_index = kmean.Classify(new Vector(vx, vy));
image.SetPixel(x, y, colors[cluster_index]);
}
}
using(Graphics g = Graphics.FromImage(image)) {
g.SmoothingMode = System.Drawing.Drawing2D.SmoothingMode.HighQuality;
foreach(var vector in vectors) {
int x = (int)((4 * vector.X + 5) / 10 * image.Width);
int y = (int)((4 * vector.Y + 5) / 10 * image.Height);
int cluster_index = kmean.Classify(vector);
var color = colors[cluster_index];
g.DrawEllipse(new Pen(Color.Black, 2), new Rectangle(new Point(x - 3, y - 3), new Size(6, 6)));
g.FillEllipse(new SolidBrush(color), new Rectangle(new Point(x - 3, y - 3), new Size(6, 6)));
}
foreach(var vector in kmean.CenterVectors) {
int x = (int)((4 * vector.X + 5) / 10 * image.Width);
int y = (int)((4 * vector.Y + 5) / 10 * image.Height);
int cluster_index = kmean.Classify(vector);
var color = colors[cluster_index];
g.DrawRectangle(new Pen(Color.Black, 2), new Rectangle(new Point(x - 5, y - 5), new Size(10, 10)));
g.FillRectangle(new SolidBrush(color), new Rectangle(new Point(x - 5, y - 5), new Size(10, 10)));
}
}
image.Save(filepath, ImageFormat.Png);
}
結果関連項目
クラスタリング手法インタフェースクラス