Commit b81fc821 authored by Simon Peyton Jones's avatar Simon Peyton Jones Committed by Ben Gamari

Set strictness correctly for JoinIds

We were failing to keep correct strictness info when eta-expanding
join points; Trac #15517.   The situation was something like

  \q v eta ->
     let j x = error "blah
         -- STR Lx   bottoming!
     in case y of
           A -> j x eta
           B -> blah
           C -> j x eta

So we spot j as a join point and eta-expand it.  But we must
also adjust the stricness info, else it vlaimes to bottom after
one arg is applied but now it has become two.

I fixed this in two places:

 - In CoreOpt.joinPointBinding_maybe, adjust strictness info

 - In SimplUtils.tryEtaExpandRhs, return consistent values
   for arity and bottom-ness

(cherry picked from commit ce6ce788)
parent 2d308da2
......@@ -39,7 +39,7 @@ module Demand (
nopSig, botSig, exnSig, cprProdSig,
isTopSig, hasDemandEnvSig,
splitStrictSig, strictSigDmdEnv,
increaseStrictSigArity,
increaseStrictSigArity, etaExpandStrictSig,
seqDemand, seqDemandList, seqDmdType, seqStrictSig,
......@@ -1737,8 +1737,23 @@ splitStrictSig (StrictSig (DmdType _ dmds res)) = (dmds, res)
increaseStrictSigArity :: Int -> StrictSig -> StrictSig
-- Add extra arguments to a strictness signature
increaseStrictSigArity arity_increase (StrictSig (DmdType env dmds res))
= StrictSig (DmdType env (replicate arity_increase topDmd ++ dmds) res)
increaseStrictSigArity arity_increase sig@(StrictSig dmd_ty@(DmdType env dmds res))
| isTopDmdType dmd_ty = sig
| arity_increase <= 0 = sig
| otherwise = StrictSig (DmdType env dmds' res)
where
dmds' = replicate arity_increase topDmd ++ dmds
etaExpandStrictSig :: Arity -> StrictSig -> StrictSig
-- We are expanding (\x y. e) to (\x y z. e z)
-- Add exta demands to the /end/ of the arg demands if necessary
etaExpandStrictSig arity sig@(StrictSig dmd_ty@(DmdType env dmds res))
| isTopDmdType dmd_ty = sig
| arity_increase <= 0 = sig
| otherwise = StrictSig (DmdType env dmds' res)
where
arity_increase = arity - length dmds
dmds' = dmds ++ replicate arity_increase topDmd
isTopSig :: StrictSig -> Bool
isTopSig (StrictSig ty) = isTopDmdType ty
......
......@@ -36,6 +36,7 @@ import Var ( varType )
import VarSet
import VarEnv
import DataCon
import Demand( etaExpandStrictSig )
import OptCoercion ( optCoercion )
import Type hiding ( substTy, extendTvSubst, extendCvSubst, extendTvSubstList
, isInScope, substTyVarBndr, cloneTyVarBndr )
......@@ -658,7 +659,11 @@ joinPointBinding_maybe bndr rhs
| AlwaysTailCalled join_arity <- tailCallInfo (idOccInfo bndr)
, (bndrs, body) <- etaExpandToJoinPoint join_arity rhs
= Just (bndr `asJoinId` join_arity, mkLams bndrs body)
, let str_sig = idStrictness bndr
str_arity = count isId bndrs -- Strictness demands are for Ids only
join_bndr = bndr `asJoinId` join_arity
`setIdStrictness` etaExpandStrictSig str_arity str_sig
= Just (join_bndr, mkLams bndrs body)
| otherwise
= Nothing
......@@ -668,6 +673,27 @@ joinPointBindings_maybe bndrs
= mapM (uncurry joinPointBinding_maybe) bndrs
{- Note [Strictness and join points]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Suppose we have
let f = \x. if x>200 then e1 else e1
and we know that f is strict in x. Then if we subsequently
discover that f is an arity-2 join point, we'll eta-expand it to
let f = \x y. if x>200 then e1 else e1
and now it's only strict if applied to two arguments. So we should
adjust the strictness info.
A more common case is when
f = \x. error ".."
and again its arity increses (Trac #15517)
-}
{- *********************************************************************
* *
exprIsConApp_maybe
......
......@@ -1511,9 +1511,12 @@ tryEtaExpandRhs :: SimplMode -> OutId -> OutExpr
-- (a) rhs' has manifest arity
-- (b) if is_bot is True then rhs' applied to n args is guaranteed bottom
tryEtaExpandRhs mode bndr rhs
| isJoinId bndr
= return (manifestArity rhs, False, rhs)
-- Note [Do not eta-expand join points]
| Just join_arity <- isJoinId_maybe bndr
= do { let (join_bndrs, join_body) = collectNBinders join_arity rhs
; return (count isId join_bndrs, exprIsBottom join_body, rhs) }
-- Note [Do not eta-expand join points]
-- But do return the correct arity and bottom-ness, because
-- these are used to set the bndr's IdInfo (Trac #15517)
| otherwise
= do { (new_arity, is_bot, new_rhs) <- try_expand
......
{-# LANGUAGE PatternSynonyms #-}
module T15517 where
data Nat = Z | S Nat
pattern Zpat = Z
sfrom :: Nat -> () -> Bool
sfrom Zpat = \_ -> False
sfrom (S Z) = \_ -> False
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module T15517a () where
import Data.Proxy
newtype Rep (ki :: kon -> *) (phi :: Nat -> *) (code :: [[Atom kon]])
= Rep (NS (PoA ki phi) code)
data NA :: (kon -> *) -> (Nat -> *) -> Atom kon -> * where
NA_I :: (IsNat k) => phi k -> NA ki phi (I k)
NA_K :: ki k -> NA ki phi (K k)
data NP :: (k -> *) -> [k] -> * where
NP0 :: NP p '[]
(:*) :: p x -> NP p xs -> NP p (x : xs)
class IsNat (n :: Nat) where
getSNat :: Proxy n -> SNat n
instance IsNat Z where
getSNat _ = SZ
instance IsNat n => IsNat (S n) where
getSNat p = SS (getSNat $ proxyUnsuc p)
proxyUnsuc :: Proxy (S n) -> Proxy n
proxyUnsuc _ = Proxy
type PoA (ki :: kon -> *) (phi :: Nat -> *) = NP (NA ki phi)
data Atom kon
= K kon
| I Nat
data Nat = S Nat | Z
data SNat :: Nat -> * where
SZ :: SNat Z
SS :: SNat n -> SNat (S n)
data Kon = KInt
data Singl (kon :: Kon) :: * where
SInt :: Int -> Singl KInt
type family Lkup (n :: Nat) (ks :: [k]) :: k where
Lkup Z (k : ks) = k
Lkup (S n) (k : ks) = Lkup n ks
data El :: [*] -> Nat -> * where
El :: IsNat ix => Lkup ix fam -> El fam ix
data NS :: (k -> *) -> [k] -> * where
There :: NS p xs -> NS p (x : xs)
Here :: p x -> NS p (x : xs)
class Family (ki :: kon -> *) (fam :: [*]) (codes :: [[[Atom kon]]])
| fam -> ki codes , ki codes -> fam where
sfrom' :: SNat ix -> El fam ix -> Rep ki (El fam) (Lkup ix codes)
data Rose a = a :>: [Rose a]
| Leaf a
type FamRoseInt = '[Rose Int, [Rose Int]]
type CodesRoseInt =
'[ '[ '[K KInt, I (S Z)], '[K KInt]], '[ '[], '[I Z, I (S Z)]]]
pattern IdxRoseInt = SZ
pattern IdxListRoseInt = SS SZ
pat1 :: PoA Singl (El FamRoseInt) '[I Z, I (S Z)]
-> NS (PoA Singl (El FamRoseInt)) '[ '[], '[I Z, I (S Z)]]
pat1 d = There (Here d)
pat2 :: PoA Singl (El FamRoseInt) '[]
-> NS (PoA Singl (El FamRoseInt)) '[ '[], '[I Z, I (S Z)]]
pat2 d = Here d
pat3 :: PoA Singl (El FamRoseInt) '[K KInt]
-> NS (PoA Singl (El FamRoseInt)) '[ '[K KInt, I (S Z)], '[K KInt]]
pat3 d = There (Here d)
pat4 :: PoA Singl (El FamRoseInt) '[K KInt, I (S Z)]
-> NS (PoA Singl (El FamRoseInt)) '[ '[K KInt, I (S Z)], '[K KInt]]
pat4 d = Here d
instance Family Singl FamRoseInt CodesRoseInt where
sfrom' = \case IdxRoseInt -> \case El (x :>: xs) -> Rep (pat4 (NA_K (SInt x) :* (NA_I (El xs) :* NP0)))
El (Leaf x) -> Rep (pat3 (NA_K (SInt x) :* NP0))
IdxListRoseInt -> \case El [] -> Rep (pat2 NP0)
El (x:xs) -> Rep (pat1 (NA_I (El x) :* (NA_I (El xs) :* NP0)))
......@@ -316,4 +316,6 @@ test('T15005', normal, compile, ['-O'])
# we omit profiling because it affects the optimiser and makes the test fail
test('T15056', [extra_files(['T15056a.hs']), omit_ways(['profasm'])], multimod_compile, ['T15056', '-O -v0 -ddump-rule-firings'])
test('T15186', normal, multimod_compile, ['T15186', '-v0'])
test('T15453', normal, compile, ['-dcore-lint -O1'])
test('T15453', normal, compile, ['-O1'])
test('T15517', normal, compile, ['-O0'])
test('T15517a', normal, compile, ['-O0'])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment