Íåéðîííàÿ ñåòü íà Java (àâòîð À.Áåðñåíåâ, 2005 ãîä)

(Ôðàãìåíò êîäà ñèñòåìû Predictor)[1]

 

 

package com.bers.forecast.calc;

 

import java.beans.*;

import java.text.*;

 

import javax.swing.*;

 

import com.bers.db.*;

import com.bers.forecast.common.*;

import com.bers.forecast.data.pref.*;

import com.bers.nn.perceptron.*;

import com.bers.nn.ui.*;

 

public class NNMethod extends EmptyFcastMethod implements PropertyChangeListener {

 

    /**

     * @directed true

     */

 

    private Trainer tr = new Trainer();

    private NumberFormat nf = NumberFormat.getNumberInstance();

    private NumberFormat nf2 = NumberFormat.getNumberInstance();

    private int historyCount = 10;

    private TrainProgressDialog dlg = new TrainProgressDialog();

    private double[][] diffs = null;

    private double[][] res = null;

    private double[] difmuls = null;

 

    public NNMethod() {

        nf.setMaximumFractionDigits( 6 );

        nf.setMinimumFractionDigits( 6 );

 

        nf2.setMaximumFractionDigits( 3 );

        nf2.setMinimumFractionDigits( 3 );

        nf.setGroupingUsed( false );

 

        dlg.setDefaultCloseOperation( WindowConstants.HIDE_ON_CLOSE );

        dlg.setModal( true );

        dlg.setTrainer( tr );

 

        //íàñòðàèâàåì ñëîè

        IFunction fn = new ThSigmoid();

        tr.getNetwork().getLayerAt( 0 ).setOutCount( 10 );

        for ( int i = 1; i <= 3; i++ ) {

            tr.addLayer( i );

            tr.getNetwork().getLayerAt( i ).setFn( fn );

            tr.getNetwork().getLayerAt( i ).setOutCount( 10 );

        }

 

        loadProps();

    }

 

    private void loadProps() {

        DM dm = DM.getInstance();

        PrefVO pref = dm.findPrefByName( "forecast.nn.initialSpeed" );

        if ( pref.getValue() != null ) {

            try {

                tr.setInitialSpeed( Double.parseDouble( pref.getValue() ) );

            } catch ( NumberFormatException ex ) {

            }

        }

        pref = dm.findPrefByName( "forecast.nn.stopError" );

        if ( pref.getValue() != null ) {

            try {

                tr.setStopError( Double.parseDouble( pref.getValue() ) );

            } catch ( NumberFormatException ex ) {

            }

        }

        pref = dm.findPrefByName( "forecast.nn.stopIteration" );

        if ( pref.getValue() != null ) {

            try {

                tr.setStopIteration( Integer.parseInt( pref.getValue() ) );

            } catch ( NumberFormatException ex ) {

            }

        }

        pref = dm.findPrefByName( "forecast.nn.speedDecrement" );

        if ( pref.getValue() != null ) {

            try {

                tr.setSpeedDecrement( Double.parseDouble( pref.getValue() ) );

            } catch ( NumberFormatException ex ) {

            }

        }

    }

 

    private void saveProps() {

        DM dm = DM.getInstance();

        try {

            PrefVO pref = dm.findPrefByName( "forecast.nn.initialSpeed" );

            pref.setValue( tr.getInitialSpeed() + "" );

            pref.store();

 

            pref = dm.findPrefByName( "forecast.nn.stopError" );

            pref.setValue( tr.getStopError() + "" );

            pref.store();

 

            pref = dm.findPrefByName( "forecast.nn.stopIteration" );

            pref.setValue( tr.getStopIteration() + "" );

            pref.store();

 

            pref = dm.findPrefByName( "forecast.nn.speedDecrement" );

            pref.setValue( tr.getSpeedDecrement() + "" );

            pref.store();

 

        } catch ( DBException ex ) {

            System.err.println( "cannot save forecast parameters" );

        }

    }

 

    public void prepare() {

        super.prepare();

        //Âû÷èñëÿåì ðàçíîñòè

        prepareDiffs();

 

        //Ïîäãîòàâëèâàåì ó÷èòåëÿ

        initTrainer();

 

        //Ñîçäà¸ì ïðèìåðû

        addSamples();

        res = null;

    }

 

