分数係数による多項式回帰
格子点を通る多項式を求める場合などで,多項式の係数が分数のままのものが欲しいときがある。
係数が浮動小数点でもよいなら R の lm() を使っても求まるが,特別なプログラムを書いた。
using Statistics, Plots
function RationalLM(x::Vector{Int64}, y::Vector{Int64}, p::Int64 = length(x) - 1)
n = length(x)
dat = ones(Rational, n, p + 1);
for j = 1:p
dat[:, j] = (x .// big(1)) .^ j
end
dat[:, p+1] = big.(y)
s = cov(dat)
b = s[1:end-1, 1:end-1] \ s[end, 1:end-1]
means = mean(dat, dims=1)
pushfirst!(b, means[end] .- b' * means[1:end-1])
den = lcm(denominator.(b))
num = numerator.(b * lcm(den))
(x=x, y=y, p=p, num=Int64.(num), den=Int64(den), b=b, fb=Float64.(b))
end
predict(obj, x) = [(xi .^ collect(0:obj.p))' * obj.b for xi in x]
function predict(obj)
pred = predict(obj, obj.x)
for i = 1:length(obj.x)
println("x = $(obj.x[i]) \ty = $(obj.y[i]) \tpredict = $(pred[i]) \t= $(Float64(pred[i]))")
end
end
function summary(obj)
for i = 1:obj.p + 1
println("coef$(i-1) \t= $(obj.b[i]) \t= $(obj.num[i]) // $(obj.den) \t= $(obj.fb[i])")
end
end
function plot_results(obj; width=400, height=300)
pyplot(grid=false, size=(width, height), label="")
scatter(obj.x, obj.y, tick_direction=:out)
minx, maxx = extrema(obj.x)
margin = (maxx - minx) * 0.1
x2 = range(minx - margin, maxx + margin, length=1000)
y2 = predict(obj, x2)
plot!(x2, y2)
end
使用法
x, y は整数ベクトルであること。
x = collect(0:5);
y = [1,3,7,12,8,15];
あてはめ(多項式の係数を求める)
a = RationalLM(x, y);
多項式の係数を,分数形式,分母を共通とする分数形式,浮動小数点形式でまとめて表示する。
summary(a)
#=
coef0 = 1//1 = 120 // 120 = 1.0
coef1 = 643//60 = 1286 // 120 = 10.716666666666667
coef2 = -151//8 = -2265 // 120 = -18.875
coef3 = 323//24 = 1615 // 120 = 13.458333333333334
coef4 = -29//8 = -435 // 120 = -3.625
coef5 = 13//40 = 39 // 120 = 0.325
=#
y の予測値を表示する。
predict(a)
#=
x = 0 y = 1 predict = 1//1 = 1.0)
x = 1 y = 3 predict = 3//1 = 3.0)
x = 2 y = 7 predict = 7//1 = 7.0)
x = 3 y = 12 predict = 12//1 = 12.0)
x = 4 y = 8 predict = 8//1 = 8.0)
x = 5 y = 15 predict = 15//1 = 15.0)
=#
図に示す。
plot_results(a)
第3引数で,多項式の次数を指定できる。
b = RationalLM(x, y, 3);
summary(b)
#=
coef0 = 17//42 = 102 // 252 = 0.40476190476190477
coef1 = 1285//252 = 1285 // 252 = 5.099206349206349
coef2 = -7//6 = -294 // 252 = -1.1666666666666667
coef3 = 5//36 = 35 // 252 = 0.1388888888888889
=#
predict(b)
#=
x = 0 y = 1 predict = 17//42 = 0.40476190476190477
x = 1 y = 3 predict = 94//21 = 4.476190476190476
x = 2 y = 7 predict = 148//21 = 7.0476190476190474
x = 3 y = 12 predict = 188//21 = 8.952380952380953
x = 4 y = 8 predict = 463//42 = 11.023809523809524
x = 5 y = 15 predict = 296//21 = 14.095238095238095
=#
plot_results(b)
※コメント投稿者のブログIDはブログ作成者のみに通知されます