log1mexp sign confusion
Summary
log1mexp
implementation for Float and Double has a sign error in the threshold, possibly due to copy/paste confusion when referring to other language implementations, which specify log1mexp x = log (1 - exp (- x))
instead of Haskell's log1mexp x = log (1 - exp x)
.
Maybe Haskell's specification should be changed to match other languages? With the current specification the threshold should be changed to -(log 2)
.
Steps to reproduce
{-# LANGUAGE RankNTypes #-}
module Test where
import Numeric
naive x = log (1 - exp x)
log1mexp' x
| x <= -(log 2) = log1p (negate (exp x))
| otherwise = log (negate (expm1 x))
relativeError :: Double -> (forall a . RealFloat a => a -> a) -> Double
relativeError x f = (realToFrac yf - y) / y
where
xf :: Float
xf = realToFrac x
y = f x
yf = f xf
test x = (x, relativeError x naive, relativeError x log1mexp, relativeError x log1mexp')
main = mapM_ (print . test . negate . (2 ^^)) [-10 .. 10]
Output shows the relative error (magnitude of the last column) is consistently smaller in the fixed version log1mexp'
, the more so the larger the magnitude of the input:
(-9.765625e-4,8.479991916086323e-9,8.479992941110612e-9,8.479992941110612e-9)
(-1.953125e-3,1.0464709616203075e-7,2.8222328488136576e-8,2.8222328488136576e-8)
(-3.90625e-3,4.61205828590223e-7,3.140053949552827e-8,3.140053949552827e-8)
(-7.8125e-3,5.264616060243174e-7,3.547762531997476e-8,3.547762531997476e-8)
(-1.5625e-2,-3.0243668632982033e-7,4.088452396255435e-8,4.088452396255435e-8)
(-3.125e-2,-1.5715816024124222e-7,-2.0187950435429786e-8,-2.0187950435429786e-8)
(-6.25e-2,5.751965485445792e-8,-2.7518195971898168e-8,-2.7518195971898168e-8)
(-0.125,-7.402495028994113e-8,3.7318452255312984e-8,3.7318452255312984e-8)
(-0.25,-9.60333945667006e-11,-9.60333945667006e-11,-9.60333945667006e-11)
(-0.5,3.0539548876919827e-9,3.0539548876919827e-9,3.0539548876919827e-9)
(-1.0,1.5606325135389306e-9,1.5606325135389306e-9,1.5606325135389306e-9)
(-2.0,-9.918407768149559e-8,-9.918407768149559e-8,3.2903522112743677e-9)
(-4.0,-1.38006136898811e-6,-1.38006136898811e-6,3.061765740951412e-8)
(-8.0,-2.2891326354666136e-5,-2.2891326354666136e-5,8.757843622459283e-9)
(-16.0,5.930686227212524e-2,5.930686227212524e-2,1.687229115013425e-8)
(-32.0,-1.0,-1.0,1.8016016639892746e-8)
(-64.0,NaN,NaN,3.01797060757898e-8)
(-128.0,NaN,NaN,-1.0)
(-256.0,NaN,NaN,-1.0)
(-512.0,NaN,NaN,-1.0)
(-1024.0,NaN,NaN,NaN)
Expected behavior
For base's log1mexp
to be accurate.
Environment
- GHC 8.4.4, 8.6.5, 8.8.1
- Debian Buster
- x86_64