Íåéðîííàÿ ñåòü íà Java (àâòîð À.Áåðñåíåâ, 2005 ãîä)
(Ôðàãìåíò êîäà ñèñòåìû “Predictor”)
package
com.bers.forecast.calc;
import
java.beans.*;
import
java.text.*;
import
javax.swing.*;
import
com.bers.db.*;
import
com.bers.forecast.common.*;
import
com.bers.forecast.data.pref.*;
import
com.bers.nn.perceptron.*;
import
com.bers.nn.ui.*;
public
class NNMethod extends EmptyFcastMethod implements PropertyChangeListener {
/**
* @directed true
*/
private Trainer tr = new Trainer();
private NumberFormat nf =
NumberFormat.getNumberInstance();
private NumberFormat nf2 =
NumberFormat.getNumberInstance();
private int historyCount = 10;
private TrainProgressDialog dlg = new
TrainProgressDialog();
private double[][] diffs = null;
private double[][] res = null;
private double[] difmuls = null;
public NNMethod() {
nf.setMaximumFractionDigits( 6 );
nf.setMinimumFractionDigits( 6 );
nf2.setMaximumFractionDigits( 3 );
nf2.setMinimumFractionDigits( 3 );
nf.setGroupingUsed( false );
dlg.setDefaultCloseOperation(
WindowConstants.HIDE_ON_CLOSE );
dlg.setModal( true );
dlg.setTrainer( tr );
//íàñòðàèâàåì ñëîè
IFunction fn = new ThSigmoid();
tr.getNetwork().getLayerAt( 0
).setOutCount( 10 );
for ( int i = 1; i <= 3; i++ ) {
tr.addLayer( i );
tr.getNetwork().getLayerAt( i
).setFn( fn );
tr.getNetwork().getLayerAt( i
).setOutCount( 10 );
}
loadProps();
}
private void loadProps() {
DM dm = DM.getInstance();
PrefVO pref = dm.findPrefByName(
"forecast.nn.initialSpeed" );
if ( pref.getValue() != null ) {
try {
tr.setInitialSpeed(
Double.parseDouble( pref.getValue() ) );
} catch ( NumberFormatException ex ) {
}
}
pref = dm.findPrefByName(
"forecast.nn.stopError" );
if ( pref.getValue() != null ) {
try {
tr.setStopError(
Double.parseDouble( pref.getValue() ) );
} catch ( NumberFormatException ex
) {
}
}
pref = dm.findPrefByName(
"forecast.nn.stopIteration" );
if ( pref.getValue() != null ) {
try {
tr.setStopIteration(
Integer.parseInt( pref.getValue() ) );
} catch ( NumberFormatException ex
) {
}
}
pref = dm.findPrefByName(
"forecast.nn.speedDecrement" );
if ( pref.getValue() != null ) {
try {
tr.setSpeedDecrement(
Double.parseDouble( pref.getValue() ) );
} catch ( NumberFormatException ex
) {
}
}
}
private void saveProps() {
DM dm = DM.getInstance();
try {
PrefVO pref = dm.findPrefByName(
"forecast.nn.initialSpeed" );
pref.setValue( tr.getInitialSpeed()
+ "" );
pref.store();
pref = dm.findPrefByName(
"forecast.nn.stopError" );
pref.setValue( tr.getStopError() +
"" );
pref.store();
pref = dm.findPrefByName(
"forecast.nn.stopIteration" );
pref.setValue(
tr.getStopIteration() + "" );
pref.store();
pref = dm.findPrefByName(
"forecast.nn.speedDecrement" );
pref.setValue(
tr.getSpeedDecrement() + "" );
pref.store();
} catch ( DBException ex ) {
System.err.println( "cannot
save forecast parameters" );
}
}
public void prepare() {
super.prepare();
//Âû÷èñëÿåì ðàçíîñòè
prepareDiffs();
//Ïîäãîòàâëèâàåì ó÷èòåëÿ
initTrainer();
//Ñîçäà¸ì ïðèìåðû
addSamples();
res = null;
}
private void addSamples() {
int outsize = diffs[0].length;
int insize = outsize * historyCount;
tr.removeAllSamples();
for ( int i = historyCount; i <
diffs.length; i++ ) {
double[] x = new double[insize];
for ( int j = 1; j <= historyCount; j++ ) {
for ( int k = 0; k <
outsize; k++ ) {
x[ ( j - 1 ) * outsize + k]
= diffs[i - j][k];
}
}
double[] y = new double[outsize];
for ( int j = 0; j < outsize; j++
) {
y[j] = diffs[i][j];
}
ArraySample sample = new
ArraySample( x, y );
tr.addSample( sample );
}
}
private void initTrainer() {
int outsize = diffs[0].length;
int insize = outsize * historyCount;
int midsize = outsize + ( int )
Math.sqrt( insize ) * 2;
tr.setInputCount( insize );
tr.getNetwork().getLayerAt( 0
).setInCount( insize );
tr.getNetwork().getLayerAt( 0
).setOutCount( midsize );
int lcount = tr.getLayersCount();
for ( int i = 1; i < lcount - 1; i++
) {
tr.getNetwork().getLayerAt( i
).setInCount( midsize );
tr.getNetwork().getLayerAt( i
).setOutCount( midsize );
}
tr.getNetwork().getLayerAt( lcount - 1
).setInCount( midsize );
tr.getNetwork().getLayerAt( lcount - 1
).setOutCount( outsize );
tr.setOutputCount( outsize );
}
private void prepareDiffs() {
diffs = new double[m.length - 1][];
difmuls = new double[m[0].length];
for ( int i = 0; i < m.length - 1;
i++ ) {
diffs[i] = new double[m[i +
1].length];
for ( int j = 0; j < m[i +
1].length; j++ ) {
diffs[i][j] = m[i + 1][j] -
m[i][j];
if ( Math.abs( diffs[i][j] )
> difmuls[j] ) {
difmuls[j] = Math.abs(
diffs[i][j] );
}
}
}
for ( int i = 0; i < diffs.length;
i++ ) {
for ( int j = 0; j <
diffs[0].length; j++ ) {
diffs[i][j] /= difmuls[j];
}
}
}
public void calc() throws FcastException {
if ( !isPrepared() ) {
prepare();
}
saveProps();
int lcount = tr.getLayersCount();
for ( int i = 0; i < lcount; i++ ) {
tr.getNetwork().getLayerAt( i
).reinit();
}
//Ïîäãîòàâëèâàåì ê çàïóñêó
dlg.reset();
dlg.setVisible( false );
dlg.setTrainer( tr );
//çàïóñêàåì ïîòîê
tr.removeListener( this );
tr.addListener( this );
tr.start();
//Ïîäãîòàâëèâàåì ðåçóëüòàò
cleanResult();
}
private void cleanResult() {
int fcount = getFcastValue() /
fixedStep;
double step = ( max - min ) / m.length;
res = new double[m.length +
fcount][regions.length * 2 + 1];
//Ïåðâîíà÷àëüíîå çàïîëíåíèå
initialFill(
step );
try {
//Çàïîëíåíèå ïî ñòàòèñòèêå
statFill(
step );
//Çàïîëíåíèå ïî ïðîãíîçó
fcastFill(
step );
} catch ( NoLayersException ex ) {
ex.printStackTrace();
} catch ( SizesMismatchException ex ) {
ex.printStackTrace();
}
}
private void initialFill( double step ) {
for ( int i = 0; i < historyCount;
i++ ) {
double t = min + i * step;
res[i][0] = t;
int k = 0;
for ( int j = 0; j <
regions.length; j++ ) {
int rccount =
regions[j].getArraysCount();
res[i][j * 2 + 1] = m[i][k];
res[i][j * 2 + 2] = m[i][k];
k += rccount;
}
}
}
private void statFill( double step ) throws
NoLayersException, SizesMismatchException {
int incount = tr.getInputCount();
int outcount = tr.getOutputCount();
for ( int i = historyCount; i <
diffs.length; i++ ) {
double t = min + i * step;
res[i][0] = t;
int k = 0;
for ( int j = 0; j <
regions.length; j++ ) {
int rccount =
regions[j].getArraysCount();
res[i][j * 2 + 1] = m[i][k];
k += rccount;
}
//ñîçäà¸ì ïðèìåð
double[] x = new double[incount];
for ( int iback = 1; iback <=
historyCount; iback++ ) {
for ( k = 0; k < outcount;
k++ ) {
x[ ( iback - 1 ) * outcount + k] =
diffs[i - iback][k];
}
}
double[] y = new double[outcount];
for ( int iback = 0; iback <
outcount; iback++ ) {
y[iback] = diffs[i][iback];
}
ArraySample sample = new
ArraySample( x, y );
//Âû÷èñëÿåì
tr.getNetwork().calc( sample );
double[] outs =
tr.getNetwork().getLayerAt( tr.getLayersCount() - 1 ).getOuts();
k = 0;
for ( int j = 0; j <
regions.length; j++ ) {
int rccount =
regions[j].getArraysCount();
res[i][j * 2 + 2] = outs[k] *
difmuls[k] + m[i - 1][k];
k += rccount;
}
}
}
private void fcastFill( double step )
throws NoLayersException, SizesMismatchException {
int fcount = getFcastValue() / 14;
int mcount = m.length;
int rcount = mcount + fcount;
int incount = tr.getInputCount();
int outcount = tr.getOutputCount();
double[][] tmp = new
double[rcount][incount];
for ( int i = 0; i < diffs.length;
i++ ) {
tmp[i] = diffs[i];
}
for ( int ifcast = 0; ifcast <=
fcount; ifcast++ ) {
int i = mcount + ifcast - 1;
double t = min + i * step;
res[i][0] = t;
//ñîçäà¸ì ïðèìåð
double[] x = new double[incount];
for ( int iback = 1; iback <=
historyCount; iback++ ) {
for ( int k = 0; k <
outcount; k++ ) {
x[ ( iback - 1 ) * outcount
+ k] = tmp[i - iback][k];
}
}
double[] y = new double[outcount];
for ( int iback = 0; iback <
outcount; iback++ ) {
y[iback] = tmp[i][iback];
}
ArraySample sample = new
ArraySample( x, y );
//Âû÷èñëÿåì
tr.getNetwork().calc( sample );
double[] outs =
tr.getNetwork().getLayerAt( tr.getLayersCount() - 1 ).getOuts();
int k = 0;
for ( int j = 0; j <
regions.length; j++ ) {
int rccount =
regions[j].getArraysCount();
for ( int o = 0; o <
rccount; o++ ) {
tmp[i][k + o] = outs[k +
o];
}
res[i][j * 2 + 2] = outs[k] *
difmuls[k] + res[i - 1][j * 2 + 2];
k += rccount;
}
}
}
public double[][] getFcast() {
return res;
}
public int getHistoryCount() {
return historyCount;
}
public void setHistoryCount( int
historyCount ) {
this.historyCount = historyCount;
m = null;
}
public Trainer getTrainer() {
return tr;
}
public void setTrainer( Trainer trainer ) {
if ( tr != null ) {
tr.removeListener( this );
}
tr = trainer;
dlg.setTrainer( tr );
}
public void propertyChange(
PropertyChangeEvent evt ) {
if ( evt.getPropertyName().equals(
"trainIteration" ) ) {
TrainEvent te = ( TrainEvent )
evt.getNewValue();
if ( ( te.getIteration() % 500 == 0
) || ( te.getNewError() <= tr.getStopError() ) ) {
double err = te.getNewError();
double d = Math.sqrt( err );
System.out.println(
"it=" + te.getIteration() + " err=" + nf.format( err ) +
" (" + nf2.format( d ) +
")
spd=" + nf.format( te.getSpeed() ) );
}
} else if (
evt.getPropertyName().equals( "trainThread" ) ) {
boolean val = ( ( Boolean )
evt.getNewValue() ).booleanValue();
if ( !val ) {
tr.removeListener( this );
// showFcast();
//
System.exit(0);
}
}
}
}