Machine-learning Framework

Hallo Leute, ich habe mich in den letzten Wochen mit neuralen Netzwerken beschäftigt und habe schon ein neurales Netzwerk, welches ein csv File einliest, erstellt. Jetzt möchte ich gerne mit Backpropagation arbeiten, ich weiß, wie dies theoretisch funktioniert, aber nicht wie man das implementieren kann.
Die Aufgabe, die ich von der Uni aus lösen muss, ist folgende:

The resulting software package shall not contain any problem-specific part (e.g. how to recognize a digit). But it shall work with an ANN that supports multiple hidden layers (as specified with your API or XML document from assignment 11). This requires adapting the backpropagation algorithm.

Use this framework to implement the digit recognition example with a 5 layer ANN.

Avoid making the same mistakes as in the pouring-problem assignment (particularly, be sure not to mix problem-specific code with problem-independent one). You'll notice that there is a tradeoff between the amount of code that is required to use the framework and the degree of flexibility that your framework supports. Which assumptions can you make (e.g. input-format, input data type, ...) without restricting the system too much?

Ich sollte also ein Framework erstellen, welches eben mit den 5 Layern eine Ziffer erkennen sollte. Das csv. File ist schon vorhanden und ich habe schon folgenden Code vom vorherigen Beispiel, dieser sollte adaptiert werden.

import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;

