2016年4月6日水曜日

最小二乗法をやってみた

たまには数値計算でもしてみようかなぁと思い、最近はやりの最小二乗法を作ってみました。
グラフツールとかインストールしていないので、コンソールにグラフを書く簡単なライブラリも作ってみました。


おーなんかあってるっぽいね。

-----------------------------------------------
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>

#define WIDTH 80
#define HEIGHT 25

//
//myplot
//

static char dot[HEIGHT][WIDTH];
static double off_x = 0.0;
static double off_y = 0.0;
static double scale = 1.0;

static void myplot_init(void)
{
int i;
memset(dot, ' ', HEIGHT*WIDTH);
for (i = 0; i < WIDTH; i++)dot[0][i] = '-';
for (i = 0; i < HEIGHT; i++){
dot[i][0] = '|';
dot[i][WIDTH - 1] = 0;
}
}

static void myplot_plot(double x, double y, char c)
{
int ix, iy;
x = (x - off_x)*scale * 2;
y = (y - off_y)*scale;
ix = (int)x;
iy = (int)y;
if (ix<0 || ix>WIDTH - 2)return;
if (iy<0 || iy>HEIGHT - 1)return;
dot[iy][ix] = c;
}

static void myplot_print(void)
{
int i;
for (i = 0; i < HEIGHT - 1; i++)printf("%s\n", dot[HEIGHT - 2 - i]);
}

void myplot_set_scale(double off_x_, double off_y_, double scale_)
{
off_x = off_x_;
off_y = off_y_;
scale = scale_;
}


//
// saisyou2zyou
//

int gauss(double **w, int n, int m, double eps)
{
double y1, y2;
int ind = 0, nm, m1, m2, i1, i2, i3;

nm = n + m;

for (i1 = 0; i1 < n && ind == 0; i1++) {

y1 = .0;
m1 = i1 + 1;
m2 = 0;

for (i2 = i1; i2 < n; i2++) {
y2 = fabs(w[i2][i1]);
if (y1 < y2) {
y1 = y2;
m2 = i2;
}
}

if (y1 < eps)
ind = 1;

else {

for (i2 = i1; i2 < nm; i2++) {
y1        = w[i1][i2];
w[i1][i2] = w[m2][i2];
w[m2][i2] = y1;
}

y1 = 1.0 / w[i1][i1];

for (i2 = m1; i2 < nm; i2++)
w[i1][i2] *= y1;

for (i2 = 0; i2 < n; i2++) {
if (i2 != i1) {
for (i3 = m1; i3 < nm; i3++)
w[i2][i3] -= w[i2][i1] * w[i1][i3];
}
}
}
}

return(ind);
}

double *least(int m, int n, double *x, double *y)
{
double **A, **w, *z, x1, x2;
int i1, i2, i3, sw;

m++;
z = (double*)malloc(sizeof(double)*m);
w = (double**)malloc(sizeof(double*)*m);
for (i1 = 0; i1 < m; i1++)
w[i1] = (double*)malloc(sizeof(double)*(m + 1));
A = (double**)malloc(sizeof(double*)*n);

for (i1 = 0; i1 < n; i1++) {
A[i1] = (double*)malloc(sizeof(double)*m);
A[i1][m - 2] = x[i1];
A[i1][m - 1] = 1.0;
x1 = A[i1][m - 2];
x2 = x1;
for (i2 = m - 3; i2 >= 0; i2--) {
x2 *= x1;
A[i1][i2] = x2;
}
}

for (i1 = 0; i1 < m; i1++) {
for (i2 = 0; i2 < m; i2++) {
w[i1][i2] = 0.0;
for (i3 = 0; i3 < n; i3++)
w[i1][i2] += A[i3][i1] * A[i3][i2];
}
}

for (i1 = 0; i1 < m; i1++) {
w[i1][m] = 0.0;
for (i2 = 0; i2 < n; i2++)
w[i1][m] += A[i2][i1] * y[i2];
}

sw = gauss(w, m, 1, 1.0e-10);

if (sw == 0) {
for (i1 = 0; i1 < m; i1++)
z[i1] = w[i1][m];
}
else
z = NULL;

for (i1 = 0; i1 < n; i1++)
free(A[i1]);
for (i1 = 0; i1 < m; i1++)
free(w[i1]);
free(A);
free(w);

return z;
}

//
// main
//
int main()
{
double x[] = { 10, 15, 20, 26, 32, 40 };
double y[] = { 28.2, 47, 44.4, 32.8, 20.8, 0.8 };

double *z;
double d;
int i, m, n;

m = 2;
n = 6;

myplot_init();
myplot_set_scale(0, 0, 0.5);
for (i = 0; i < 10; i++){
myplot_plot(i * 10, 0, '+');
myplot_plot(0, i * 10, '+');
}
for (i = 0; i < n; i++)myplot_plot(x[i], y[i], '*');

z = least(m, n, x, y);

if (z == NULL)return 0;

for (i = 0; i < 80; i++){
d = z[0] * i * i + z[1] * i + z[2];
myplot_plot(i, d, 'o');
}

myplot_print();

free(z);

return 0;
}

-----------------------------------------------


0 件のコメント:

コメントを投稿