裏 RjpWiki

Julia ときどき R, Python によるコンピュータプログラム,コンピュータ・サイエンス,統計学

Julia に翻訳--193 判別分析,線形判別関数

2021年04月17日 | ブログラミング

#==========
Julia の修行をするときに,いろいろなプログラムを書き換えるのは有効な方法だ。
以下のプログラムを Julia に翻訳してみる。

判別分析(線形判別関数)
http://aoki2.si.gunma-u.ac.jp/R/disc.html

ファイル名: disc.jl  関数名: disc, printdisc, plotdisc

翻訳するときに書いたメモ

結果の表示に NamedArray を使ってみた。

==========#

using CSV, DataFrames, Statistics, LinearAlgebra, Rmath, NamedArrays, Plots, StatsPlots

function disc(data::DataFrame, group)
    disc(Matrix(data), group, name=names(data))
end

function disc(data::Array{Int64,2}, group; name=[])
    disc(Matrix(data), group, name)
end

function disc(data::Array{Float64,2}, group; name=[])
    ncase, p = size(data)
    if length(name) == 0
        name = "X" .* string.(1:p)
    end
    gname, num = table(group)
    ng = length(num)
    gmean = vec(mean(data, dims=1))
    t = cov(data, corrected=false) .* ncase
    means = zeros(ng, p)
    vars = zeros(p, p, ng)
    for i = 1:ng
        gdata = data[group .== gname[i], :]
        means[i, :] = vec(mean(gdata, dims=1))
        vars[:, :, i] = cov(gdata, corrected=false) .* num[i]
    end
    w = sum(vars, dims=3)[:, :, 1]
    gsd = vec(std(data, dims=1))
    detw = det(w)
    dett = det(t)
    wl = detw / dett
    temp = (w \ transpose(means))
    a = -2 * (ncase - ng) .* temp
    a0 = sum(transpose(temp) .* means, dims=2) * (ncase - ng)
    cfunction = vcat(a, a0')
    m = (ng - 1) * ng ÷ 2
    dfunction = zeros(p + 1, m)
    header = fill("", m)
    k = 0
    for i = 1:ng-1
        for j = i+1:ng
            k += 1
            header[k] = gname[i] * ":" * gname[j]
            dfunction[:, k] = (cfunction[:, j] - cfunction[:, i]) ./ 2
        end
    end
    invw = inv(w)
    F = diag(inv(t) ./ invw)
    idf1 = ng - 1
    idf2 = ncase - idf1 - p
    F = idf2 / idf1 * (1 .- F) ./ F
    P = pf.(F, idf1, idf2, false)
    c1 = (p ^ 2 + idf1 ^ 2 != 5) ? sqrt((p ^ 2 * idf1 ^ 2 - 4) / (p ^ 2 + idf1 ^ 2 - 5)) : 1
    c2 = wl ^ (1 / c1)
    df1 = p * idf1
    df2 = (ncase - 1 - (p + ng) / 2) * c1 + 1 - 0.5 * p * idf1
    Fwl = df2 * (1 - c2) / (df1 * c2)
    Pwl = pf(Fwl, df1, df2, false)
    D2 = zeros(ncase, ng);
    tdata = transpose(data);
    for i = 1:ng # i = 1; j = 1
        temp = tdata .- means[i, :];
        for j = 1:ncase # j = 1
            D2[j, i] = temp[:, j]' * invw * temp[:, j]
        end
    end
    D2 = (ncase-ng) .* D2
    P2 = pchisq.(D2, p, false)
    prediction = [gname[argmax(P2[j, :])] for j = 1:ncase]
    correct = prediction .== group
    factor1, factor2, correcttable = table(group, prediction)
    correctrate = sum(diag(correcttable)) / ncase * 100
    if ng == 2
        discriminantvalue = data * dfunction[1:p] .+ dfunction[p + 1]
    else
        discriminantvalue = []
    end
    Dict(:dfunction => dfunction, :header => header,
         :cfunction => cfunction,:partialF => F,
         :partialFP => P, :df1 => idf1, :df2 => idf2, :wilkslambda => wl,
         :wilkslambdaF => Fwl, :wilkslambdaP => Pwl, :wilkslambdadf1 => df1,
         :wilkslambdadf2 => df2, :distance => D2, :Pvalue => P2,
         :prediction => prediction, :correct => correct,
         :correcttable => correcttable, :correctrate => correctrate,
         :discriminantvalue => discriminantvalue, :group => group,
         :factor1 => factor1, :factor2 => factor2,
         :name => name, :gname => gname, :num => num)
end

function table(x) # indices が少ないとき
    indices = sort(unique(x))
    counts = zeros(Int, length(indices))
    for i in indexin(x, indices)
        counts[i] += 1
    end
    return indices, counts
end

function table(x, y) # 二次元
    indicesx = sort(unique(x))
    indicesy = sort(unique(y))
    counts = zeros(Int, length(indicesx), length(indicesy))
    for (i, j) in zip(indexin(x, indicesx), indexin(y, indicesy))
        counts[i, j] += 1
    end
    return indicesx, indicesy, counts
end

function printdisc(obj::Dict{Symbol, Any})
    gname = obj[:gname]
    ng = length(gname)
    # =====
    header = obj[:header]
    dfunction = NamedArray(obj[:dfunction],
                           (vcat(obj[:name], "Constant"),
                            "Func." .* string.(1:length(header))))
    println("判別関数\n", dfunction)
    for i = 1:length(header)
        println("     Func.$i: $(header[i])")
    end
    # =====
    partialF = NamedArray(hcat(obj[:partialF], obj[:partialFP]),
                          (obj[:name], ["Partial F", "p value"]))
    println("偏 F 値,p 値\n", partialF)
    println("     偏 F 値の自由度 ($(obj[:df1]), $(obj[:df2]))")
    # =====
    cfunction = NamedArray(obj[:cfunction],
                           (vcat(obj[:name], "Constant"), gname))
    println("分類関数\n", cfunction)
    # =====
    correcttable = NamedArray(obj[:correcttable],
                          (obj[:factor1], obj[:factor2]))
    println("判別結果\n", correcttable)
    println("正判別率 = $(round(obj[:correctrate], digits=2))%")
    # =====
end

function plotdisc(obj; which = "boxplot", # or "barplot" or "scatterplot"
                  nclass = 20, color1=:blue, color2=:red)
    if length(obj[:discriminantvalue]) != 0
        pyplot()
        if which == "boxplot"
            boxplot(string.(obj[:group]), obj[:discriminantvalue], xlabel = "群", ylabel = "判別値", label="")
        elseif which == "barplot"
            discriminantvalue = obj[:discriminantvalue];
            minx, maxx = extrema(discriminantvalue)
            w = (maxx - minx) / (nclass - 1)
            discv = floor.(Int, (discriminantvalue .- minx) ./ w)
            index1, index2, res = table(discv, obj[:group])
            groupedbar(res, xlabel = "判別値($nclass 階級に基準化)", label="")
        else
            gname = obj[:gname]
            g1 = obj[:group] .== obj[:gname][1];
            g2 = obj[:group] .== obj[:gname][2];
            plt = scatter(obj[:distance][g1, 1], obj[:distance][g1, 2],
                color = color1, markerstrokecolor = color1,
                xlabel = "$(gname[1]) の重心への二乗距離",
                ylabel = "$(gname[2]) の重心への二乗距離", aspect_ratio = 1, label=gname[1])
            scatter!(obj[:distance][g2, 1], obj[:distance][g2, 2],
                    color = color2, markerstrokecolor = color2, label=gname[2])
        end
    else
        error("3群以上の場合にはグラフ表示は用意されていません")
    end
end

using RDatasets
iris = dataset("datasets", "iris");
data = Matrix(iris[:, 1:4]); # typeof(data)
name = names(iris)[1:4]; # typeof(name)
group = iris[:, 5]; # typeof(group)
res = disc(data, group, name=name)
printdisc(res)

判別関数
5×3 Named Matrix{Float64}
      A ╲ B │   Func.1    Func.2    Func.3
────────────┼─────────────────────────────
SepalLength │  7.84596   11.0983   3.25236
SepalWidth  │  16.5154   19.9026   3.38723
PetalLength │ -21.6421  -29.1972  -7.55509
PetalWidth  │ -23.8326  -38.4775  -14.6449
Constant    │ -13.4559   18.0599   31.5157
     Func.1: setosa:versicolor
     Func.2: setosa:virginica
     Func.3: versicolor:virginica
偏 F 値,p 値
4×2 Named Matrix{Float64}
      A ╲ B │   Partial F      p value
────────────┼─────────────────────────
SepalLength │     4.72115    0.0103288
SepalWidth  │     21.9359    4.8312e-9
PetalLength │     35.5902  2.75621e-13
PetalWidth  │     24.9043  5.14315e-10
     偏 F 値の自由度 (2, 144)
分類関数
5×3 Named Matrix{Float64}
      A ╲ B │     setosa  versicolor   virginica
────────────┼───────────────────────────────────
SepalLength │   -47.0883    -31.3964    -24.8917
SepalWidth  │   -47.1757     -14.145    -7.37056
PetalLength │    32.8613    -10.4229    -25.5331
PetalWidth  │    34.7968    -12.8685    -42.1582
Constant    │     170.42     143.508     206.539
判別結果
3×3 Named Matrix{Int64}
     A ╲ B │     setosa  versicolor   virginica
───────────┼───────────────────────────────────
setosa     │         50           0           0
versicolor │          0          48           2
virginica  │          0           1          49
正判別率 = 98.0%

data = Matrix(iris[51:150, 1:4]); # typeof(data)
group = iris[51:150, 5]; # typeof(group)
name = names(iris)[1:4]; # typeof(name)
obj = disc(data, group, name=name)
printdisc(obj)
判別関数
5×1 Named Matrix{Float64}
      A ╲ B │   Func.1
────────────┼─────────
SepalLength │   3.5563
SepalWidth  │  5.57862
PetalLength │ -6.97013
PetalWidth  │  -12.386
Constant    │  16.6631
     Func.1: versicolor:virginica
偏 F 値,p 値
4×2 Named Matrix{Float64}
      A ╲ B │  Partial F     p value
────────────┼───────────────────────
SepalLength │    7.36791  0.00788617
SepalWidth  │    10.5875  0.00157755
PetalLength │    24.1566  3.70346e-6
PetalWidth  │    37.0916  2.38356e-8
     偏 F 値の自由度 (1, 95)
分類関数
5×2 Named Matrix{Float64}
      A ╲ B │ versicolor   virginica
────────────┼───────────────────────
SepalLength │    -30.802    -23.6894
SepalWidth  │   -31.9428    -20.7856
PetalLength │    5.16328    -8.77697
PetalWidth  │    1.17289    -23.5992
Constant    │    123.886     157.212
判別結果
2×2 Named Matrix{Int64}
     A ╲ B │ versicolor   virginica
───────────┼───────────────────────
versicolor │         48           2
virginica  │          1          49
正判別率 = 97.0%

plotdisc(obj)

plotdisc(obj, which="barplot")



plotdisc(obj, which="scatterplot")

コメント
  • X
  • Facebookでシェアする
  • はてなブックマークに追加する
  • LINEでシェアする

PVアクセスランキング にほんブログ村

PVアクセスランキング にほんブログ村