Winner-Take-All

  以下,複数のファイル構成になっています.ファイル間の区切りを「---・・・」で示します.

------------------------入力ファイル--------------
最大試行回数 100 入力セルの数 2 出力セルの数 2 訓練例の数 4 乱数 123
入力データファイル or.dat

------------------------or.dat--------------------
OR演算の訓練例.各行における最後の2つのデータが目標出力値
-1 -1 -1 1
-1  1  1 -1
 1 -1  1 -1
 1  1  1 -1

----------------------プログラム-----------------
/****************************/
/* Winner-Take-All Groups   */
/*      coded by Y.Suganuma */
/****************************/
import java.io.*;
import java.util.Random;
import java.util.StringTokenizer;

public class Test {
	/****************/
	/* main program */
	/****************/
	public static void main(String args[]) throws IOException, FileNotFoundException
	{
		int max, n, o, p, seed;
		StringTokenizer str;
		String name;

		if (args.length > 0) {
					// 基本データの入力
			BufferedReader in = new BufferedReader(new FileReader(args[0]));

			str = new StringTokenizer(in.readLine(), " ");
			str.nextToken();
			max = Integer.parseInt(str.nextToken());
			str.nextToken();
			p = Integer.parseInt(str.nextToken());
			str.nextToken();
			o = Integer.parseInt(str.nextToken());
			str.nextToken();
			n = Integer.parseInt(str.nextToken());
			str.nextToken();
			seed = Integer.parseInt(str.nextToken());

			str = new StringTokenizer(in.readLine(), " ");
			str.nextToken();
			name = str.nextToken();

			in.close();
					// ネットワークの定義
			Winner net = new Winner (max, n, o, p, seed);
			net.input(name);
					// 学習と結果の出力
			if (args.length == 1)
				net.learn(0, "");
			else
				net.learn(1, args[1]);
		}
					// エラー
		else {
			System.out.print("***error   入力データファイル名を指定して下さい\n");
			System.exit(1);
		}
	}
}

/**********************/
/* Winnerクラスの定義 */
/**********************/
class Winner {

	private int max;   // 最大学習回数
	private int n;   // 訓練例の数
	private int o;   // 出力セルの数
	private int p;   // 入力セルの数
	private int W_p[][];   // 重み(ポケット)
	private int W[][];   // 重み
	private int E[][];   // 訓練例
	private int C[][];   // 各訓練例に対する正しい出力
	private int Ct[];   // 作業領域
	private Random rn;   // 乱数

	/******************/
	/* コンストラクタ */
	/******************/
	Winner (int max_i, int n_i, int o_i, int p_i, int seed)
	{
	/*
	     設定
	*/
		max = max_i;
		n   = n_i;
		o   = o_i;
		p   = p_i;

		rn  = new Random(seed);   // 乱数の初期設定
	/*
	     領域の確保
	*/
		E   = new int [n][p+1];
		W_p = new int [o][p+1];
		W   = new int [o][p+1];
		C   = new int [n][o];
		Ct  = new int [o];
	}

	/******************************************/
	/* 訓練例の分類                           */
	/*      return : 正しく分類した訓練例の数 */
	/******************************************/
	int bunrui()
	{
		int cor, i1, i2, i3, mx = 0, mx_v = 0, num = 0, s, sw = 0;

		for (i1 = 0; i1 < n; i1++) {
			cor = 0;
			for (i2 = 0; i2 < o; i2++) {
				if (C[i1][i2] == 1)
					cor = i2;
				s = 0;
				for (i3 = 0; i3 <= p; i3++)
					s += W[i2][i3] * E[i1][i3];
				if (i2 == 0) {
					mx   = 0;
					mx_v = s;
				}
				else {
					if (s > mx_v) {
						mx   = i2;
						mx_v = s;
						sw   = 0;
					}
					else {
						if (s == mx_v)
							sw = 1;
					}
				}
			}

			if (sw == 0 && cor == mx)
				num++;
		}

		return num;
	}

	/**************************/
	/* 学習データの読み込み   */
	/*      name : ファイル名 */
	/**************************/
	void input (String name) throws IOException, FileNotFoundException
	{
		int i1, i2;
		StringTokenizer str;

		BufferedReader st = new BufferedReader(new FileReader(name));

		str = new StringTokenizer(st.readLine(), " ");

		for (i1 = 0; i1 < n; i1++) {
			E[i1][0] = 1;
			str = new StringTokenizer(st.readLine(), " ");
			for (i2 = 1; i2 <= p; i2++)
				E[i1][i2] = Integer.parseInt(str.nextToken());
			for (i2 = 0; i2 < o; i2++)
				C[i1][i2] = Integer.parseInt(str.nextToken());
		}

		st.close();
	}

	/*********************************/
	/* 学習と結果の出力              */
	/*      pr : =0 : 画面に出力     */
	/*           =1 : ファイルに出力 */
	/*      name : 出力ファイル名    */
	/*********************************/
	void learn(int pr, String name) throws FileNotFoundException
	{
		int i1, i2, i3, mx = 0, mx_v = 0, n_tri, s, sw;
		int num[] = new int [1];
					// 学習
		n_tri = pocket(num);
					// 結果の出力
		if (pr == 0) {

			System.out.print("重み\n");
			for (i1 = 0; i1 < o; i1++) {
				for (i2 = 0; i2 <= p; i2++)
					System.out.print("  " + W_p[i1][i2]);
				System.out.println();
			}

			System.out.print("分類結果\n");
			for (i1 = 0; i1 < n; i1++) {
				sw = 0;
				for (i2 = 0; i2 < o; i2++) {
					s = 0;
					for (i3 = 0; i3 <= p; i3++)
						s += W_p[i2][i3] * E[i1][i3];
					if (i2 == 0) {
						mx_v = s;
						mx   = 0;
					}
					else {
						if (s > mx_v) {
							sw   = 0;
							mx_v = s;
							mx   = i2;
						}
						else {
							if (s == mx_v)
								sw = 1;
						}
					}
				}
				for (i2 = 1; i2 <= p; i2++)
					System.out.print(" " + E[i1][i2]);
				System.out.print(" Cor ");
				for (i2 = 0; i2 < o; i2++)
					System.out.print(" " + C[i1][i2]);
				if (sw > 0)
					mx = -1;
				System.out.println(" Res " + (mx+1));
			}
		}

		else {

			PrintStream out = new PrintStream(new FileOutputStream(name));

			out.print("重み\n");
			for (i1 = 0; i1 < o; i1++) {
				for (i2 = 0; i2 <= p; i2++)
					out.print("  " + W_p[i1][i2]);
				out.println();
			}

			out.print("分類結果\n");
			for (i1 = 0; i1 < n; i1++) {
				sw = 0;
				for (i2 = 0; i2 < o; i2++) {
					s = 0;
					for (i3 = 0; i3 <= p; i3++)
						s += W_p[i2][i3] * E[i1][i3];
					if (i2 == 0) {
						mx_v = s;
						mx   = 0;
					}
					else {
						if (s > mx_v) {
							sw   = 0;
							mx_v = s;
							mx   = i2;
						}
						else {
							if (s == mx_v)
								sw = 1;
						}
					}
				}
				for (i2 = 1; i2 <= p; i2++)
					out.print(" " + E[i1][i2]);
				out.print(" Cor ");
				for (i2 = 0; i2 < o; i2++)
					out.print(" " + C[i1][i2]);
				if (sw > 0)
					mx = -1;
				out.println(" Res " + (mx+1));
			}

			out.close();
		}

		if (n == num[0])
			System.out.print("  !!すべてを分類(試行回数:" + n_tri + ")\n");
		else
			System.out.print("  !!" + num[0] + " 個を分類\n");
	}

