http://minato.sip21c.org/seek_pi_by_Archimedes.R
中澤さんの,「100ビット実数を使って円周率を10桁正しく出すには正何角形まで必要かを求めるためのRコード」について
Rmpfr というパッケージの使い方がこれだけでわかり,ありがたかった。
もともと,計算時間がかかるというプログラムではなかったが,ちょっとした工夫で計算時間が短くなることがわかったので,記録しておこう。
末尾に追記あり。
中澤さんのオリジナルプログラム
func1 = function() {
library(Rmpfr)
n = 6
p = sqrt(mpfr(3, 100)) / mpfr(2, 100)
a = sqrt(mpfr(1, 100) - p ^ 2)
ad = a / p
L = a * n
M = ad * n
while ((M - L) > mpfr(1e-10, 100)) {
n = n * mpfr(2, 100)
p = sqrt((mpfr(1, 100) + p) / mpfr(2, 100))
a = sqrt(mpfr(1, 100) - p ^ mpfr(2, 100))
ad = a / p
L = a * n
M = ad * n
}
invisible(c(n, L, M))
}
毎回 mpfr を使うのを避ける
func2 = function() {
library(Rmpfr)
n = mpfr(6, 100)
one = mpfr(1, 100)
two = mpfr(2, 100)
epsilon = mpfr(1e-10, 100)
p = sqrt(mpfr(3, 100)) / two
a = sqrt(one - p ^ 2)
ad = a / p
L = a * n
M = ad * n
while ((M - L) > epsilon) {
n = n * two
p = sqrt((one + p) / two)
a = sqrt(one - p ^ two)
ad = a / p
L = a * n
M = ad * n
}
invisible(c(n, L, M))
}
混合演算の場合は,自動的に mpfr してくれるようなので,そのまま定数を書く
func3 = function() {
library(Rmpfr)
n = 6
epsilon = mpfr(1e-10, 100)
p = sqrt(mpfr(3, 100)) / 2
a = sqrt(1 - p ^ 2)
ad = a / p
L = a * n
M = ad * n
while ((M - L) > epsilon) {
n = n * 2
p = sqrt((1 + p) / 2)
a = sqrt(1 - p ^ 2)
ad = a / p
L = a * n
M = ad * n
}
invisible(c(n, L, M))
}
更に,重箱の隅をつつき,計算時間を搾り取る
func4 = function() {
library(Rmpfr)
n = 6
epsilon = mpfr(1e-10, 100)
p = sqrt(mpfr(3, 100)) / 2
a = sqrt(1 - p ^ 2)
ad = a/p
L = a * n
M = ad * n
while ((M - L) > epsilon) {
n = n + n
p = (1 + p) * 0.5
a = sqrt(1 - p)
p = sqrt(p)
L = a * n
M = L / p
}
invisible(c(n, L, M))
}
library(rbenchmark)
benchmark(func1(), func2(), func3(), func4())
結局の所,mpfr が不要な定数はそのまま書くのが速いらしい。
> benchmark(func1(), func2(), func3(), func4())
test replications elapsed relative user.self sys.self
1 func1() 100 22.560 2.429 22.550 0.146
2 func2() 100 13.763 1.482 13.748 0.075
3 func3() 100 12.017 1.294 12.003 0.066
4 func4() 100 9.287 1.000 9.277 0.051
結果が正しいことを確認
> print(func1())
3 'mpfr' numbers of precision 100 bits
[1] 786432 3.1415926535814376697324981306951 3.1415926536065043758325431868181
> print(func2())
3 'mpfr' numbers of precision 100 bits
[1] 786432 3.1415926535814376697324981306951 3.1415926536065043758325431868181
> print(func3())
[[1]]
[1] 786432
[[2]]
'mpfr1' 3.1415926535814376697324981306951
[[3]]
'mpfr1' 3.1415926536065043758325431868181
> print(func4())
[[1]]
[1] 786432
[[2]]
'mpfr1' 3.1415926535814376697324981306951
[[3]]
'mpfr1' 3.1415926536065043758325431868181
Julia で書くと以下のようになる。
function func_julia(bits=256, epsilon=eps())
setprecision(BigFloat, bits)
n = 6
p = sqrt(big(3)) / 2
a = sqrt(1 - p ^ 2)
ad = a/p
L = a * n
M = ad * n
while (M - L) > epsilon
n = 2n
p = (1 + p) * 0.5
a = sqrt(1 - p)
p = sqrt(p)
L = a * n
M = L / p
end
(n, L, M)
end
@time n, L, M = func_julia(100, 1e-10)
n # 786432
L # 3.1415926535814376697324981306951
M # 3.1415926536065043758325431868181
func4 に比べると ほぼ 1000 倍速い。