public class App {
    public static void main(String[] args) throws Exception {
        DigitRecognizer recognizer = new DigitRecognizer();
        recognizer.init(new File("mnist_train_100.csv"));
        // scorecard for how well the network performs, initially empty
        int attemptsOk = 0;
        int attemptsFailed = 0;
        try (Stream<String> stream = Files
                .lines(Paths.get("mnist_test_10.csv"))) {
            List<String> testDataList = stream.collect(Collectors.toList());

            System.out.println("correct | recognized");

            // go through all the records in the test data set
            for (String record : testDataList) {
                // split the record by the ',' commas
                String[] allValues = record.split(",", 2);

                // correct answer is first value
                int correctDigit = Integer.parseInt(allValues[0]);
                int recognizedDigit = recognizer.recognize(allValues[1]);

                if (correctDigit == recognizedDigit)

                System.out.println("  " + correctDigit + "     |     " + recognizedDigit);
        // calculate the performance score, the fraction of correct answers
        System.out.println("performance = " + (double) attemptsOk / (attemptsOk + attemptsFailed) * 100 + "%");
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.function.Function;

public class DigitRecognizer implements Assignment9 {

    private List<Integer> trainingDataDigit;
    private List<List<Double>> trainingData;
    private static final int amountOfTargets = 10;
    private double learningRate;
    private Function<Double, Double> activation;
    private Random random = new Random();
    private Matrix wih;
    private Matrix who;

     * Loads the .csv file with the training data or throws an Exception if anything goes wrong;
     * returns true iff the initialization completed successfully.
     * @param csvTrainingData
     *            the data used to train the neural network
     * @return true if the initialization was successful
    public boolean init(File csvTrainingData) throws Exception {
        try (BufferedReader br = new BufferedReader(new FileReader(csvTrainingData))) {
            trainingData = new ArrayList<>();
            trainingDataDigit = new ArrayList<>();

            br.lines().forEach(s -> {
                String[] allValues = s.split(",", 2);


                // scale and shift the inputs
                        .mapToDouble(x -> Double.parseDouble(x) / 255.0 * 0.99 + 0.01)

        if (trainingData.size() > 0) {
            init(trainingData.get(0).size(), 200, amountOfTargets, 0.1, (Double x) -> 1.0 / (1.0 + Math.exp(-x)));

            return true;
        } else
            return false;

     * trains the neural network used for digit recogniztion.
     * @return true iff the training of the neural network was successful.
     * @throws Exception
    public boolean train() throws Exception {
        // create the target output values (all 0.01, except the desired label which is 0.99)
        double[] targets = DoubleStream.generate(() -> 0.01).limit(amountOfTargets).toArray();

        for (int epochs = 0; epochs < 5; epochs++) {
            for (int i = 0; i < trainingData.size(); i++) {
                targets[trainingDataDigit.get(i)] = 0.99;

                double[] inputs = -> (double) d).toArray();

                train(inputs, targets);

                targets[trainingDataDigit.get(i)] = 0.01;

        return true;

    private int indexFromMax(double[] data) {
        int max = 0;

        for (int i = 0; i < data.length; i++)
            if (data[i] > data[max])
                max = i;

        return max;

     * Tries to recognize the digit represented by csvString.
     * @param csvString
     *            the digit pattern as CSV string.
     * @return the recognized digit
    public int recognize(String csvString) throws Exception {
        // scale and shift the inputs
        double[] inputs =","))
                .mapToDouble(s -> Double.parseDouble(s) / 255.0 * 0.99 + 0.01)

        double[] outputs = query(inputs);

        // the index of the highest value corresponds to the label
        return indexFromMax(outputs);

     * @param inputNodes
     *            the amount of input nodes
     * @param hiddenNodes
     *            the amount of hidden nodes
     * @param outputNodes
     *            the amount of output nodes
     * @param learningRate
     *            the learning rate
     * @param activation
     *            the activation function
    public void init(int inputNodes, int hiddenNodes, int outputNodes, double learningRate, Function<Double, Double> activation) {
        this.learningRate = learningRate;
        this.activation = activation;

        /*link weight matrices, wih and who
         weights inside the arrays are w_i_j, where link is from node i to node j in the next layer
         w11 w21
         w12 w22 etc*/
        wih = new Matrix(hiddenNodes, inputNodes);
        who = new Matrix(outputNodes, hiddenNodes);

        fillWithRandomValues(wih, 0, Math.pow(inputNodes, -0.5));
        fillWithRandomValues(who, 0, Math.pow(hiddenNodes, -0.5));

    private void fillWithRandomValues(Matrix matrix, double mean, double variance) {
        for (int row = 0; row < matrix.getRowDimension(); row++)
            for (int col = 0; col < matrix.getColumnDimension(); col++)
                matrix.set(row, col, nextRandom(mean, variance));

    private double nextRandom(double mean, double variance) {
        return mean + random.nextGaussian() * variance;

     * Trains the neural network with the given input and target values.
     * @param inputsList
     *            the input values to be used
     * @param targetsList
     *            the target values for the given input values
    public void train(double[] inputsList, double[] targetsList) {
        Matrix inputs = toMatrix(inputsList);
        Matrix targets = toMatrix(targetsList);

        // calculate signals into layer
        Matrix hiddenInputs = wih.matrixMultiplication(inputs);
        // calculate the signals emerging from hidden layer
        Matrix hiddenOutputs = hiddenInputs.applyFuntion(activation);

        // calculate signals into final output layer
        Matrix finalInputs = who.matrixMultiplication(hiddenOutputs);
        // calculate the signals emerging from final output layer
        Matrix finalOutputs = finalInputs.applyFuntion(activation);

        // output layer error is the (target - actual)
        Matrix outputErrors = targets.matrixSubstraction(finalOutputs);
        // hidden layer error is the output_errors, split by weights, recombined at hidden nodes
        Matrix hiddenErrors = who.transposeMatrix().matrixMultiplication(outputErrors);

        // update the weights for the links between the hidden and output layers
        who = who.matrixAddition(outputErrors.multByElement(finalOutputs)
                .multByElement(finalOutputs.applyFuntion(d -> 1.0 - d))

        // update the weights for the links between the input and hidden layers
        wih = wih.matrixAddition(hiddenErrors.multByElement(hiddenOutputs)
                .multByElement(hiddenOutputs.applyFuntion(d -> 1.0 - d))

     * Queries the output of the neural network for a given input.
     * @param inputsList
     *            the input to query for.
     * @return the output from the network.
    public double[] query(double[] inputsList) {
        Matrix inputs = toMatrix(inputsList);

        // calculate signals into hidden layer
        Matrix hiddenInputs = wih.matrixMultiplication(inputs);
        // calculate the signals emerging from hidden layer
        Matrix hiddenOutputs = hiddenInputs.applyFuntion(activation);

        // calculate signals into final output layer
        Matrix finalInputs = who.matrixMultiplication(hiddenOutputs);
        // calculate the signals emerging from final output layer
        Matrix finalOutputs = finalInputs.applyFuntion(activation);

        return toArray(finalOutputs);

    private Matrix toMatrix(double[] data) {
        Matrix result = new Matrix(data.length, 1);

        for (int i = 0; i < data.length; i++)
            result.set(i, 0, data[i]);

        return result;

    private double[] toArray(Matrix matrix) {
        double[] result = new double[matrix.getRowDimension()];

        for (int i = 0; i < result.length; i++)
            result[i] = matrix.get(i, 0);

        return result;
import java.lang.StringBuilder;
import java.util.function.Function;

public class Matrix {

    private final int row;
    private final int col;

    private double[][] elements;

    public Matrix(int row, int col) {
        this.row = row;
        this.col = col;
        this.elements = new double[row][col];

    public int getRowDimension() {
        return row;

    public int getColumnDimension() {
        return col;

    public Matrix transposeMatrix() {
        Matrix B = new Matrix(this.col, this.row);
        for (int row = 0; row < this.col; row++)
            for (int col = 0; col < this.row; col++)
                B.set(row, col, this.get(col, row));
        return B;

    public void set(int row, int col, double e) {
        elements[row][col] = e;

    public double get(int row, int col) {
        return elements[row][col];

    public Matrix scalarAddition(double a) {
        Matrix B = new Matrix(this.row, this.col);
        for (int row = 0; row < this.row; row++)
            for (int col = 0; col < this.col; col++)
                B.set(row, col, a + this.get(row, col));
        return B;

    public Matrix scalarSubstraction(double a) {
        Matrix B = new Matrix(this.row, this.col);
        for (int row = 0; row < this.row; row++)
            for (int col = 0; col < this.col; col++)
                B.set(row, col, a - this.get(row, col));
        return B;

    public Matrix scalarMultiplication(double a) {
        Matrix B = new Matrix(this.row, this.col);
        for (int row = 0; row < this.row; row++)
            for (int col = 0; col < this.col; col++)
                B.set(row, col, this.get(row, col) * a);
        return B;

    public Matrix applyFuntion(Function<Double, Double> f) {
        Matrix B = new Matrix(this.row, this.col);
        for (int row = 0; row < this.row; row++)
            for (int col = 0; col < this.col; col++)
                B.set(row, col, f.apply(this.get(row, col)));
        return B;

    public Matrix multByElement(Matrix B) {
        Matrix C = new Matrix(this.row, this.col);
        for (int row = 0; row < this.row; row++)
            for (int col = 0; col < this.col; col++)
                C.set(row, col, this.get(row, col) * B.get(row, col));
        return C;

    public Matrix matrixAddition(Matrix B) {
        if (!(this.row == B.row && this.col == B.col))
            throw new RuntimeException("the matrices dimensions do not add up");

        Matrix C = new Matrix(this.row, this.col);
        for (int row = 0; row < this.row; row++)
            for (int col = 0; col < this.col; col++)
                C.set(row, col, this.get(row, col) + B.get(row, col));
        return C;

    public Matrix matrixSubstraction(Matrix B) {
        if (!(this.row == B.row && this.col == B.col))
            throw new RuntimeException("the matrices dimensions do not add up");

        Matrix C = new Matrix(this.row, this.col);
        for (int row = 0; row < this.row; row++)
            for (int col = 0; col < this.col; col++)
                 C.set(row, col, this.get(row, col) - B.get(row, col));
        return C;

    public Matrix matrixMultiplication(Matrix B) {
        if (!(this.col == B.row))
            throw new RuntimeException("the matrices dimensions do not add up");

        Matrix C = new Matrix(this.row, B.col);
        for (int row = 0; row < this.row; row++)
            for (int col = 0; col < B.col; col++) {
                double sum = 0;
                for (int k = 0; k < this.col; k++)
                    sum += this.get(row, k) * B.get(k, col);
                C.set(row, col, sum);
        return C;

    public String toString() {
        StringBuilder str = new StringBuilder();
        for (double row[] : elements) {
            for (double d : row)
                str.append(d + " ");
        return str.toString();

Ich hoffe ihr könnt mir zeigen, wie ich das Programm so umschreiben kann, dass es mit der Backpropagation und den 5 Layern funktioniert!

Ganz liebe Grüße!!
