Hey,
nachdem ich versucht habe der Stackoverflowcommunity eine bessere Antwort herauszulocken, versuche ich es mal hier.
Ich versuche eine AI zu entwickeln, die zum späteren Teil einmal Astroids spielt. Nun soll aber das Q - Learning zu nächst allgemein funktionieren.
Die Brain Klasse sieht wie folgt aus:
Die rLearn Methode evaluiert den Fehler bzw. den neuen Q - value, während die experienceReplay Methode versucht dem Netz etwas beizubringen. Unglücklicherweise ist in dieser Klasse ein ziemlich idiotischer Bug, ich weiß nur nicht wo... Ich benutze ein älteres Netz um den maximalen Q - Value des jeweils nächsten States zu berechnen, das soll Stabilität gewährleisten.
Das Tuplecode findet sich hier:
package rlgame;
import java.util.ArrayList;
public class Tuple {
VectorND statefirst = new VectorND();
VectorND stateafter = new VectorND();
VectorND qactions = new VectorND();
double rewardafter;
int actionTaken;
}
Den gesamten Code findet ihr hier:
https://github.com/SuchtyTV/RLearningBird
nachdem ich versucht habe der Stackoverflowcommunity eine bessere Antwort herauszulocken, versuche ich es mal hier.
Ich versuche eine AI zu entwickeln, die zum späteren Teil einmal Astroids spielt. Nun soll aber das Q - Learning zu nächst allgemein funktionieren.
Die Brain Klasse sieht wie folgt aus:
Java:
package rlgame;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.encog.engine.network.activation.ActivationLOG;
import org.encog.engine.network.activation.ActivationLinear;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.engine.network.activation.ActivationSoftMax;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.training.propagation.back.Backpropagation;
public class Brain {
private ArrayList<ArrayList<Tuple>> biglist = new ArrayList<ArrayList<Tuple>>();
BasicNetwork nn;
BasicNetwork oldnn;
private int index = 0;
MLDataSet set = new BasicMLDataSet();
public Brain() {
nn = new BasicNetwork();
nn.addLayer(new BasicLayer(new ActivationLinear(),true,29));
nn.addLayer(new BasicLayer(new ActivationSigmoid(),true,20));
nn.addLayer(new BasicLayer(new ActivationSigmoid(),true,20));
nn.addLayer(new BasicLayer(new ActivationLinear(),false,5));
nn.getStructure().finalizeStructure();
nn.reset();
oldnn = (BasicNetwork) nn.clone();
}
public void rlearn(ArrayList<Tuple> tupels, double learningrate, double discountfactor, boolean rememberTuples) {
if(rememberTuples)biglist.add(tupels);
//newQ = sum of all rewards you have got through
for(int i = tupels.size()-1; i > 0; i--) {
MLData in = new BasicMLData(29);
MLData out = new BasicMLData(5);
//Add State as in
int index = 0;
for(double w : tupels.get(i).statefirst.elements) {
in.add(index++, w);
}
//Now start updating Q - Values
double qnew = 0;
if(i <= tupels.size()-2){
qnew = tupels.get(i).rewardafter + discountfactor*qMax(tupels.get(i).stateafter);
} else {
qnew = tupels.get(i).rewardafter;
}
tupels.get(i).qactions.elements[tupels.get(i).actionTaken] = qnew;
//Add Q Values as out
index = 0;
for(double w : tupels.get(i).qactions.elements) {
out.add(index++, w);
}
set.add(in, out);
}
}
private double qMax(VectorND stateafter) {
double[] qactions = oldnn.compute(new BasicMLData(stateafter.elements)).getData();
double max = Double.MIN_VALUE;
for(double w : qactions) {
if(w > max) {
max = w;
}
}
return max;
}
public double[] getOutput(MLData input) {
return nn.compute(input).getData();
}
public void experienceReplay(double learningRate, double discountFactor) {
for(int i = 0; i < 10; i++) {
Collections.shuffle(biglist);
List<ArrayList<Tuple>> list = biglist.subList(0, (int)(biglist.size()*0.3));
for(ArrayList<Tuple> tuples : list) {
rlearn(tuples,learningRate, discountFactor, false);
}
Backpropagation prop = new Backpropagation(nn, set);
prop.setLearningRate(learningRate);
prop.iteration(10);
System.out.println(prop.getError());
}
oldnn = (BasicNetwork) nn.clone();
if(biglist.size() > 10000) {
System.out.println("List trimmed.");
while(biglist.size() > 10000) {
biglist.remove(biglist.size()-1);
}
}
set = new BasicMLDataSet();
}
public void addTuples(ArrayList<Tuple> tuples) {
biglist.add(tuples);
}
}
Die rLearn Methode evaluiert den Fehler bzw. den neuen Q - value, während die experienceReplay Methode versucht dem Netz etwas beizubringen. Unglücklicherweise ist in dieser Klasse ein ziemlich idiotischer Bug, ich weiß nur nicht wo... Ich benutze ein älteres Netz um den maximalen Q - Value des jeweils nächsten States zu berechnen, das soll Stabilität gewährleisten.
Das Tuplecode findet sich hier:
Code:
package rlgame;
import java.util.ArrayList;
public class Tuple {
VectorND statefirst = new VectorND();
VectorND stateafter = new VectorND();
VectorND qactions = new VectorND();
double rewardafter;
int actionTaken;
}
package rlgame;
import java.util.ArrayList;
public class Tuple {
VectorND statefirst = new VectorND();
VectorND stateafter = new VectorND();
VectorND qactions = new VectorND();
double rewardafter;
int actionTaken;
}
Den gesamten Code findet ihr hier:
https://github.com/SuchtyTV/RLearningBird