Hallo,
ich habe eine Klasse für univariate lineare Regression:
[CODE lang="java" title="UnivariateLinearRegression"]public class UnivariateLinearRegression implements LinearModel {
private double w0, w1; // weights hw(x) = w1*x + w0
public UnivariateLinearRegression() {
w0 = Math.random();
w1 = Math.random();
}
@Override
public void train(double[][] data, int epochs, double alpha) {
for(int i = 0; i < epochs; i++) {
double errorw0 = sumErrorW0(data);
double errorw1 = sumErrorW1(data);
w0 = w0 + alpha * errorw0;
w1 = w1 + alpha * errorw1;
}
}
private double sumErrorW0(double[][] data) {
double sum = 0;
for (int i = 0; i < data.length; i++) {
if (data.length != 2)
throw new IllegalArgumentException("Each Data-Point must have two elements");
sum += data[1] - (w1 * data[0] + w0);
}
return sum;
}
private double sumErrorW1(double[][] data) {
double sum = 0;
for (int i = 0; i < data.length; i++) {
if (data.length != 2)
throw new IllegalArgumentException("Each Data-Point must have two elements");
sum += (data[1] - (w1 * data[0] + w0)) * data[0];
}
return sum;
}
}[/CODE]
Nun habe ich das Problem, dass die Qualität des Ergebnisses sehr stark von den Startwerten abhängt. Momentan mache ich das ja mit Math.random() was (offensichtlich) nicht die beste lösung ist. Wie würde man das besser machen?
LG
ich habe eine Klasse für univariate lineare Regression:
[CODE lang="java" title="UnivariateLinearRegression"]public class UnivariateLinearRegression implements LinearModel {
private double w0, w1; // weights hw(x) = w1*x + w0
public UnivariateLinearRegression() {
w0 = Math.random();
w1 = Math.random();
}
@Override
public void train(double[][] data, int epochs, double alpha) {
for(int i = 0; i < epochs; i++) {
double errorw0 = sumErrorW0(data);
double errorw1 = sumErrorW1(data);
w0 = w0 + alpha * errorw0;
w1 = w1 + alpha * errorw1;
}
}
private double sumErrorW0(double[][] data) {
double sum = 0;
for (int i = 0; i < data.length; i++) {
if (data.length != 2)
throw new IllegalArgumentException("Each Data-Point must have two elements");
sum += data[1] - (w1 * data[0] + w0);
}
return sum;
}
private double sumErrorW1(double[][] data) {
double sum = 0;
for (int i = 0; i < data.length; i++) {
if (data.length != 2)
throw new IllegalArgumentException("Each Data-Point must have two elements");
sum += (data[1] - (w1 * data[0] + w0)) * data[0];
}
return sum;
}
}[/CODE]
Nun habe ich das Problem, dass die Qualität des Ergebnisses sehr stark von den Startwerten abhängt. Momentan mache ich das ja mit Math.random() was (offensichtlich) nicht die beste lösung ist. Wie würde man das besser machen?
LG