Commit 1990bb0d authored by Simon Peyton Jones's avatar Simon Peyton Jones Committed by David Feuer
Browse files

Make Specialise work with casts

With my upcoming early-inlining patch it turned out that Specialise
was getting stuck on casts.  This patch fixes it; see Note
[Account for casts in binding] in Specialise.

Reviewers: austin, goldfire, bgamari

Subscribers: thomie

Differential Revision: https://phabricator.haskell.org/D3192
parent 29b57238
......@@ -34,7 +34,7 @@ module CoreSubst (
-- ** Simple expression optimiser
simpleOptPgm, simpleOptExpr, simpleOptExprWith,
exprIsConApp_maybe, exprIsLiteral_maybe, exprIsLambda_maybe,
pushCoArg, pushCoValArg, pushCoTyArg
pushCoArg, pushCoValArg, pushCoTyArg, collectBindersPushingCo
) where
#include "HsVersions.h"
......@@ -1614,7 +1614,7 @@ exprIsLambda_maybe _ _e
Here we implement the "push rules" from FC papers:
* The push-argument ules, where we can move a coercion past an argument.
* The push-argument rules, where we can move a coercion past an argument.
We have
(fun |> co) arg
and we want to transform it to
......@@ -1687,7 +1687,7 @@ pushCoValArg co
= Just (mkRepReflCo arg, mkRepReflCo res)
| isFunTy tyL
, [_, _, co1, co2] <- decomposeCo 4 co
, (co1, co2) <- decomposeFunCo co
-- If co :: (tyL1 -> tyL2) ~ (tyR1 -> tyR2)
-- then co1 :: tyL1 ~ tyR1
-- co2 :: tyL2 ~ tyR2
......@@ -1711,7 +1711,7 @@ pushCoercionIntoLambda in_scope x e co
, Pair s1s2 t1t2 <- coercionKind co
, Just (_s1,_s2) <- splitFunTy_maybe s1s2
, Just (t1,_t2) <- splitFunTy_maybe t1t2
= let [_rep1, _rep2, co1, co2] = decomposeCo 4 co
= let (co1, co2) = decomposeFunCo co
-- Should we optimize the coercions here?
-- Otherwise they might not match too well
x' = x `setIdType` t1
......@@ -1784,3 +1784,57 @@ pushCoDataCon dc dc_args co
where
Pair from_ty to_ty = coercionKind co
collectBindersPushingCo :: CoreExpr -> ([Var], CoreExpr)
-- Collect lambda binders, pushing coercions inside if possible
-- E.g. (\x.e) |> g g :: <Int> -> blah
-- = (\x. e |> Nth 1 g)
--
-- That is,
--
-- collectBindersPushingCo ((\x.e) |> g) === ([x], e |> Nth 1 g)
collectBindersPushingCo e
= go [] e
where
-- Peel off lambdas until we hit a cast.
go :: [Var] -> CoreExpr -> ([Var], CoreExpr)
-- The accumulator is in reverse order
go bs (Lam b e) = go (b:bs) e
go bs (Cast e co) = go_c bs e co
go bs e = (reverse bs, e)
-- We are in a cast; peel off casts until we hit a lambda.
go_c :: [Var] -> CoreExpr -> Coercion -> ([Var], CoreExpr)
-- (go_c bs e c) is same as (go bs e (e |> c))
go_c bs (Cast e co1) co2 = go_c bs e (co1 `mkTransCo` co2)
go_c bs (Lam b e) co = go_lam bs b e co
go_c bs e co = (reverse bs, mkCast e co)
-- We are in a lambda under a cast; peel off lambdas and build a
-- new coercion for the body.
go_lam :: [Var] -> Var -> CoreExpr -> Coercion -> ([Var], CoreExpr)
-- (go_lam bs b e c) is same as (go_c bs (\b.e) c)
go_lam bs b e co
| isTyVar b
, let Pair tyL tyR = coercionKind co
, ASSERT( isForAllTy tyL )
isForAllTy tyR
, isReflCo (mkNthCo 0 co) -- See Note [collectBindersPushingCo]
= go_c (b:bs) e (mkInstCo co (mkNomReflCo (mkTyVarTy b)))
| isId b
, let Pair tyL tyR = coercionKind co
, ASSERT( isFunTy tyL) isFunTy tyR
, (co_arg, co_res) <- decomposeFunCo co
, isReflCo co_arg -- See Note [collectBindersPushingCo]
= go_c (b:bs) e co_res
| otherwise = (reverse bs, mkCast (Lam b e) co)
{- Note [collectBindersPushingCo]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
We just look for coercions of form
<type> -> blah
(and similarly for foralls) to keep this function simple. We could do
more elaborate stuff, but it'd involve substitution etc.
-}
......@@ -1153,8 +1153,8 @@ specCalls :: Maybe Module -- Just this_mod => specialising imported fn
specCalls mb_mod env rules_for_me calls_for_me fn rhs
-- The first case is the interesting one
| rhs_tyvars `lengthIs` n_tyvars -- Rhs of fn's defn has right number of big lambdas
&& rhs_ids `lengthAtLeast` n_dicts -- and enough dict args
| rhs_tyvars `lengthIs` n_tyvars -- Rhs of fn's defn has right number of big lambdas
&& rhs_bndrs1 `lengthAtLeast` n_dicts -- and enough dict args
&& notNull calls_for_me -- And there are some calls to specialise
&& not (isNeverActive (idInlineActivation fn))
-- Don't specialise NOINLINE things
......@@ -1178,7 +1178,7 @@ specCalls mb_mod env rules_for_me calls_for_me fn rhs
return ([], [], emptyUDs)
where
_trace_doc = sep [ ppr rhs_tyvars, ppr n_tyvars
, ppr rhs_ids, ppr n_dicts
, ppr rhs_bndrs, ppr n_dicts
, ppr (idInlineActivation fn) ]
fn_type = idType fn
......@@ -1194,11 +1194,12 @@ specCalls mb_mod env rules_for_me calls_for_me fn rhs
-- Figure out whether the function has an INLINE pragma
-- See Note [Inline specialisations]
(rhs_tyvars, rhs_ids, rhs_body) = collectTyAndValBinders rhs
rhs_dict_ids = take n_dicts rhs_ids
body = mkLams (drop n_dicts rhs_ids) rhs_body
-- Glue back on the non-dict lambdas
(rhs_bndrs, rhs_body) = CoreSubst.collectBindersPushingCo rhs
-- See Note [Account for casts in binding]
(rhs_tyvars, rhs_bndrs1) = span isTyVar rhs_bndrs
(rhs_dict_ids, rhs_bndrs2) = splitAt n_dicts rhs_bndrs1
body = mkLams rhs_bndrs2 rhs_body
-- Glue back on the non-dict lambdas
already_covered :: DynFlags -> [CoreExpr] -> Bool
already_covered dflags args -- Note [Specialisations already covered]
......@@ -1350,7 +1351,23 @@ specCalls mb_mod env rules_for_me calls_for_me fn rhs
; return (Just ((spec_f_w_arity, spec_rhs), final_uds, spec_env_rule)) } }
{- Note [Evidence foralls]
{- Note [Account for casts in binding]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Consider
f :: Eq a => a -> IO ()
{-# INLINABLE f
StableUnf = (/\a \(d:Eq a) (x:a). blah) |> g
#-}
f = ...
In f's stable unfolding we have done some modest simplification which
has pushed the cast to the outside. (I wonder if this is the Right
Thing, but it's what happens now; see SimplUtils Note [Casts and
lambdas].) Now that stable unfolding must be specialised, so we want
to push the cast back inside. It would be terrible if the cast
defeated specialisation! Hence the use of collectBindersPushingCo.
Note [Evidence foralls]
~~~~~~~~~~~~~~~~~~~~~~~~~~
Suppose (Trac #12212) that we are specialising
f :: forall a b. (Num a, F a ~ F b) => blah
......
......@@ -48,7 +48,7 @@ module Coercion (
mapStepResult, unwrapNewTypeStepper,
topNormaliseNewType_maybe, topNormaliseTypeX,
decomposeCo, getCoVar_maybe,
decomposeCo, decomposeFunCo, getCoVar_maybe,
splitTyConAppCo_maybe,
splitAppCo_maybe,
splitFunCo_maybe,
......@@ -293,8 +293,20 @@ ppr_co_ax_branch ppr_rhs
Destructing coercions
%* *
%************************************************************************
Note [Function coercions]
~~~~~~~~~~~~~~~~~~~~~~~~~
Remember that
(->) :: forall r1 r2. TYPE r1 -> TYPE r2 -> TYPE LiftedRep
Hence
FunCo r co1 co2 :: (s1->t1) ~r (s2->t2)
is short for
TyConAppCo (->) co_rep1 co_rep2 co1 co2
where co_rep1, co_rep2 are the coercions on the representations.
-}
-- | This breaks a 'Coercion' with type @T A B C ~ T D E F@ into
-- a list of 'Coercion's of kinds @A ~ D@, @B ~ E@ and @E ~ F@. Hence:
--
......@@ -304,6 +316,16 @@ decomposeCo arity co
= [mkNthCo n co | n <- [0..(arity-1)] ]
-- Remember, Nth is zero-indexed
decomposeFunCo :: Coercion -> (Coercion, Coercion)
-- Expects co :: (s1 -> t1) ~ (s2 -> t2)
-- Returns (co1 :: s1~s2, co2 :: t1~t2)
-- See Note [Function coercions] for the "2" and "3"
decomposeFunCo co = ASSERT2( all_ok, ppr co )
(mkNthCo 2 co, mkNthCo 3 co)
where
Pair s1t1 s2t2 = coercionKind co
all_ok = isFunTy s1t1 && isFunTy s2t2
-- | Attempts to obtain the type variable underlying a 'Coercion'
getCoVar_maybe :: Coercion -> Maybe CoVar
getCoVar_maybe (CoVarCo cv) = Just cv
......@@ -554,7 +576,7 @@ mkNomReflCo = mkReflCo Nominal
mkTyConAppCo :: HasDebugCallStack => Role -> TyCon -> [Coercion] -> Coercion
mkTyConAppCo r tc cos
| tc `hasKey` funTyConKey
, [_rep1, _rep2, co1, co2] <- cos
, [_rep1, _rep2, co1, co2] <- cos -- See Note [Function coercions]
= -- (a :: TYPE ra) -> (b :: TYPE rb) ~ (c :: TYPE rc) -> (d :: TYPE rd)
-- rep1 :: ra ~ rc rep2 :: rb ~ rd
-- co1 :: a ~ c co2 :: b ~ d
......@@ -882,14 +904,26 @@ mkNthCo n (Refl r ty)
mkNthCo 0 (ForAllCo _ kind_co _) = kind_co
-- If co :: (forall a1:k1. t1) ~ (forall a2:k2. t2)
-- then (nth 0 co :: k1 ~ k2)
mkNthCo n (TyConAppCo _ _ arg_cos) = arg_cos `getNth` n
mkNthCo n co@(FunCo _ arg res)
-- See Note [Function coercions]
-- If FunCo _ arg_co res_co :: (s1:TYPE sk1 -> s2:TYPE sk2)
-- ~ (t1:TYPE tk1 -> t2:TYPE tk2)
-- Then we want to behave as if co was
-- TyConAppCo argk_co resk_co arg_co res_co
-- where
-- argk_co :: sk1 ~ tk1 = mkNthCo 0 (mkKindCo arg_co)
-- resk_co :: sk2 ~ tk2 = mkNthCo 0 (mkKindCo res_co)
-- i.e. mkRuntimeRepCo
= case n of
0 -> mkRuntimeRepCo arg
1 -> mkRuntimeRepCo res
2 -> arg
3 -> res
_ -> pprPanic "mkNthCo(FunCo)" (ppr n $$ ppr co)
mkNthCo n (TyConAppCo _ _ arg_cos) = arg_cos `getNth` n
mkNthCo n co = NthCo n co
mkLRCo :: LeftOrRight -> Coercion -> Coercion
......@@ -937,8 +971,10 @@ mkKindCo co
-- generally, calling coercionKind during coercion creation is a bad idea,
-- as it can lead to exponential behavior. But, we don't have nested mkKindCos,
-- so it's OK here.
, typeKind ty1 `eqType` typeKind ty2
= Refl Nominal (typeKind ty1)
, let tk1 = typeKind ty1
tk2 = typeKind ty2
, tk1 `eqType` tk2
= Refl Nominal tk1
| otherwise
= KindCo co
......
......@@ -843,7 +843,7 @@ unify_ty (CoercionTy co1) (CoercionTy co2) kco
-> do { b <- tvBindFlagL cv
; if b == BindMe
then do { checkRnEnvRCo co2
; let [_, _, co_l, co_r] = decomposeCo 4 kco
; let (co_l, co_r) = decomposeFunCo kco
-- cv :: t1 ~ t2
-- co2 :: s1 ~ s2
-- co_l :: t1 ~ s1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment