以下,複数のファイル構成になっていますので,ファイル間の 区切りを「---・・・」で示します. ------------------------makefile------------------ # # リンク # CFLAGS = -c -Wall -O2 OBJECT = win.o Bunrui.o Cons.o Dest.o Input.o Learn.o Pocket.o pgm: $(OBJECT) g++ $(OBJECT) -o win -lm # # コンパイル # win.o: winner.h win.cpp g++ $(CFLAGS) win.cpp Bunrui.o: winner.h Bunrui.cpp g++ $(CFLAGS) Bunrui.cpp Cons.o: winner.h Cons.cpp g++ $(CFLAGS) Cons.cpp Dest.o: winner.h Dest.cpp g++ $(CFLAGS) Dest.cpp Input.o: winner.h Input.cpp g++ $(CFLAGS) Input.cpp Learn.o: winner.h Learn.cpp g++ $(CFLAGS) Learn.cpp Pocket.o: winner.h Pocket.cpp g++ $(CFLAGS) Pocket.cpp ------------------------入力ファイル-------------- 最大試行回数 100 入力セルの数 2 出力セルの数 2 訓練例の数 4 乱数 123 入力データファイル or.dat ------------------------or.dat-------------------- OR演算の訓練例.最後の2つのデータが目標出力値 -1 -1 -1 1 -1 1 1 -1 1 -1 1 -1 1 1 1 -1 ------------------------winner.h------------------ /**************************/ /* Winnerクラスの定義 */ /**************************/ class Winner { long max; // 最大学習回数 long n; // 訓練例の数 long o; // 出力セルの数 long p; // 入力セルの数 long **W_p; // 重み(ポケット) long **W; // 重み long **E; // 訓練例 long **C; // 各訓練例に対する正しい出力 long *Ct; // 作業領域 public: Winner(long, long, long, long, long); ~Winner(); long Bunrui(); void Input(char *); void Learn(int, char *name = ""); long Pocket(long *); }; ------------------------constructor--------------- /**********************/ /* コンストラクタ */ /**********************/ #include #include "winner.h" void srand48(long); Winner::Winner(long max_i, long n_i, long o_i, long p_i, long seed) { long i1; /* 設定 */ max = max_i; n = n_i; o = o_i; p = p_i; srand48(seed); /* 領域の確保 */ E = new long * [n]; C = new long * [n]; for (i1 = 0; i1 < n; i1++) { E[i1] = new long [p+1]; C[i1] = new long [o]; } W_p = new long * [o]; W = new long * [o]; for (i1 = 0; i1 < o; i1++) { W_p[i1] = new long [p+1]; W[i1] = new long [p+1]; } Ct = new long [o]; } ------------------------destructor---------------- /********************/ /* デストラクタ */ /********************/ #include "winner.h" Winner::~Winner() { int i1; for (i1 = 0; i1 < n; i1++) { delete [] E[i1]; delete [] C[i1]; } delete [] E; delete [] C; for (i1 = 0; i1 < o; i1++) { delete [] W_p[i1]; delete [] W[i1]; } delete [] W_p; delete [] W; delete [] Ct; } ------------------------Bunrui.cpp---------------- /**********************************************/ /* 訓練例の分類 */ /* return : 正しく分類した訓練例の数 */ /**********************************************/ #include "winner.h" long Winner::Bunrui() { long cor, i1, i2, i3, mx = 0, mx_v = 0, num = 0, s; int sw = 0; for (i1 = 0; i1 < n; i1++) { cor = 0; for (i2 = 0; i2 < o; i2++) { if (C[i1][i2] == 1) cor = i2; s = 0; for (i3 = 0; i3 <= p; i3++) s += W[i2][i3] * E[i1][i3]; if (i2 == 0) { mx = 0; mx_v = s; } else { if (s > mx_v) { mx = i2; mx_v = s; sw = 0; } else { if (s == mx_v) sw = 1; } } } if (sw == 0 && cor == mx) num++; } return num; } ------------------------Input.cpp----------------- /******************************/ /* 学習データの読み込み */ /* name : ファイル名 */ /******************************/ #include #include "winner.h" void Winner::Input(char *name) { long i1, i2; FILE *st; st = fopen(name, "r"); fscanf(st, "%*s"); for (i1 = 0; i1 < n; i1++) { E[i1][0] = 1; for (i2 = 1; i2 <= p; i2++) fscanf(st, "%ld", &E[i1][i2]); for (i2 = 0; i2 < o; i2++) fscanf(st, "%ld", &C[i1][i2]); } fclose(st); } ------------------------Learn.cpp----------------- /*************************************/ /* 学習と結果の出力 */ /* seed : 乱数の初期値 */ /* pr : =0 : 画面に出力 */ /* =1 : ファイルに出力 */ /* name : 出力ファイル名 */ /*************************************/ #include #include "winner.h" void Winner::Learn(int pr, char *name) { long i1, i2, i3, mx = 0, mx_v = 0, n_tri, num, s; int sw; FILE *out; n_tri = Pocket(&num); if (pr == 0) out = stdout; else out = fopen(name, "w"); fprintf(out, "重み\n"); for (i1 = 0; i1 < o; i1++) { for (i2 = 0; i2 <= p; i2++) fprintf(out, "%5ld", W_p[i1][i2]); fprintf(out, "\n"); } fprintf(out, "分類結果\n"); for (i1 = 0; i1 < n; i1++) { sw = 0; for (i2 = 0; i2 < o; i2++) { s = 0; for (i3 = 0; i3 <= p; i3++) s += W_p[i2][i3] * E[i1][i3]; if (i2 == 0) { mx_v = s; mx = 0; } else { if (s > mx_v) { sw = 0; mx_v = s; mx = i2; } else { if (s == mx_v) sw = 1; } } } for (i2 = 1; i2 <= p; i2++) fprintf(out, "%2ld", E[i1][i2]); fprintf(out, " Cor "); for (i2 = 0; i2 < o; i2++) fprintf(out,"%2ld", C[i1][i2]); if (sw > 0) mx = -1; fprintf(out, " Res %2ld\n", mx+1); } if (n == num) printf(" !!すべてを分類(試行回数:%ld)\n", n_tri); else printf(" !!%ld 個を分類\n", num); } ------------------------Pocket.cpp---------------- /************************************************/ /* Pocket Algorith with Ratcet */ /* num_p : 正しく分類した訓練例の数 */ /* return : =0 : 最大学習回数 */ /* >0 : すべてを分類(回数) */ /************************************************/ #include #include "winner.h" #include double drand48(); long Winner::Pocket(long *num_p) { long cor, count = 0, i1, i2, k, mx = 0, num, run = 0, run_p = 0, s, sw = -1; int sw1; /* 初期設定 */ *num_p = 0; for (i1 = 0; i1 < o; i1++) { for (i2 = 0; i2 <= p; i2++) W[i1][i2] = 0; } /* 実行 */ while (sw < 0) { // 終了チェック count++;; if (count > max) sw = 0; else { // 訓練例の選択 k = (long)(drand48() * n); if (k >= n) k = n - 1; // 出力の計算 sw1 = 0; cor = -1; for (i1 = 0; i1 < o; i1++) { if (C[k][i1] == 1) cor = i1; s = 0; for (i2 = 0; i2 <= p; i2++) s += W[i1][i2] * E[k][i2]; Ct[i1] = s; if (i1 == 0) mx = 0; else { if (s > Ct[mx]) { mx = i1; sw1 = 0; } else { if (s == Ct[mx]) { sw1 = 1; if (cor >= 0 && mx == cor) mx = i1; } } } } // 正しい分類 if (sw1 == 0 && cor == mx) { run++; if (run > run_p) { num = Bunrui(); if (num > *num_p) { *num_p = num; run_p = run; for (i1 = 0; i1 < o; i1++) { for (i2 = 0; i2 <= p; i2++) W_p[i1][i2] = W[i1][i2]; } if (num == n) sw = count; } } } // 誤った分類 else { run = 0; for (i1 = 0; i1 <= p; i1++) { W[cor][i1] += E[k][i1]; W[mx][i1] -= E[k][i1]; } } } } return sw; } ------------------------main---------------------- /********************************/ /* Winner-Take-All Groups */ /* coded by Y.Suganuma */ /********************************/ #include #include #include "winner.h" /**********************/ /* 乱数の初期設定 */ /**********************/ void srand48(long seed) { srand((int)seed); } /**********************************/ /* [0, 1]区間の一様乱数の発生 */ /* rerutn : 乱数 */ /**********************************/ double drand48() { double x; while ((x = (double)rand() / RAND_MAX) == 0.0) ; return x; } /********************/ /* main program */ /********************/ int main(int argc, char *argv[]) { long max, n, o, p, seed; char name[100]; FILE *st; if (argc > 1) { // 基本データの入力 st = fopen(argv[1], "r"); fscanf(st, "%*s %ld %*s %ld %*s %ld %*s %ld %*s %ld", &max, &p, &o, &n, &seed); fscanf(st, "%*s %s", name); fclose(st); // ネットワークの定義と学習データ等の設定 Winner net(max, n, o, p, seed); net.Input(name); // 学習と結果の出力 if (argc == 2) net.Learn(0); else net.Learn(1, argv[2]); } else { printf("***error 入力データファイル名を指定して下さい\n"); exit(1); } return 0; }