裏 RjpWiki

Julia ときどき R, Python によるコンピュータプログラム,コンピュータ・サイエンス,統計学

Julia に翻訳--189 多重ロジスティックモデル,ロジスティック回帰

2021年04月14日 | ブログラミング

#==========
Julia の修行をするときに,いろいろなプログラムを書き換えるのは有効な方法だ。
以下のプログラムを Julia に翻訳してみる。

多重ロジスティックモデル(ロジスティック回帰)
http://aoki2.si.gunma-u.ac.jp/R/lr.html

ファイル名: logisticregression.jl  関数名: logisticregression

翻訳するときに書いたメモ

数値解析をするときには,Julia のデータフレームは邪魔でしかない。

==========#

using Rmath, Plots, Statistics, Printf

function logisticregression(x)
    n, mp1 = size(x)
    m = mp1 - 1
    vnames = names(x)
    x = Matrix(x)
    index, num = table(x[:, mp1])
    @printf("***** 多重ロジスティック回帰 *****\n\n")
    @printf("サンプルサイズ   %5i\n", n)
    @printf("  生存(打ち切り)%5i\n", num[1])
    @printf("  死亡(故障)  %5i\n", num[2])
    if num[1] == 0 || num[2] == 0 || num[1] + num[2] == 2 || n <= mp1
        error("有効ケース数が 1 以下です")
    end
    means = mean(x, dims=1)
    sds = std(x, dims=1)
    @printf("\n         %15s %15s\n", "平均値", "標準偏差")
    for i in 1:m
        @printf("%8s %15.7f %15.7f\n", vnames[i], means[i], sds[i])
    end
    coeff = newtonlogist(x, vnames, m, n, sds)
    fitness(x, m, n, coeff)
end

function diff(x, m, n, coeff)
    mp1 = m + 1
    temp = coeff[mp1] .+ x[:, 1:m] * coeff[1:m];
    p0 = 1 ./ (1 .+ exp.(-temp));
    p1 = 1 .- p0;
    pp = [x[i, mp1] == 1 ? p1[i] : -p0[i] for i = 1:n];
    diff1 = zeros(mp1)
    diff1[mp1] = sum(pp)
    diff1[1:m] .+= vec(sum(x[:, 1:m] .* pp, dims=1))
    temp = -x[:, 1:m] .* p0 .* p1;
    diff2 = zeros(mp1, mp1)
    diff2[1:m, 1:m] = transpose(x[:, 1:m]) * temp
    [diff2[i, j] = 0 for i=1:m, j = 1:m if i > j]
    diff2[mp1, mp1] = -sum(p0 .* p1)
    diff2[1:m, mp1] += vec(sum(temp[:, 1:m], dims=1))
    [diff2[i, j] = diff2[j, i] for i=1:mp1, j = 1:mp1 if i > j]
    return diff1, diff2
end

function llh(x, m, coeff)
    temp = x[:, 1:m] * coeff[1:m] .+ coeff[m + 1];
    -sum(log.(1 .+ exp.([x[i, m + 1] == 1 ? -temp[i] : temp[i] for i = 1:size(x, 1)])))
end

function newtonlogist(x, vnames, m, n, sds)
    mp1 = m + 1
    coeff0 = zeros(mp1)
    coeff = fill(1e-14, mp1)
    for itr = 1:500
        diff1, diff2 = diff(x, m, n, coeff)
        coeff0 = diff2 \ diff1
        converge = all(abs.(coeff0 ./ coeff) .< 1e-10)
        coeff .-= coeff0
        if converge
            se = sqrt.(-[inv(diff2)[i, i] for i = 1:mp1])
            @printf("\n対数尤度 = %.14g\n", llh(x, m, coeff))
            @printf("\nパラメータ推定値\n\n")
            @printf("         %14s %14s %12s %8s %14s\n", "偏回帰係数", "標準偏差", "t 値", "P 値", "標準化偏回帰係数")
            t = abs.(coeff[1:m] ./ se[1:m])
            p = pt.(t, n - mp1, false) * 2
            for i in 1:m
                @printf(" %8s %14.7g %14.7g %12.7g %8.5f %14.7g\n", vnames[i], coeff[i], se[i], t[i], p[i], coeff[i] * sds[i])
            end
            t = abs(coeff[mp1] / se[mp1])
            p = pt(t, n - mp1, false) * 2
            @printf(" %8s %14.7g %14.7g %12.7g %8.5f\n %60s %i\n\n", " 定数項", coeff[mp1], se[mp1], t, p, "t の自由度 = ", n - mp1)
            return coeff
        end
    end
    error("not converged")
