#==========
Julia の修行をするときに,いろいろなプログラムを書き換えるのは有効な方法だ。
以下のプログラムを Julia に翻訳してみる。
判別分析(線形判別関数)
http://aoki2.si.gunma-u.ac.jp/R/disc.html
ファイル名: disc.jl 関数名: disc, printdisc, plotdisc
翻訳するときに書いたメモ
結果の表示に NamedArray を使ってみた。
==========#
using CSV, DataFrames, Statistics, LinearAlgebra, Rmath, NamedArrays, Plots, StatsPlots
function disc(data::DataFrame, group)
disc(Matrix(data), group, name=names(data))
end
function disc(data::Array{Int64,2}, group; name=[])
disc(Matrix(data), group, name)
end
function disc(data::Array{Float64,2}, group; name=[])
ncase, p = size(data)
if length(name) == 0
name = "X" .* string.(1:p)
end
gname, num = table(group)
ng = length(num)
gmean = vec(mean(data, dims=1))
t = cov(data, corrected=false) .* ncase
means = zeros(ng, p)
vars = zeros(p, p, ng)
for i = 1:ng
gdata = data[group .== gname[i], :]
means[i, :] = vec(mean(gdata, dims=1))
vars[:, :, i] = cov(gdata, corrected=false) .* num[i]
end
w = sum(vars, dims=3)[:, :, 1]
gsd = vec(std(data, dims=1))
detw = det(w)
dett = det(t)
wl = detw / dett
temp = (w \ transpose(means))
a = -2 * (ncase - ng) .* temp
a0 = sum(transpose(temp) .* means, dims=2) * (ncase - ng)
cfunction = vcat(a, a0')
m = (ng - 1) * ng ÷ 2
dfunction = zeros(p + 1, m)
header = fill("", m)
k = 0
for i = 1:ng-1
for j = i+1:ng
k += 1
header[k] = gname[i] * ":" * gname[j]
dfunction[:, k] = (cfunction[:, j] - cfunction[:, i]) ./ 2
end
end
invw = inv(w)
F = diag(inv(t) ./ invw)
idf1 = ng - 1
idf2 = ncase - idf1 - p
F = idf2 / idf1 * (1 .- F) ./ F
P = pf.(F, idf1, idf2, false)
c1 = (p ^ 2 + idf1 ^ 2 != 5) ? sqrt((p ^ 2 * idf1 ^ 2 - 4) / (p ^ 2 + idf1 ^ 2 - 5)) : 1
c2 = wl ^ (1 / c1)
df1 = p * idf1
df2 = (ncase - 1 - (p + ng) / 2) * c1 + 1 - 0.5 * p * idf1
Fwl = df2 * (1 - c2) / (df1 * c2)
Pwl = pf(Fwl, df1, df2, false)
D2 = zeros(ncase, ng);
tdata = transpose(data);
for i = 1:ng # i = 1; j = 1
temp = tdata .- means[i, :];
for j = 1:ncase # j = 1
D2[j, i] = temp[:, j]' * invw * temp[:, j]
end
end
D2 = (ncase-ng) .* D2
P2 = pchisq.(D2, p, false)
prediction = [gname[argmax(P2[j, :])] for j = 1:ncase]
correct = prediction .== group
factor1, factor2, correcttable = table(group, prediction)
correctrate = sum(diag(correcttable)) / ncase * 100
if ng == 2
discriminantvalue = data * dfunction[1:p] .+ dfunction[p + 1]
else
discriminantvalue = []
end
Dict(:dfunction => dfunction, :header => header,
:cfunction => cfunction,:partialF => F,
:partialFP => P, :df1 => idf1, :df2 => idf2, :wilkslambda => wl,
:wilkslambdaF => Fwl, :wilkslambdaP => Pwl, :wilkslambdadf1 => df1,
:wilkslambdadf2 => df2, :distance => D2, :Pvalue => P2,
:prediction => prediction, :correct => correct,
:correcttable => correcttable, :correctrate => correctrate,
:discriminantvalue => discriminantvalue, :group => group,
:factor1 => factor1, :factor2 => factor2,
:name => name, :gname => gname, :num => num)
end
function table(x) # indices が少ないとき
indices = sort(unique(x))
counts = zeros(Int, length(indices))
for i in indexin(x, indices)
counts[i] += 1
end
return indices, counts
end
function table(x, y) # 二次元
indicesx = sort(unique(x))
indicesy = sort(unique(y))
counts = zeros(Int, length(indicesx), length(indicesy))
for (i, j) in zip(indexin(x, indicesx), indexin(y, indicesy))
counts[i, j] += 1
end
return indicesx, indicesy, counts
end
function printdisc(obj::Dict{Symbol, Any})
gname = obj[:gname]
ng = length(gname)
# =====
header = obj[:header]
dfunction = NamedArray(obj[:dfunction],
(vcat(obj[:name], "Constant"),
"Func." .* string.(1:length(header))))
println("判別関数\n", dfunction)
for i = 1:length(header)
println(" Func.$i: $(header[i])")
end
# =====
partialF = NamedArray(hcat(obj[:partialF], obj[:partialFP]),
(obj[:name], ["Partial F", "p value"]))
println("偏 F 値,p 値\n", partialF)
println(" 偏 F 値の自由度 ($(obj[:df1]), $(obj[:df2]))")
# =====
cfunction = NamedArray(obj[:cfunction],
(vcat(obj[:name], "Constant"), gname))
println("分類関数\n", cfunction)
# =====
correcttable = NamedArray(obj[:correcttable],
(obj[:factor1], obj[:factor2]))
println("判別結果\n", correcttable)
println("正判別率 = $(round(obj[:correctrate], digits=2))%")
# =====
end
function plotdisc(obj; which = "boxplot", # or "barplot" or "scatterplot"
nclass = 20, color1=:blue, color2=:red)
if length(obj[:discriminantvalue]) != 0
pyplot()
if which == "boxplot"
boxplot(string.(obj[:group]), obj[:discriminantvalue], xlabel = "群", ylabel = "判別値", label="")
elseif which == "barplot"
discriminantvalue = obj[:discriminantvalue];
minx, maxx = extrema(discriminantvalue)
w = (maxx - minx) / (nclass - 1)
discv = floor.(Int, (discriminantvalue .- minx) ./ w)
index1, index2, res = table(discv, obj[:group])
groupedbar(res, xlabel = "判別値($nclass 階級に基準化)", label="")
else
gname = obj[:gname]
g1 = obj[:group] .== obj[:gname][1];
g2 = obj[:group] .== obj[:gname][2];
plt = scatter(obj[:distance][g1, 1], obj[:distance][g1, 2],
color = color1, markerstrokecolor = color1,
xlabel = "$(gname[1]) の重心への二乗距離",
ylabel = "$(gname[2]) の重心への二乗距離", aspect_ratio = 1, label=gname[1])
scatter!(obj[:distance][g2, 1], obj[:distance][g2, 2],
color = color2, markerstrokecolor = color2, label=gname[2])
end
else
error("3群以上の場合にはグラフ表示は用意されていません")
end
end
using RDatasets
iris = dataset("datasets", "iris");
data = Matrix(iris[:, 1:4]); # typeof(data)
name = names(iris)[1:4]; # typeof(name)
group = iris[:, 5]; # typeof(group)
res = disc(data, group, name=name)
printdisc(res)
判別関数
5×3 Named Matrix{Float64}
A ╲ B │ Func.1 Func.2 Func.3
────────────┼─────────────────────────────
SepalLength │ 7.84596 11.0983 3.25236
SepalWidth │ 16.5154 19.9026 3.38723
PetalLength │ -21.6421 -29.1972 -7.55509
PetalWidth │ -23.8326 -38.4775 -14.6449
Constant │ -13.4559 18.0599 31.5157
Func.1: setosa:versicolor
Func.2: setosa:virginica
Func.3: versicolor:virginica
偏 F 値,p 値
4×2 Named Matrix{Float64}
A ╲ B │ Partial F p value
────────────┼─────────────────────────
SepalLength │ 4.72115 0.0103288
SepalWidth │ 21.9359 4.8312e-9
PetalLength │ 35.5902 2.75621e-13
PetalWidth │ 24.9043 5.14315e-10
偏 F 値の自由度 (2, 144)
分類関数
5×3 Named Matrix{Float64}
A ╲ B │ setosa versicolor virginica
────────────┼───────────────────────────────────
SepalLength │ -47.0883 -31.3964 -24.8917
SepalWidth │ -47.1757 -14.145 -7.37056
PetalLength │ 32.8613 -10.4229 -25.5331
PetalWidth │ 34.7968 -12.8685 -42.1582
Constant │ 170.42 143.508 206.539
判別結果
3×3 Named Matrix{Int64}
A ╲ B │ setosa versicolor virginica
───────────┼───────────────────────────────────
setosa │ 50 0 0
versicolor │ 0 48 2
virginica │ 0 1 49
正判別率 = 98.0%
data = Matrix(iris[51:150, 1:4]); # typeof(data)
group = iris[51:150, 5]; # typeof(group)
name = names(iris)[1:4]; # typeof(name)
obj = disc(data, group, name=name)
printdisc(obj)
判別関数
5×1 Named Matrix{Float64}
A ╲ B │ Func.1
────────────┼─────────
SepalLength │ 3.5563
SepalWidth │ 5.57862
PetalLength │ -6.97013
PetalWidth │ -12.386
Constant │ 16.6631
Func.1: versicolor:virginica
偏 F 値,p 値
4×2 Named Matrix{Float64}
A ╲ B │ Partial F p value
────────────┼───────────────────────
SepalLength │ 7.36791 0.00788617
SepalWidth │ 10.5875 0.00157755
PetalLength │ 24.1566 3.70346e-6
PetalWidth │ 37.0916 2.38356e-8
偏 F 値の自由度 (1, 95)
分類関数
5×2 Named Matrix{Float64}
A ╲ B │ versicolor virginica
────────────┼───────────────────────
SepalLength │ -30.802 -23.6894
SepalWidth │ -31.9428 -20.7856
PetalLength │ 5.16328 -8.77697
PetalWidth │ 1.17289 -23.5992
Constant │ 123.886 157.212
判別結果
2×2 Named Matrix{Int64}
A ╲ B │ versicolor virginica
───────────┼───────────────────────
versicolor │ 48 2
virginica │ 1 49
正判別率 = 97.0%
plotdisc(obj)
plotdisc(obj, which="barplot")
plotdisc(obj, which="scatterplot")