Commit ce6ce788 authored by Simon Peyton Jones's avatar Simon Peyton Jones

Set strictness correctly for JoinIds

We were failing to keep correct strictness info when eta-expanding
join points; Trac #15517.   The situation was something like

  \q v eta ->
     let j x = error "blah
         -- STR Lx   bottoming!
     in case y of
           A -> j x eta
           B -> blah
           C -> j x eta

So we spot j as a join point and eta-expand it.  But we must
also adjust the stricness info, else it vlaimes to bottom after
one arg is applied but now it has become two.

I fixed this in two places:

 - In CoreOpt.joinPointBinding_maybe, adjust strictness info

 - In SimplUtils.tryEtaExpandRhs, return consistent values
   for arity and bottom-ness
parent 9c4e6c6b
......@@ -39,7 +39,7 @@ module Demand (
nopSig, botSig, exnSig, cprProdSig,
isTopSig, hasDemandEnvSig,
splitStrictSig, strictSigDmdEnv,
increaseStrictSigArity,
increaseStrictSigArity, etaExpandStrictSig,
seqDemand, seqDemandList, seqDmdType, seqStrictSig,
......@@ -1737,8 +1737,23 @@ splitStrictSig (StrictSig (DmdType _ dmds res)) = (dmds, res)
increaseStrictSigArity :: Int -> StrictSig -> StrictSig
-- Add extra arguments to a strictness signature
increaseStrictSigArity arity_increase (StrictSig (DmdType env dmds res))
= StrictSig (DmdType env (replicate arity_increase topDmd ++ dmds) res)
increaseStrictSigArity arity_increase sig@(StrictSig dmd_ty@(DmdType env dmds res))
| isTopDmdType dmd_ty = sig
| arity_increase <= 0 = sig
| otherwise = StrictSig (DmdType env dmds' res)
where
dmds' = replicate arity_increase topDmd ++ dmds
etaExpandStrictSig :: Arity -> StrictSig -> StrictSig
-- We are expanding (\x y. e) to (\x y z. e z)
-- Add exta demands to the /end/ of the arg demands if necessary
etaExpandStrictSig arity sig@(StrictSig dmd_ty@(DmdType env dmds res))
| isTopDmdType dmd_ty = sig
| arity_increase <= 0 = sig
| otherwise = StrictSig (DmdType env dmds' res)
where
arity_increase = arity - length dmds
dmds' = dmds ++ replicate arity_increase topDmd
isTopSig :: StrictSig -> Bool
isTopSig (StrictSig ty) = isTopDmdType ty
......
......@@ -36,6 +36,7 @@ import Var ( varType )
import VarSet
import VarEnv
import DataCon
import Demand( etaExpandStrictSig )
import OptCoercion ( optCoercion )
import Type hiding ( substTy, extendTvSubst, extendCvSubst, extendTvSubstList
, isInScope, substTyVarBndr, cloneTyVarBndr )
......@@ -658,7 +659,11 @@ joinPointBinding_maybe bndr rhs
| AlwaysTailCalled join_arity <- tailCallInfo (idOccInfo bndr)
, (bndrs, body) <- etaExpandToJoinPoint join_arity rhs
= Just (bndr `asJoinId` join_arity, mkLams bndrs body)
, let str_sig = idStrictness bndr
str_arity = count isId bndrs -- Strictness demands are for Ids only
join_bndr = bndr `asJoinId` join_arity
`setIdStrictness` etaExpandStrictSig str_arity str_sig
  • Couldn't this use Demand.ensureArgs instead of essentially redefining it as Demand.etaExpandStrictSig?

    From what I can tell, the only difference is in handling of DmdResult and the case when there's an arity decrease. ensureArgs will shorten the demand signature appropriately, whereas etaExpandStrictSig doesn't.

    Actually, that's kind of broken: In the case of an arity decrease (happens when we have something like go x y = x `seq` ... in go), we really want to zap strictness altogether, wouldn't we? Just re-using the signature for arity 0 in that example seems unsound in a call like go (error "boom").

    I'm trying to resurrect an old CoreLint check that checks that dmdTypeDepth == idArity and spent some time debugging this.

    Edited by Sebastian Graf
  • I don't think I have enough context to reply meaningfully. If you'd like my input, can you explain a bit more? (Alternatively, if you know what you are doing, go ahead.)

  • It's probably more an observation and the consequences on !312 (merged), where I 'fix' this. I should probably write this down as a Note. Read on if you like, but I think I have a pretty good picture.

    TLDR; This code (etaExpandStrictSig) isn't unsound currently, but would become so if we had it in !312 (merged) unchanged.

    So, it turns out the simplifier turns a binding like go x y = x `seq` $rhs in go (side note: the expression occurs this way probably because it's wrapped in a cast) into a nullary join point. It will call joinPointBinding_maybe go (x `seq` $rhs). This returns the updated binder, which is now a join point with arity 0. But go's idArity was 2 before, so we must account for 'eta-expansion' (reduction, rather) in the strictness signature.

    We do so by calling etaExpandStrictSig, but its definition doesn't really have a story for when arity decreases, other than returning the old signature for the mismatching arity.

    This becomes unsound when we only look at the arity of the binder instead of the dmdTypeDepth to unleash a strictness signature, like I propose to do in !312 (merged). Consider what happened if we went ahead and found a call site of go with one argument somewhere, like go (error "boom"). Before the arity decrease, we couldn't unleash the strictness signature, because it was for a call of at least idArity go == 2 arguments, but the call site only supplies one. Now that idArity is 0, we would happily unleash the strictness signature of go (which is strict in its first argument) and suddenly would be strict in error "boom".

  • Read on if you like, but I think I have a pretty good picture.

    Great -- go for it!

Please register or sign in to reply
= Just (join_bndr, mkLams bndrs body)
| otherwise
= Nothing
......@@ -668,6 +673,27 @@ joinPointBindings_maybe bndrs
= mapM (uncurry joinPointBinding_maybe) bndrs
{- Note [Strictness and join points]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Suppose we have
let f = \x. if x>200 then e1 else e1
and we know that f is strict in x. Then if we subsequently
discover that f is an arity-2 join point, we'll eta-expand it to
let f = \x y. if x>200 then e1 else e1
and now it's only strict if applied to two arguments. So we should
adjust the strictness info.
A more common case is when
f = \x. error ".."
and again its arity increses (Trac #15517)
-}
{- *********************************************************************
* *
exprIsConApp_maybe
......
......@@ -1511,9 +1511,12 @@ tryEtaExpandRhs :: SimplMode -> OutId -> OutExpr
-- (a) rhs' has manifest arity
-- (b) if is_bot is True then rhs' applied to n args is guaranteed bottom
tryEtaExpandRhs mode bndr rhs
| isJoinId bndr
= return (manifestArity rhs, False, rhs)
-- Note [Do not eta-expand join points]
| Just join_arity <- isJoinId_maybe bndr
= do { let (join_bndrs, join_body) = collectNBinders join_arity rhs
; return (count isId join_bndrs, exprIsBottom join_body, rhs) }
-- Note [Do not eta-expand join points]
-- But do return the correct arity and bottom-ness, because
-- these are used to set the bndr's IdInfo (Trac #15517)
| otherwise
= do { (new_arity, is_bot, new_rhs) <- try_expand
......
{-# LANGUAGE PatternSynonyms #-}
module T15517 where
data Nat = Z | S Nat
pattern Zpat = Z
sfrom :: Nat -> () -> Bool
sfrom Zpat = \_ -> False
sfrom (S Z) = \_ -> False
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module T15517a () where
import Data.Proxy
newtype Rep (ki :: kon -> *) (phi :: Nat -> *) (code :: [[Atom kon]])
= Rep (NS (PoA ki phi) code)
data NA :: (kon -> *) -> (Nat -> *) -> Atom kon -> * where
NA_I :: (IsNat k) => phi k -> NA ki phi (I k)
NA_K :: ki k -> NA ki phi (K k)
data NP :: (k -> *) -> [k] -> * where
NP0 :: NP p '[]
(:*) :: p x -> NP p xs -> NP p (x : xs)
class IsNat (n :: Nat) where
getSNat :: Proxy n -> SNat n
instance IsNat Z where
getSNat _ = SZ
instance IsNat n => IsNat (S n) where
getSNat p = SS (getSNat $ proxyUnsuc p)
proxyUnsuc :: Proxy (S n) -> Proxy n
proxyUnsuc _ = Proxy
type PoA (ki :: kon -> *) (phi :: Nat -> *) = NP (NA ki phi)
data Atom kon
= K kon
| I Nat
data Nat = S Nat | Z
data SNat :: Nat -> * where
SZ :: SNat Z
SS :: SNat n -> SNat (S n)
data Kon = KInt
data Singl (kon :: Kon) :: * where
SInt :: Int -> Singl KInt
type family Lkup (n :: Nat) (ks :: [k]) :: k where
Lkup Z (k : ks) = k
Lkup (S n) (k : ks) = Lkup n ks
data El :: [*] -> Nat -> * where
El :: IsNat ix => Lkup ix fam -> El fam ix
data NS :: (k -> *) -> [k] -> * where
There :: NS p xs -> NS p (x : xs)
Here :: p x -> NS p (x : xs)
class Family (ki :: kon -> *) (fam :: [*]) (codes :: [[[Atom kon]]])
| fam -> ki codes , ki codes -> fam where
sfrom' :: SNat ix -> El fam ix -> Rep ki (El fam) (Lkup ix codes)
data Rose a = a :>: [Rose a]
| Leaf a
type FamRoseInt = '[Rose Int, [Rose Int]]
type CodesRoseInt =
'[ '[ '[K KInt, I (S Z)], '[K KInt]], '[ '[], '[I Z, I (S Z)]]]
pattern IdxRoseInt = SZ
pattern IdxListRoseInt = SS SZ
pat1 :: PoA Singl (El FamRoseInt) '[I Z, I (S Z)]
-> NS (PoA Singl (El FamRoseInt)) '[ '[], '[I Z, I (S Z)]]
pat1 d = There (Here d)
pat2 :: PoA Singl (El FamRoseInt) '[]
-> NS (PoA Singl (El FamRoseInt)) '[ '[], '[I Z, I (S Z)]]
pat2 d = Here d
pat3 :: PoA Singl (El FamRoseInt) '[K KInt]
-> NS (PoA Singl (El FamRoseInt)) '[ '[K KInt, I (S Z)], '[K KInt]]
pat3 d = There (Here d)
pat4 :: PoA Singl (El FamRoseInt) '[K KInt, I (S Z)]
-> NS (PoA Singl (El FamRoseInt)) '[ '[K KInt, I (S Z)], '[K KInt]]
pat4 d = Here d
instance Family Singl FamRoseInt CodesRoseInt where
sfrom' = \case IdxRoseInt -> \case El (x :>: xs) -> Rep (pat4 (NA_K (SInt x) :* (NA_I (El xs) :* NP0)))
El (Leaf x) -> Rep (pat3 (NA_K (SInt x) :* NP0))
IdxListRoseInt -> \case El [] -> Rep (pat2 NP0)
El (x:xs) -> Rep (pat1 (NA_I (El x) :* (NA_I (El xs) :* NP0)))
......@@ -318,4 +318,6 @@ test('T15005', normal, compile, ['-O'])
# we omit profiling because it affects the optimiser and makes the test fail
test('T15056', [extra_files(['T15056a.hs']), omit_ways(['profasm'])], multimod_compile, ['T15056', '-O -v0 -ddump-rule-firings'])
test('T15186', normal, multimod_compile, ['T15186', '-v0'])
test('T15453', normal, compile, ['-dcore-lint -O1'])
test('T15453', normal, compile, ['-O1'])
test('T15517', normal, compile, ['-O0'])
test('T15517a', normal, compile, ['-O0'])
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment