ジェネレータ式(内包表記)で書かれたプログラムがあったのだけど,説明はあるものの,読んだだけでは何をするプログラムかちょっとわかりにくかった。
Julia では,「ベクトル化関数をつかうよりむしろループを書け」という人もいるようなのだけど,ちょっと確かめてみようと思った次第である。
1. 紹介されていたプログラム
プログラム中の numerator は上式の分子,すなわち a の各項の指数をとったもの,sum_exp_a はその総和,最後に y は numerator の各項を sum_exp_a で割ったものを返すということで,素直といえば素直なのだけど,ループをジェネレータ式で書いているということ。 その割には,総和は reduce で計算したりしている(reduce は for ループで書いてもほとんど同じパフォーマンスであった)。
function func(a)
len_a = length(a) # S01
numerator = [exp(a[i]) for i=1:len_a] # S02
sum_exp_a = reduce(+, numerator) # S03
y = [numerator[i] / sum_exp_a for i=1:len_a] # S04
return y
end
func (generic function with 1 method)
2. 実行速度を計測する
第一引数は実行速度を計算される関数の名前である。
function test(FUNC, a, n=100)
for i = 1:n
FUNC(a)
end
end
test (generic function with 2 methods)
2.1. ジェネレータ式を使って書かれたプログラム
a = randn(10000000);
@time test(func, a)
14.059123 seconds (163.74 k allocations: 14.910 GiB, 3.66% gc time, 0.63% compilation time)
2.2. ベクトル化した関数を使って書かれたプログラム
sum_exp_a は sum() で求め,numerator の各項を ./ で割り算する。
関数が何をしているかは,Julia プログラムになれているとすぐに理解できる。
function func2(a)
numerator = exp.(a) # S02
numerator ./ sum(numerator) # S04
end
@time test(func2, a)
12.133699 seconds (32.60 k allocations: 14.903 GiB, 4.31% gc time, 0.37% compilation time)
2.3. for ループを使って書かれたプログラム
function func3(a)
numerator = similar(a)
len_a = length(a) # S01
sum_exp_a = 0
for i = 1:len_a
numerator[i] = exp(a[i]) # S02
sum_exp_a += numerator[i] # S03
end
for i = 1:len_a
numerator[i] /= sum_exp_a # S04
end
numerator
end
@time test(func3, a)
15.028953 seconds (23.41 k allocations: 7.452 GiB, 1.82% gc time, 0.16% compilation time)
2.4. for ループを使って書かれたプログラム その 2
function func4!(a)
len_a = length(a) # S01
sum_exp_a = 0
for i = 1:len_a
a[i] = exp(a[i]) # S02
sum_exp_a += a[i] # S03
end
for i = 1:len_a
a[i] /= sum_exp_a # S04
end
end
@time test(func4!, a)
14.219583 seconds (22.84 k allocations: 1.357 MiB, 0.16% compilation time)
3. 実行速度のまとめ
- 一番早いのは,関数(ベクトル化した関数)を使うプログラムである。ジェネレータ式を使って書かれたプログラムより 1.1 倍ほど速い。プログラムが,何をやっているかもすぐわかるし,短いし。
- for ループを使って書かれたプログラムはジェネレータ式を使ったプログラムとほぼ同じである(ジェネレータ式は結局は for ループと同じなので当たり前といえば当たり前。いずれも,わかりにくい!!!
結論:いつでもそうではないかも知れないが,関数(ベクトル化した関数)を使うべし?かな???
4. 最後に
アルゴリズム,計算順序の違いで,誤差範囲の違いは生じる。
まあ,計算誤差の範囲内で,どの関数も同じ結果を返すということがわかった。
a = randn(10000000)
result = func(a)
result2 = func2(a)
result3 = func3(a)
x = copy(a)
func4!(x)
println(all(abs.(result .- result2) .< eps()))
println(all(abs.(result .- result3) .< eps()))
println(all(abs.(result .- x) .< eps()))
true
true
true
println(eps())
2.220446049250313e-16