end

function fitness(x, m, n, coeff)
    lambda = x[:, 1:m] * coeff[1:m] .+ coeff[m + 1];
    pred = 1 ./ (1 .+ exp.(-lambda));
    y = x[:, m + 1];
    div = round.(Int, collect(range(0, n-1, step = n / 10)))[2:end]
    xs = sort(lambda);
    div2 = vcat((xs[div] .+ xs[div .+ 1]) ./ 2, Inf)
    g = zeros(Int, n);
    for i = 1:n
        for j = 1:10
            if lambda[i] <= div2[j]
                g[i] = j
                break
            end
        end
    end
    index, cnt = table(g)
    from = vcat(minimum(lambda), xs[div .+ 1]...)
    to = vcat(xs[div], maximum(lambda)...)
    mid = (from .+ to) ./ 2
    pred = [sum(pred[g .== i]) for i = 1:10]
    obs = [sum(y[g .== i]) for i = 1:10]
    @printf("   %10s %10s %10s %10s %6s %10s %10s\n",
            "以上", "以下", "期待値", "リスク", "観察値", "故障率", "サンプル数")
    for i = 1:10
        @printf("%2d %10.6f %10.6f %10.6f %10.6f %6d %10.6f %10d\n",
                i, from[i], to[i], pred[i], pred[i] / cnt[i], obs[i],
                obs[i] / cnt[i], cnt[i])
    end
    @printf(" % 60s % s\n", "", "左の2列は,各区間のλの値(最小値と最大値)")
    x2 = range(minimum(lambda), maximum(lambda), length=1000)
    y2 = 1 ./ (1 .+ exp.(-x2))
    plt = plot(x2, y2, grid=false, tick_direction=:out,
               color=:red, linewidth=0.5, xlabel="\$\\lambda\$",
               ylabel="Risk", label="")
    for i = 1:10
        plot!([from[i], to[i]], [pred[i], pred[i]] ./ cnt[i], color=:red, linewidth=0.5, label="")
        plot!([from[i], to[i]], [obs[i], obs[i]] ./ cnt[i], color=:black, linewidth=0.5, label="")
    end
    display(plt)
end

function table(x) # indices が少ないとき
    indices = sort(unique(x))
    counts = zeros(Int, length(indices))
    for i in indexin(x, indices)
        counts[i] += 1
    end
    return indices, counts
end

using CSV, DataFrames
x = CSV.read("/Users/aoki/Desktop/HD3/www/R/lr.data", DataFrame);
logisticregression(x)
#===
***** 多重ロジスティック回帰 *****

サンプルサイズ      98
  生存(打ち切り)   85
  死亡(故障)     13

                     平均値            標準偏差
      x1     132.6734694      14.5047276
      x2     223.2346939      49.2250881

対数尤度 = -36.09198547355

パラメータ推定値

                  偏回帰係数           標準偏差          t 値      P 値       標準化偏回帰係数
       x1    0.008297108     0.02120826    0.3912206  0.69651      0.1203473
       x2     0.01138648    0.005739863     1.983755  0.05017      0.5605007
      定数項      -5.645581       3.048239     1.852079  0.06712
                                                    t の自由度 =  95

           以上         以下        期待値        リスク    観察値        故障率      サンプル数
 1  -3.090951  -2.663472   0.540297   0.054030      0   0.000000         10
 2  -2.661530  -2.490380   0.729186   0.072919      0   0.000000         10
 3  -2.472638  -2.436890   0.712089   0.079121      0   0.000000          9
 4  -2.398670  -2.202452   0.916765   0.091676      3   0.300000         10
 5  -2.160525  -2.054780   1.097699   0.109770      1   0.100000         10
 6  -2.029713  -1.945594   1.206158   0.120616      1   0.100000         10
 7  -1.944181  -1.742756   1.350194   0.135019      1   0.100000         10
 8  -1.724572  -1.549538   1.427380   0.158598      3   0.333333          9
 9  -1.522529  -1.252166   1.986057   0.198606      0   0.000000         10
10  -1.238838   0.587855   3.034176   0.303418      4   0.400000         10
                                                              左の2列は,各区間のλの値(最小値と最大値)
===#

コメント
  • X
  • Facebookでシェアする
  • はてなブックマークに追加する
  • LINEでシェアする

PVアクセスランキング にほんブログ村

PVアクセスランキング にほんブログ村