以下,複数のファイル構成になっていますので,ファイル間の 区切りを「---・・・」で示します. ------------------------makefile------------------ # # リンク # CFLAGS = -c -Wall -O2 OBJECT = comp.o Cons.o Dest.o Input.o Learn.o pgm: $(OBJECT) g++ $(OBJECT) -o comp -lm # # コンパイル # comp.o: competition.h comp.cpp g++ $(CFLAGS) comp.cpp Cons.o: competition.h Cons.cpp g++ $(CFLAGS) Cons.cpp Dest.o: competition.h Dest.cpp g++ $(CFLAGS) Dest.cpp Input.o: competition.h Input.cpp g++ $(CFLAGS) Input.cpp Learn.o: competition.h Learn.cpp g++ $(CFLAGS) Learn.cpp ------------------------入力ファイル-------------- 最大試行回数 1000 入力セルの数 9 出力セルの数 3 訓練例の数 9 乱数 123 係数(0〜1) 0.1 入力データファイル pat.dat ------------------------pat.dat------------------- 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 1 1 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 1 1 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 ------------------------competition.h------------- /*******************************/ /* Competitionクラスの定義 */ /*******************************/ class Competition { long max; // 最大学習回数 long n; // 訓練例の数 long o; // 出力セルの数 long p; // 入力セルの数 long **E; // 訓練例 double **W; // 重み double sig; // 重み修正係数 public: Competition(long, long, long, long, long, double); ~Competition(); void Input(char *); void Learn(int, char *name = ""); }; ------------------------constructor--------------- /**********************/ /* コンストラクタ */ /**********************/ #include #include "competition.h" void srand48(long); Competition::Competition(long max_i, long n_i, long o_i, long p_i, long seed, double sig_i) { long i1; /* 設定 */ sig = sig_i; max = max_i; n = n_i; o = o_i; p = p_i; srand48(seed); /* 領域の確保 */ E = new long * [n]; for (i1 = 0; i1 < n; i1++) E[i1] = new long [p]; W = new double * [o]; for (i1 = 0; i1 < o; i1++) W[i1] = new double [p]; } ------------------------destructor---------------- /********************/ /* デストラクタ */ /********************/ #include "competition.h" Competition::~Competition() { int i1; for (i1 = 0; i1 < n; i1++) delete [] E[i1]; delete [] E; for (i1 = 0; i1 < o; i1++) delete [] W[i1]; delete [] W; } ------------------------Input.cpp----------------- /******************************/ /* 学習データの読み込み */ /* name : ファイル名 */ /******************************/ #include #include "competition.h" void Competition::Input(char *name) { long i1, i2; FILE *st; st = fopen(name, "r"); for (i1 = 0; i1 < n; i1++) { for (i2 = 0; i2 < p; i2++) fscanf(st, "%ld", &E[i1][i2]); } fclose(st); } ------------------------Learn.cpp----------------- /*************************************/ /* 学習と結果の出力 */ /* pr : =0 : 画面に出力 */ /* =1 : ファイルに出力 */ /* name : 出力ファイル名 */ /*************************************/ #include #include #include "competition.h" double drand48(); void Competition::Learn(int pr, char *name) { double mx_v = 0.0, s, sum; long count, i1, i2, i3, k, mx = 0; FILE *out; /* 初期設定 */ for (i1 = 0; i1 < o; i1++) { sum = 0.0; for (i2 = 0; i2 < p; i2++) { W[i1][i2] = drand48(); sum += W[i1][i2]; } sum = 1.0 / sum; for (i2 = 0; i2 < p; i2++) W[i1][i2] *= sum; } /* 学習 */ for (count = 0; count < max; count++) { // 訓練例の選択 k = (long)(drand48() * n); if (k >= n) k = n - 1; // 出力の計算 for (i1 = 0; i1 < o; i1++) { s = 0.0; for (i2 = 0; i2 < p; i2++) s += W[i1][i2] * E[k][i2]; if (i1 == 0 || s > mx_v) { mx = i1; mx_v = s; } } // 重みの修正 sum = 0.0; for (i1 = 0; i1 < p; i1++) sum += E[k][i1]; for (i1 = 0; i1 < p; i1++) W[mx][i1] += sig * (E[k][i1] / sum - W[mx][i1]); } /* 出力 */ if (pr == 0) out = stdout; else out = fopen(name, "w"); fprintf(out, "分類結果\n"); for (i1 = 0; i1 < n; i1++) { for (i2 = 0; i2 < p; i2++) fprintf(out, "%2ld", E[i1][i2]); fprintf(out, " Res "); for (i2 = 0; i2 < o; i2++) { s = 0.0; for (i3 = 0; i3 < p; i3++) s += W[i2][i3] * E[i1][i3]; if (i2 == 0 || s > mx_v) { mx = i2; mx_v = s; } } fprintf(out, "%3ld\n", mx+1); } } ------------------------main---------------------- /********************************/ /* 競合学習 */ /* coded by Y.Suganuma */ /********************************/ #include #include #include "competition.h" /**********************/ /* 乱数の初期設定 */ /**********************/ void srand48(long seed) { srand((int)seed); } /**********************************/ /* [0, 1]区間の一様乱数の発生 */ /* rerutn : 乱数 */ /**********************************/ double drand48() { double x; while ((x = (double)rand() / RAND_MAX) == 0.0) ; return x; } /********************/ /* main program */ /********************/ int main(int argc, char *argv[]) { double sig; long max, n, p, o, seed; char name[100]; FILE *st; if (argc > 1) { // 基本データの入力 st = fopen(argv[1], "r"); fscanf(st, "%*s %ld %*s %ld %*s %ld %*s %ld", &max, &p, &o, &n); fscanf(st, "%*s %ld %*s %lf", &seed, &sig); fscanf(st, "%*s %s", name); fclose(st); // ネットワークの定義 Competition net(max, n, o, p, seed, sig); net.Input(name); // 学習と結果の出力 if (argc == 2) net.Learn(0); else net.Learn(1, argv[2]); } else { printf("***error 入力データファイル名を指定して下さい\n"); exit(1); } return 0; }