    private void addSamples() {

        int outsize = diffs[0].length;

        int insize = outsize * historyCount;

 

        tr.removeAllSamples();

        for ( int i = historyCount; i < diffs.length; i++ ) {

            double[] x = new double[insize];

            for ( int j = 1; j <= historyCount; j++ ) {

                for ( int k = 0; k < outsize; k++ ) {

                    x[ ( j - 1 ) * outsize + k] = diffs[i - j][k];

                }

            }

            double[] y = new double[outsize];

            for ( int j = 0; j < outsize; j++ ) {

                y[j] = diffs[i][j];

            }

            ArraySample sample = new ArraySample( x, y );

            tr.addSample( sample );

        }

    }

 

    private void initTrainer() {

        int outsize = diffs[0].length;

        int insize = outsize * historyCount;

        int midsize = outsize + ( int ) Math.sqrt( insize ) * 2;

 

        tr.setInputCount( insize );

        tr.getNetwork().getLayerAt( 0 ).setInCount( insize );

        tr.getNetwork().getLayerAt( 0 ).setOutCount( midsize );

 

        int lcount = tr.getLayersCount();

        for ( int i = 1; i < lcount - 1; i++ ) {

            tr.getNetwork().getLayerAt( i ).setInCount( midsize );

            tr.getNetwork().getLayerAt( i ).setOutCount( midsize );

        }

        tr.getNetwork().getLayerAt( lcount - 1 ).setInCount( midsize );

        tr.getNetwork().getLayerAt( lcount - 1 ).setOutCount( outsize );

        tr.setOutputCount( outsize );

    }

 

    private void prepareDiffs() {

        diffs = new double[m.length - 1][];

        difmuls = new double[m[0].length];

        for ( int i = 0; i < m.length - 1; i++ ) {

            diffs[i] = new double[m[i + 1].length];

            for ( int j = 0; j < m[i + 1].length; j++ ) {

                diffs[i][j] = m[i + 1][j] - m[i][j];

                if ( Math.abs( diffs[i][j] ) > difmuls[j] ) {

                    difmuls[j] = Math.abs( diffs[i][j] );

                }

            }

        }

        for ( int i = 0; i < diffs.length; i++ ) {

            for ( int j = 0; j < diffs[0].length; j++ ) {

                diffs[i][j] /= difmuls[j];

            }

        }

    }

 

    public void calc() throws FcastException {

        if ( !isPrepared() ) {

            prepare();

        }

        saveProps();

 

        int lcount = tr.getLayersCount();

        for ( int i = 0; i < lcount; i++ ) {

            tr.getNetwork().getLayerAt( i ).reinit();

        }

 

        //Ïîäãîòàâëèâàåì ê çàïóñêó

        dlg.reset();

        dlg.setVisible( false );

        dlg.setTrainer( tr );

 

        //çàïóñêàåì ïîòîê

        tr.removeListener( this );

        tr.addListener( this );

        tr.start();

 

        //Ïîäãîòàâëèâàåì ðåçóëüòàò

        cleanResult();

    }

 

    private void cleanResult() {

        int fcount = getFcastValue() / fixedStep;

 

        double step = ( max - min ) / m.length;

 

        res = new double[m.length + fcount][regions.length * 2 + 1];

        //Ïåðâîíà÷àëüíîå çàïîëíåíèå

        initialFill( step );

 

        try {

            //Çàïîëíåíèå ïî ñòàòèñòèêå

            statFill( step );

 

            //Çàïîëíåíèå ïî ïðîãíîçó

            fcastFill( step );

        } catch ( NoLayersException ex ) {

            ex.printStackTrace();

        } catch ( SizesMismatchException ex ) {

            ex.printStackTrace();

        }

    }

 

    private void initialFill( double step ) {

        for ( int i = 0; i < historyCount; i++ ) {

            double t = min + i * step;

            res[i][0] = t;

            int k = 0;

            for ( int j = 0; j < regions.length; j++ ) {

                int rccount = regions[j].getArraysCount();

                res[i][j * 2 + 1] = m[i][k];

                res[i][j * 2 + 2] = m[i][k];

                k += rccount;

            }

        }

    }

 

