#==========
Julia の修行をするときに,いろいろなプログラムを書き換えるのは有効な方法だ。
以下のプログラムを Julia に翻訳してみる。
判別分析(線形判別関数;ステップワイズ変数選択)
http://aoki2.si.gunma-u.ac.jp/R/sdis.html
ファイル名: sdis.jl 関数名: sdis
翻訳するときに書いたメモ
==========#
using CSV, DataFrames, Rmath, Statistics, LinearAlgebra, NamedArrays, Printf, Plots
function sdis(data::Array{Float64,2}, group; name=[], stepwise=true, Pin=0.05, Pout=0.05, predict=false, verbose=false)
getitem(t, lxi) = [t[i, i] for i in lxi]
formatpval(p) = p >= 0.000001 ? @sprintf("%.6f", p) : "< 0.000001"
function stepout(isw)
step += 1
ncasek = ncase - ng
isw != 0 && verbose && println("\n ***** ステップ $step ***** \n",
" $(["編入", "除去"][isw])変数: $(name[ip])")
lxi = lx[1:ni]
a = zeros(ni, ng)
a0 = zeros(ng)
for g = 1:ng
a[:, g] = -(w[lxi, lxi] * means[lxi, g]) * 2ncasek
a0[g] = transpose(means[lxi, g]) * w[lxi, lxi] * means[lxi, g] * ncasek
end
idf1 = ng - 1
idf2 = ncase - (ng - 1) - ni
temp = idf2 / idf1
f = getitem(t, lxi) ./ getitem(w, lxi)
f = temp * (1 .- f) ./ f
P = pf.(f, idf1, idf2, false)
alp = ng - 1
b = ncase - 1 - 0.5 * (ni + ng)
qa = ni ^ 2 + alp ^ 2
c = 1
qa != 5 && (c = sqrt((ni ^ 2 * alp ^ 2 - 4) / (qa - 5)))
wldf1 = ni * alp
wldf2 = b * c + 1 - 0.5 * ni * alp
wl = detw / dett
cl = exp(log(wl) / c)
wlf = wldf2 * (1 - cl) / (wldf1 * cl)
wlp = pf(wlf, wldf1, wldf2, false)
results = merge(results, Dict(:rownames => name[lxi], :a => a, :a0 => a0,
:f => f, :idf1 => idf1, :idf2 => idf2, :P => P,
:wl => wl, :wlf => wlf, :wldf1 => wldf1,
:wldf2 => wldf2, :wlp => wlp))
end
function printclassfunc()
println("\n***** 分類関数 *****")
@printf("%12s", "")
for i = 1:ng
@printf(" %12s", gname[i])
end
@printf(" %12s %12s\n", "偏F値", "P値")
rownames = results[:rownames]
for j = 1:length(rownames)
@printf("%12s", rownames[j])
for i = 1:ng
@printf(" %12.6f", results[:a][j, i])
end
@printf(" %12.6f %12s\n", results[:f][j], formatpval(results[:P][j]))
end
@printf("%12s", "定数項")
for i = 1:ng
@printf(" %12.6f", results[:a0][i])
end
println("\n\nウィルクスのΛ: $(results[:wl])")
println("等価なF値: $(results[:wlf])")
println("自由度: ($(results[:wldf1]), $(results[:wldf2]))")
println("P値: $(results[:wlp])")
end
function fmax()
kouho = 1:p
if ni > 0
suf = trues(p)
[suf[i] = false for i in lx[1:ni]]
kouho = (1:p)[suf]
end
temp = getitem(w, kouho) ./ getitem(t, kouho)
temp = (1 .- temp) ./ temp
ip = argmax(temp)
return temp[ip], kouho[ip]
end
function fmin()
kouho = lx[1:ni]
temp = getitem(t, kouho) ./ getitem(w, kouho)
temp = (1 .- temp) ./ temp
ip = argmin(temp)
return temp[ip], lx[ip]
end
function sweepsdis!(r, det, ip, p)
ap = r[ip, ip]
abs(ap) > EPSINV || error("正規方程式の係数行列が特異行列です")
det *= ap
for i = 1:p
if i != ip
temp = r[ip, i] / ap
for j = 1:p
if j != ip
r[j, i] -= r[j, ip] * temp
end
end
end
end
r[:, ip] /= ap
r[ip, :] /= -ap
r[ip, ip] = 1 / ap
det
end
function discriminantfunction()
lxi = lx[1:ni]
side = name[lxi]
ncasek = ncase - ng
p0 = length(lxi)
m = ng * (ng - 1) ÷ 2
dfunc = zeros(p0 + 1, m)
stddfunc = zeros(p0, m)
header = fill("", m)
dist = zeros(m)
errorp = zeros(m)
k = 0
for g1 in 1:ng - 1
for g2 in g1 + 1:ng
k += 1
header[k] = gname[g1] * ":" * gname[g2]
xx = means[lxi, g1] .- means[lxi, g2]
fn = w[lxi, lxi] * xx .* ncasek
stddfunc[1:p0, k] = sd[lxi] .* fn
fn0 = -sum(fn .* (means[lxi, g1] .+ means[lxi, g2]) .* 0.5)
dfunc[1:p0, k] = fn
dfunc[p0 + 1, k] = fn0
dist[k] = sqrt(sum(xx .* fn))
errorp[k] = pnorm(dist[k] * 0.5, false)
end
end
printdfunc(header, side, dfunc, stddfunc, dist, errorp)
results = merge(results, Dict(:header => header, :side => side, :dfunc => dfunc,
:dist => dist, :errorp => errorp))
end
function printdfunc(header, side, dfunc, stddfunc, dist, errorp)
simpleheader = "Func." .* string.(1:length(header))
array1 = NamedArray(dfunc,
(vcat(side, "Constant"), simpleheader))
println("\n判別関数\n", array1)
array2 = NamedArray(stddfunc, (side, simpleheader))
println("\n標準化判別関数\n", array2)
array3 = NamedArray(vcat(dist', errorp'),
(["マハラノビスの汎距離", "理論的誤判別率"], simpleheader))
println("\n", array3)
for i = 1:length(header)
println(" Func.$i: $(header[i])")
end
end
function procpredict()
nc0 = 0
ncasek = ncase - ng
lxi = lx[1:ni]
data = data[:, lxi]
tdata = transpose(data)
dis = zeros(ncase, ng)
for g = 1:ng
temp = tdata .- means[lxi, g]
for j = 1:ncase
dis[j, g] = temp[:, j]' * w[lxi, lxi] * temp[:, j]
end
end
dis .*= ncase - ng
P = pchisq.(dis, p, false)
prediction = [gname[argmax(P[j, :])] for j = 1:ncase]
index1, tbl = table(prediction)
correct = prediction .== group
factor1, factor2, correcttable = table(group, prediction)
correctrate = sum(diag(correcttable)) / ncase * 100
if ng == 2
fn = results[:dfunc][1:end-1]
fn0 = results[:dfunc][end]
dfv = data * fn .+ fn0
else
dfv = NaN
end
results = merge(results, Dict(:dis => dis, :P => P,
:prediction => prediction, :correct => correct,
:correcttable => correcttable,
:factor1 => factor1, :factor2 => factor2,
:correctrate => correctrate, :dfv => dfv))
end
EPSINV = 1e-6
Pout >= Pin || (Pout = Pin)
step = 0
ip = 0
lxi = []
ncase, p = size(data)
p > 1 || (stepwise = false)
length(name) != 0 || (name = ["x" * string(i) for i = 1:p])
gname, num = table(group)
ng = length(gname)
ng > 1 || error("1群しかありません")
any(num .>= 2) || error("ケース数が1以下の群があります")
gmeans = vec(mean(data, dims=1))
t = cov(data, corrected=false) .* ncase
means = zeros(p, ng)
vars = zeros(p, p, ng)
for i = 1:ng
gdata = data[group .== gname[i], :]
means[:, i] = vec(mean(gdata, dims=1))
vars[:, :, i] = cov(gdata, corrected=false) .* num[i]
end
if verbose
println("有効ケース数: $ncase")
println("判別するグループ: $gname")
println("***** 平均値 *****")
println(NamedArray(hcat(means, gmeans),
(name, vcat(gname, "全体"))))
end
w = sum(vars, dims=3)[:, :, 1]
detw = dett = 1
sd2 = sqrt.(diag(w) ./ ncase)
r = w ./ (sd2 * sd2') ./ ncase
if verbose
println("\n***** プールされた群内相関係数行列 *****")
println(NamedArray(r, (name, name)))
end
sd = sqrt.(diag(t) ./ ncase)
results = Dict(:ncase => ncase, :gname => gname,
:means => means, :gmeans => gmeans, :r => r)
if stepwise == false
for ip = 1:p
detw = sweepsdis!(w, detw, ip, p)
dett = sweepsdis!(t, dett, ip, p)
end
lx = 1:p
ni = p
stepout(0)
else
verbose && println("\n変数編入基準 Pin: $Pin\n",
"変数除去基準 Pout: $Pout")
lx = zeros(Int, p)
ni = 0
while ni != p
ansmax, ip = fmax()
P = (ncase - ng - ni) / (ng - 1) * ansmax
P = pf(P, ng - 1, ncase - ng - ni, false)
verbose && println("編入候補変数: $(name[ip]) P: $(formatpval(P))")
if P > Pin
verbose && println(" 編入されませんでした")
break
end
verbose && println(" 編入されました")
ni += 1
lx[ni] = ip
detw = sweepsdis!(w, detw, ip, p)
dett = sweepsdis!(t, dett, ip, p)
stepout(1)
verbose && printclassfunc()
while true
ansmin, ip = fmin()
P = (ncase - ng - ni + 1) / (ng - 1) * ansmin
P = pf(P, ng - 1, ncase - ng - ni + 1, false)
verbose && println("\n除去候補変数: $(name[ip]) P: $(formatpval(P))")
if P <= Pout
verbose && println(" 除去されませんでした")
break
else
verbose && println(" 除去されました")
lx = lx[lx .!= ip]
ni -= 1
w, detw = sweepsdis(w, detw)
t, dett = sweepsdis(t, dett)
stepout(2)
verbose && printclassfunc()
end
end
end
end
ni == 0 && error("条件(Pin < $Pin)を満たす独立変数がありません")
verbose && println("\n========== 結果 ==========")
printclassfunc()
discriminantfunction()
procpredict()
if predict
println("\n********** 各ケースの判別結果 **********")
println("\n各群への二乗距離")
if ng == 2
println(NamedArray(hcat(results[:dis], results[:dfv]), (1:ncase, vcat(gname, "判別値"))))
else
println(NamedArray(results[:dis], (1:ncase, gname)))
end
println("\nP 値")
println(NamedArray(results[:P], (1:ncase, gname)))
println(" メモ:「二乗距離」とは,各群の重心までのマハラノビスの汎距離の二乗です。")
println(" P値は各群に属する確率です。")
println("\n判別,判別の正否")
println(NamedArray(hcat(string.(group), results[:prediction], results[:correct]),
(1:ncase, ["実際の群", "判別された群", "正否"])))
end
println("\n判別結果")
println(NamedArray(results[:correcttable], (results[:factor1], results[:factor2])))
println("\n正判別率 = $(results[:correctrate])\n")
results
end
function table(x) # indices が少ないとき
indices = sort(unique(x))
counts = zeros(Int, length(indices))
for i in indexin(x, indices)
counts[i] += 1
end
return indices, counts
end
function table(x, y) # 二次元
indicesx = sort(unique(x))
indicesy = sort(unique(y))
counts = zeros(Int, length(indicesx), length(indicesy))
for (i, j) in zip(indexin(x, indicesx), indexin(y, indicesy))
counts[i, j] += 1
end
return indicesx, indicesy, counts
end
function sdis(data::DataFrame, group; stepwise=true, Pin=0.05, Pout=0.05, predict=false, verbose=false)
sdis(Matrix(data), group, name=names(data), stepwise, Pin, Pout, predict, verbose)
end
function sdis(data::Array{Int64,2}, group; name=[], stepwise=true, Pin=0.05, Pout=0.05, predict=false, verbose=false)
sdis(Matrix(data), group, name, stepwise, Pin, Pout, predict, verbose)
end
using RDatasets
iris = dataset("datasets", "iris")
name = names(iris)[1:4]
data = Matrix(iris[51:150, 1:4])
group = vec(iris[51:150, 5])
sdis(data, group, name=name, verbose=false)
***** 分類関数 *****
versicolor virginica 偏F値 P値
PetalWidth 1.172895 -23.599188 37.091608 < 0.000001
SepalWidth -31.942816 -20.785575 10.587491 0.001578
PetalLength 5.163281 -8.776974 24.156597 0.000004
SepalLength -30.802050 -23.689444 7.367913 0.007886
定数項 123.885865 157.212036
ウィルクスのΛ: 0.21611029704367302
等価なF値: 86.14758620895533
自由度: (4, 95.0)
P値: 9.53987626477818e-31
判別関数
5×1 Named Matrix{Float64}
A ╲ B │ Func.1
────────────┼─────────
PetalWidth │ -12.386
SepalWidth │ 5.57862
PetalLength │ -6.97013
SepalLength │ 3.5563
Constant │ 16.6631
標準化判別関数
4×1 Named Matrix{Float64}
A ╲ B │ Func.1
────────────┼─────────
PetalWidth │ -5.23483
SepalWidth │ 1.84699
PetalLength │ -5.72554
SepalLength │ 2.34542
2×1 Named Matrix{Float64}
A ╲ B │ Func.1
───────────┼──────────
マハラノビスの汎距離 │ 3.77079
理論的誤判別率 │ 0.0296881
Func.1: versicolor:virginica
判別結果
2×2 Named Matrix{Int64}
A ╲ B │ versicolor virginica
───────────┼───────────────────────
versicolor │ 48 2
virginica │ 1 49
正判別率 = 97.0
#=====
Dict{Symbol, Any} with 30 entries:
:dis => [5.29832 23.9158; 2.105 16.7657; … ; 21.4011 4.50692; 10.154…
:factor1 => ["versicolor", "virginica"]
:means => [5.936 6.588; 2.77 2.974; 4.26 5.552; 1.326 2.026]
:ncase => 100
:wlf => 86.1476
:prediction => ["versicolor", "versicolor", "versicolor", "versicolor", "ve…
:header => ["versicolor:virginica"]
:gmeans => [6.262, 2.872, 4.906, 1.676]
:wlp => 9.53988e-31
:correct => Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 1 … 1, 1, 1, 1, 1, 1, 1, 1…
:correctrate => 97.0
:dfv => [9.30873, 7.33037, 5.76261, 5.07121, 4.75754, 5.08672, 4.899…
:wldf2 => 95.0
:rownames => ["PetalWidth", "SepalWidth", "PetalLength", "SepalLength"]
:gname => ["versicolor", "virginica"]
:factor2 => ["versicolor", "virginica"]
:errorp => [0.0296881]
:dist => [3.77079]
:idf1 => 1
:correcttable => [48 2; 1 49]
:wldf1 => 4
:r => [1.0 0.48557 0.818971 0.378356; 0.48557 1.0 0.472261 0.58332…
:idf2 => 95
:wl => 0.21611
:a => [1.17289 -23.5992; -31.9428 -20.7856; 5.16328 -8.77697; -30.…
⋮ => ⋮
====#
name = names(iris)[1:4]
data = Matrix(iris[:, 1:4])
group = vec(iris[:, 5])
sdis(data, group, name=name, verbose=false)
#=====
***** 分類関数 *****
setosa versicolor virginica 偏F値 P値
PetalLength 32.861278 -10.422902 -25.533090 35.590175 < 0.000001
SepalWidth -47.175741 -14.145020 -7.370559 21.935928 < 0.000001
PetalWidth 34.796822 -12.868458 -42.158226 24.904333 < 0.000001
SepalLength -47.088333 -31.396418 -24.891698 4.721152 0.010329
定数項 170.419715 143.507990 206.539415
ウィルクスのΛ: 0.023438630650878322
等価なF値: 199.14534354008427
自由度: (8, 288.0)
P値: 1.3650058325897325e-112
判別関数
5×3 Named Matrix{Float64}
A ╲ B │ Func.1 Func.2 Func.3
────────────┼─────────────────────────────
PetalLength │ -21.6421 -29.1972 -7.55509
SepalWidth │ 16.5154 19.9026 3.38723
PetalWidth │ -23.8326 -38.4775 -14.6449
SepalLength │ 7.84596 11.0983 3.25236
Constant │ -13.4559 18.0599 31.5157
標準化判別関数
4×3 Named Matrix{Float64}
A ╲ B │ Func.1 Func.2 Func.3
────────────┼─────────────────────────────
PetalLength │ -38.0772 -51.3696 -13.2925
SepalWidth │ 7.17445 8.6459 1.47145
PetalWidth │ -18.1055 -29.2311 -11.1256
SepalLength │ 6.47528 9.15946 2.68418
2×3 Named Matrix{Float64}
A ╲ B │ Func.1 Func.2 Func.3
───────────┼──────────────────────────────────────
マハラノビスの汎距離 │ 9.47967 13.3935 4.14742
理論的誤判別率 │ 1.06946e-6 1.06568e-11 0.0190532
Func.1: setosa:versicolor
Func.2: setosa:virginica
Func.3: versicolor:virginica
判別結果
3×3 Named Matrix{Int64}
A ╲ B │ setosa versicolor virginica
───────────┼───────────────────────────────────
setosa │ 50 0 0
versicolor │ 0 48 2
virginica │ 0 1 49
正判別率 = 98.0
Dict{Symbol, Any} with 30 entries:
:dis => [0.29109 98.8847 191.789; 2.03135 80.9713 169.187; … ; 188.8…
:factor1 => ["setosa", "versicolor", "virginica"]
:means => [5.006 5.936 6.588; 3.428 2.77 2.974; 1.462 4.26 5.552; 0.24…
:ncase => 150
:wlf => 199.145
:prediction => ["setosa", "setosa", "setosa", "setosa", "setosa", "setosa",…
:header => ["setosa:versicolor", "setosa:virginica", "versicolor:virgin…
:gmeans => [5.84333, 3.05733, 3.758, 1.19933]
:wlp => 1.36501e-112
:correct => Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 1 … 1, 1, 1, 1, 1, 1, 1, 1…
:correctrate => 98.0
:dfv => NaN
:wldf2 => 288.0
:rownames => ["PetalLength", "SepalWidth", "PetalWidth", "SepalLength"]
:gname => ["setosa", "versicolor", "virginica"]
:factor2 => ["setosa", "versicolor", "virginica"]
:errorp => [1.06946e-6, 1.06568e-11, 0.0190532]
:dist => [9.47967, 13.3935, 4.14742]
:idf1 => 2
:correcttable => [50 0 0; 0 48 2; 0 1 49]
:wldf1 => 8
:r => [1.0 0.530236 0.756164 0.364506; 0.530236 1.0 0.377916 0.470…
:idf2 => 144
:wl => 0.0234386
:a => [32.8613 -10.4229 -25.5331; -47.1757 -14.145 -7.37056; 34.79…
⋮ => ⋮
=====#