	/********************************************/
	/* Pocket Algorith with Ratcet              */
	/*      num_p : 正しく分類した訓練例の数    */
	/*      return : =0 : 最大学習回数          */
	/*               >0  : すべてを分類(回数) */
	/********************************************/
	int pocket(int num_p[])
	{
		int cor, count = 0, i1, i2, k, mx = 0, num, run = 0, run_p = 0, s, sw = -1, sw1;
	/*
	     初期設定
	*/
		num_p[0] = 0;

		for (i1 = 0; i1 < o; i1++) {
			for (i2 = 0; i2 <= p; i2++)
				W[i1][i2] = 0;
		}
	/*
	     実行
	*/
		while (sw < 0) {
					// 終了チェック
			count++;;
			if (count > max)
				sw = 0;

			else {
					// 訓練例の選択
				k = (int)(rn.nextDouble() * n);
				if (k >= n)
					k = n - 1;
					// 出力の計算
				sw1 = 0;
				cor = -1;

				for (i1 = 0; i1 < o; i1++) {

					if (C[k][i1] == 1)
						cor = i1;

					s = 0;
					for (i2 = 0; i2 <= p; i2++)
						s += W[i1][i2] * E[k][i2];
					Ct[i1] = s;

					if (i1 == 0)
						mx = 0;
					else {
						if (s > Ct[mx]) {
							mx = i1;
							sw1 = 0;
						}
						else {
							if (s == Ct[mx]) {
								sw1 = 1;
								if (cor >= 0 && mx == cor)
									mx = i1;
							}
						}
					}
				}
					// 正しい分類
				if (sw1 == 0 && cor == mx) {
					run++;
					if (run > run_p) {
						num = bunrui();
						if (num > num_p[0]) {
							num_p[0] = num;
							run_p    = run;
							for (i1 = 0; i1 < o; i1++) {
								for (i2 = 0; i2 <= p; i2++)
									W_p[i1][i2] = W[i1][i2];
							}
							if (num == n)
								sw = count;
						}
					}
				}
					// 誤った分類
				else {
					run  = 0;
					for (i1 = 0; i1 <= p; i1++) {
						W[cor][i1] += E[k][i1];
						W[mx][i1]  -= E[k][i1];
					}
				}
			}
		}

		return sw;
	}
}
		

  コンパイルした後,

java Test 入力ファイル名 出力ファイル名

と入力してやれば実行できます.出力ファイル名は,結果を出力するファイルの名前であり,省略すると画面に出力されます.また,入力ファイル名は,実行に必要なデータを記述したファイルの名前であり,たとえば以下のような形式で作成します.
最大試行回数 100 入力セルの数 2 出力セルの数 2 訓練例の数 4 乱数 123 入力データファイル or.dat
  日本語で記述した部分(「最大試行回数」,「入力セルの数」等)は,次に続くデータの説明ですのでどのように修正しても構いませんが,削除したり,または,複数の文(間に半角のスペースを入れる)にするようなことはしないでください.各データの意味は以下に示す通りです.
最大試行回数

  学習回数を入力します.この例では 100 を与えています.

入力セルの数

  入力セル(入力ユニット)の数を入力します.この例では 2 となっています.

出力セルの数

  出力セル(出力ユニット)の数を入力します.この例では 2 となっています.

訓練例の数

  訓練例の数を入力します(この例では 4 ).訓練例は,「入力データファイル」の項に入力されたファイル(この例では,or.dat )に記述します.ファイル or.dat は,たとえば,以下のようになります.
OR演算の訓練例.各行における最後の2つのデータが目標出力値 -1 -1 -1 1 -1 1 1 -1 1 -1 1 -1 1 1 1 -1
  1 行目はこのファイルの説明であり,何を記述しても構いません.ただし,文全体を削除したり,文の途中に半角のスペースを入れるようなことはしないでください.2 行目以下が 4 つの訓練例を表しています.各訓練例において,最初の 2 つの値が各入力ユニットに入力される値であり,3 番目および 4 番目の値が,そのときの 1 番目および 2 番目の出力ユニットに対する目標出力値になっています.たとえば,2 行目における「-1 1」とは,2 番目の出力ユニットが発火することを意味しています.

乱数

  乱数の初期値です.

  上で説明したデータの元で実行すると,たとえば,以下のような出力が得られます.出力ファイル名を指定した場合は,最後の 1 行だけがコンソールに,残りはファイルに出力されます.
重み 1 1 1 -1 -1 -1 分類結果 -1 -1 Cor -1 1 Res 2 -1 1 Cor 1 -1 Res 1 1 -1 Cor 1 -1 Res 1 1 1 Cor 1 -1 Res 1 !!すべてを分類(試行回数:11)
  2 行目および 3 行目の 2 番目以降のデータが,各入力ユニットから各出力ユニット( 2 行目が 1 番目,3 行目が 2 番目の出力ユニットに対応)へ向かう枝に付けられた重みです.また,各行の 1 番目のデータはバイアスです.分類結果において,Cor の次に出力された値が,目標出力値です.例えば,5 行目は,入力が「-1 -1」である時の目標出力値は「-1 1」であること,つまり,2 番目の出力ユニットが発火すべきであることを示しています.また,Res の後の値が実際の計算(分類)結果です.この例では,目標通り,2 番目の出力ユニットが発火したことを意味しています.なお,0 は,分類の失敗を意味しています.