    private void statFill( double step ) throws NoLayersException, SizesMismatchException {

        int incount = tr.getInputCount();

        int outcount = tr.getOutputCount();

        for ( int i = historyCount; i < diffs.length; i++ ) {

            double t = min + i * step;

            res[i][0] = t;

            int k = 0;

            for ( int j = 0; j < regions.length; j++ ) {

                int rccount = regions[j].getArraysCount();

                res[i][j * 2 + 1] = m[i][k];

                k += rccount;

            }

 

            //ñîçäà¸ì ïðèìåð

            double[] x = new double[incount];

            for ( int iback = 1; iback <= historyCount; iback++ ) {

                for ( k = 0; k < outcount; k++ ) {

                    x[ ( iback - 1 ) * outcount + k] = diffs[i - iback][k];

                }

            }

            double[] y = new double[outcount];

            for ( int iback = 0; iback < outcount; iback++ ) {

                y[iback] = diffs[i][iback];

            }

            ArraySample sample = new ArraySample( x, y );

 

            //Âû÷èñëÿåì

            tr.getNetwork().calc( sample );

            double[] outs = tr.getNetwork().getLayerAt( tr.getLayersCount() - 1 ).getOuts();

            k = 0;

            for ( int j = 0; j < regions.length; j++ ) {

                int rccount = regions[j].getArraysCount();

                res[i][j * 2 + 2] = outs[k] * difmuls[k] + m[i - 1][k];

                k += rccount;

            }

        }

    }

 

    private void fcastFill( double step ) throws NoLayersException, SizesMismatchException {

        int fcount = getFcastValue() / 14;

        int mcount = m.length;

        int rcount = mcount + fcount;

        int incount = tr.getInputCount();

        int outcount = tr.getOutputCount();

        double[][] tmp = new double[rcount][incount];

        for ( int i = 0; i < diffs.length; i++ ) {

            tmp[i] = diffs[i];

        }

        for ( int ifcast = 0; ifcast <= fcount; ifcast++ ) {

            int i = mcount + ifcast - 1;

            double t = min + i * step;

            res[i][0] = t;

 

            //ñîçäà¸ì ïðèìåð

            double[] x = new double[incount];

            for ( int iback = 1; iback <= historyCount; iback++ ) {

                for ( int k = 0; k < outcount; k++ ) {

                    x[ ( iback - 1 ) * outcount + k] = tmp[i - iback][k];

                }

            }

            double[] y = new double[outcount];

            for ( int iback = 0; iback < outcount; iback++ ) {

                y[iback] = tmp[i][iback];

            }

            ArraySample sample = new ArraySample( x, y );

 

            //Âû÷èñëÿåì

            tr.getNetwork().calc( sample );

            double[] outs = tr.getNetwork().getLayerAt( tr.getLayersCount() - 1 ).getOuts();

            int k = 0;

            for ( int j = 0; j < regions.length; j++ ) {

                int rccount = regions[j].getArraysCount();

                for ( int o = 0; o < rccount; o++ ) {

                    tmp[i][k + o] = outs[k + o];

                }

                res[i][j * 2 + 2] = outs[k] * difmuls[k] + res[i - 1][j * 2 + 2];

 

                k += rccount;

            }

        }

    }

 

    public double[][] getFcast() {

        return res;

    }

 

    public int getHistoryCount() {

        return historyCount;

    }

 

    public void setHistoryCount( int historyCount ) {

        this.historyCount = historyCount;

        m = null;

    }

 

    public Trainer getTrainer() {

        return tr;

    }

 

    public void setTrainer( Trainer trainer ) {

        if ( tr != null ) {

            tr.removeListener( this );

        }

        tr = trainer;

        dlg.setTrainer( tr );

    }

 

    public void propertyChange( PropertyChangeEvent evt ) {

        if ( evt.getPropertyName().equals( "trainIteration" ) ) {

            TrainEvent te = ( TrainEvent ) evt.getNewValue();

            if ( ( te.getIteration() % 500 == 0 ) || ( te.getNewError() <= tr.getStopError() ) ) {

                double err = te.getNewError();

                double d = Math.sqrt( err );

                System.out.println( "it=" + te.getIteration() + " err=" + nf.format( err ) + " (" + nf2.format( d ) +

                                    ") spd=" + nf.format( te.getSpeed() ) );

            }

        } else if ( evt.getPropertyName().equals( "trainThread" ) ) {

            boolean val = ( ( Boolean ) evt.getNewValue() ).booleanValue();

            if ( !val ) {

                tr.removeListener( this );

//        showFcast();

//                                 System.exit(0);

            }

        }

    }

}



[1] Êîä òÿíåò êëàññ ïåðñåïòðîí ñòîðîííåé ðàçðàáîòêè

 



Ðåéòèíã@Mail.ru

 
Àïøåðîíñê Ñïîðò VBA Ôîðåêñ Ñî÷è-2014 Íåéðîñåòè Ñòóäåíòàì
Ñâÿçü ñ Àäìèíèñòðàòîðîì ñàéòà, E-mail: apsheronka@mail.ru
Àïøåðîíñê, Êðàñíîäàðñêèé êðàé

Ðàçìåùåíèå ðåêëàìû íà ñàéòå
Êàðòà ñàéòà