Commit 474e535b authored by cactus's avatar cactus

In pattern synonym matchers, support unboxed continuation results (fixes #9783).

This requires ensuring the continuations have arguments by adding a dummy
Void# argument when needed. This is so that matching on a pattern synonym
is lazy even when the result is unboxed, e.g.

    pattern P = ()
    f P = 0#

In this case, without dummy arguments, the generated matcher's type would be

   $mP :: forall (r :: ?). () -> r -> r -> r

which is called in `f` at type `() -> Int# -> Int# -> Int#`,
so it would be strict, in particular, in the failure continuation
of `patError`.

We work around this by making sure both continuations have arguments:

  $mP :: forall (r :: ?). () -> (Void# -> r) -> (Void# -> r) -> r

Of course, if `P` (and thus, the success continuation) has any arguments,
we are only adding the extra dummy argument to the failure continuation.
parent f5996d91
...@@ -76,17 +76,22 @@ For each pattern synonym, we generate a single matcher function which ...@@ -76,17 +76,22 @@ For each pattern synonym, we generate a single matcher function which
implements the actual matching. For the above example, the matcher implements the actual matching. For the above example, the matcher
will have type: will have type:
$mP :: forall r t. (Eq t, Num t) $mP :: forall (r :: ?) t. (Eq t, Num t)
=> T (Maybe t) => T (Maybe t)
-> (forall b. (Show (Maybe t), Ord b) => b -> r) -> (forall b. (Show (Maybe t), Ord b) => b -> r)
-> r -> (Void# -> r)
-> r -> r
with the following implementation: with the following implementation:
$mP @r @t $dEq $dNum scrut cont fail = case scrut of $mP @r @t $dEq $dNum scrut cont fail = case scrut of
MkT @b $dShow $dOrd [x] (Just 42) -> cont @b $dShow $dOrd x MkT @b $dShow $dOrd [x] (Just 42) -> cont @b $dShow $dOrd x
_ -> fail _ -> fail Void#
The extra Void# argument for the failure continuation is needed so that
it is lazy even when the result type is unboxed. For the same reason,
if the pattern has no arguments, an extra Void# argument is added
to the success continuation as well.
For *bidirectional* pattern synonyms, we also generate a single wrapper For *bidirectional* pattern synonyms, we also generate a single wrapper
function which implements the pattern synonym in an expression function which implements the pattern synonym in an expression
...@@ -130,11 +135,19 @@ data PatSyn ...@@ -130,11 +135,19 @@ data PatSyn
-- See Note [Matchers and wrappers for pattern synonyms] -- See Note [Matchers and wrappers for pattern synonyms]
psMatcher :: Id, psMatcher :: Id,
-- Matcher function, of type -- Matcher function. If psArgs is empty, then it has type
-- forall r univ_tvs. req_theta -- forall (r :: ?) univ_tvs. req_theta
-- => res_ty -- => res_ty
-- -> (forall ex_tvs. prov_theta -> arg_tys -> r) -- -> (forall ex_tvs. prov_theta -> Void# -> r)
-- -> r -> r -- -> (Void# -> r)
-- -> r
--
-- Otherwise:
-- forall (r :: ?) univ_tvs. req_theta
-- => res_ty
-- -> (forall ex_tvs. prov_theta -> arg_tys -> r)
-- -> (Void# -> r)
-- -> r
psWrapper :: Maybe Id psWrapper :: Maybe Id
-- Nothing => uni-directional pattern synonym -- Nothing => uni-directional pattern synonym
......
...@@ -348,7 +348,7 @@ mkPatSynCase var ty alt fail = do ...@@ -348,7 +348,7 @@ mkPatSynCase var ty alt fail = do
matcher <- dsLExpr $ mkLHsWrap wrapper $ nlHsTyApp matcher [ty] matcher <- dsLExpr $ mkLHsWrap wrapper $ nlHsTyApp matcher [ty]
let MatchResult _ mkCont = match_result let MatchResult _ mkCont = match_result
cont <- mkCoreLams bndrs <$> mkCont fail cont <- mkCoreLams bndrs <$> mkCont fail
return $ mkCoreAppsDs matcher [Var var, cont, fail] return $ mkCoreAppsDs matcher [Var var, ensure_unstrict cont, make_unstrict fail]
where where
MkCaseAlt{ alt_pat = psyn, MkCaseAlt{ alt_pat = psyn,
alt_bndrs = bndrs, alt_bndrs = bndrs,
...@@ -356,6 +356,11 @@ mkPatSynCase var ty alt fail = do ...@@ -356,6 +356,11 @@ mkPatSynCase var ty alt fail = do
alt_result = match_result} = alt alt_result = match_result} = alt
matcher = patSynMatcher psyn matcher = patSynMatcher psyn
-- See Note [Matchers and wrappers for pattern synonyms] in PatSyns
-- on these extra Void# arguments
ensure_unstrict = if null (patSynArgs psyn) then make_unstrict else id
make_unstrict = Lam voidArgId
mkDataConCase :: Id -> Type -> [CaseAlt DataCon] -> MatchResult mkDataConCase :: Id -> Type -> [CaseAlt DataCon] -> MatchResult
mkDataConCase _ _ [] = panic "mkDataConCase: no alternatives" mkDataConCase _ _ [] = panic "mkDataConCase: no alternatives"
mkDataConCase var ty alts@(alt1:_) = MatchResult fail_flag mk_case mkDataConCase var ty alts@(alt1:_) = MatchResult fail_flag mk_case
......
...@@ -24,12 +24,12 @@ import Outputable ...@@ -24,12 +24,12 @@ import Outputable
import FastString import FastString
import Var import Var
import Id import Id
import IdInfo( IdDetails( VanillaId ) )
import TcBinds import TcBinds
import BasicTypes import BasicTypes
import TcSimplify import TcSimplify
import TcType import TcType
import VarSet import VarSet
import MkId
#if __GLASGOW_HASKELL__ < 709 #if __GLASGOW_HASKELL__ < 709
import Data.Monoid import Data.Monoid
#endif #endif
...@@ -129,25 +129,29 @@ tcPatSynMatcher :: Located Name ...@@ -129,25 +129,29 @@ tcPatSynMatcher :: Located Name
-> TcM (Id, LHsBinds Id) -> TcM (Id, LHsBinds Id)
-- See Note [Matchers and wrappers for pattern synonyms] in PatSyn -- See Note [Matchers and wrappers for pattern synonyms] in PatSyn
tcPatSynMatcher (L loc name) lpat args univ_tvs ex_tvs ev_binds prov_dicts req_dicts prov_theta req_theta pat_ty tcPatSynMatcher (L loc name) lpat args univ_tvs ex_tvs ev_binds prov_dicts req_dicts prov_theta req_theta pat_ty
= do { res_tv <- zonkQuantifiedTyVar =<< newFlexiTyVar liftedTypeKind = do { res_tv <- do
{ uniq <- newUnique
; let tv_name = mkInternalName uniq (mkTyVarOcc "r") loc
; return $ mkTcTyVar tv_name openTypeKind (SkolemTv False) }
; matcher_name <- newImplicitBinder name mkMatcherOcc ; matcher_name <- newImplicitBinder name mkMatcherOcc
; let res_ty = TyVarTy res_tv ; let res_ty = TyVarTy res_tv
cont_args = if null args then [voidPrimId] else args
cont_ty = mkSigmaTy ex_tvs prov_theta $ cont_ty = mkSigmaTy ex_tvs prov_theta $
mkFunTys (map varType args) res_ty mkFunTys (map varType cont_args) res_ty
fail_ty = mkFunTy voidPrimTy res_ty
; let matcher_tau = mkFunTys [pat_ty, cont_ty, res_ty] res_ty ; let matcher_tau = mkFunTys [pat_ty, cont_ty, fail_ty] res_ty
matcher_sigma = mkSigmaTy (res_tv:univ_tvs) req_theta matcher_tau matcher_sigma = mkSigmaTy (res_tv:univ_tvs) req_theta matcher_tau
matcher_id = mkExportedLocalId VanillaId matcher_name matcher_sigma matcher_id = mkVanillaGlobal matcher_name matcher_sigma
; traceTc "tcPatSynMatcher" (ppr name $$ ppr (idType matcher_id)) ; traceTc "tcPatSynMatcher" (ppr name $$ ppr (idType matcher_id))
; let matcher_lid = L loc matcher_id ; let matcher_lid = L loc matcher_id
; scrutinee <- mkId "scrut" pat_ty ; scrutinee <- mkId "scrut" pat_ty
; cont <- mkId "cont" cont_ty ; cont <- mkId "cont" cont_ty
; let cont' = nlHsApps cont $ map nlHsVar (ex_tvs ++ prov_dicts ++ args) ; let cont' = nlHsApps cont $ map nlHsVar (ex_tvs ++ prov_dicts ++ cont_args)
; fail <- mkId "fail" res_ty ; fail <- mkId "fail" fail_ty
; let fail' = nlHsVar fail ; let fail' = nlHsApps fail [nlHsVar voidPrimId]
; let args = map nlVarPat [scrutinee, cont, fail] ; let args = map nlVarPat [scrutinee, cont, fail]
lwpat = noLoc $ WildPat pat_ty lwpat = noLoc $ WildPat pat_ty
...@@ -190,9 +194,7 @@ tcPatSynMatcher (L loc name) lpat args univ_tvs ex_tvs ev_binds prov_dicts req_d ...@@ -190,9 +194,7 @@ tcPatSynMatcher (L loc name) lpat args univ_tvs ex_tvs ev_binds prov_dicts req_d
; return (matcher_id, matcher_bind) } ; return (matcher_id, matcher_bind) }
where where
mkId s ty = do mkId s ty = mkSysLocalM (fsLit s) ty
name <- newName . mkVarOccFS . fsLit $ s
return $ mkLocalId name ty
isBidirectional :: HsPatSynDir a -> Bool isBidirectional :: HsPatSynDir a -> Bool
isBidirectional Unidirectional = False isBidirectional Unidirectional = False
...@@ -248,7 +250,7 @@ mkPatSynWrapperId (L _ name) args univ_tvs ex_tvs theta pat_ty ...@@ -248,7 +250,7 @@ mkPatSynWrapperId (L _ name) args univ_tvs ex_tvs theta pat_ty
wrapper_sigma = mkSigmaTy wrapper_tvs wrapper_theta wrapper_tau wrapper_sigma = mkSigmaTy wrapper_tvs wrapper_theta wrapper_tau
; wrapper_name <- newImplicitBinder name mkDataConWrapperOcc ; wrapper_name <- newImplicitBinder name mkDataConWrapperOcc
; return $ mkExportedLocalId VanillaId wrapper_name wrapper_sigma } ; return $ mkVanillaGlobal wrapper_name wrapper_sigma }
mkPatSynWrapper :: Id mkPatSynWrapper :: Id
-> HsBind Name -> HsBind Name
......
...@@ -1097,6 +1097,7 @@ mk/ghcconfig*_inplace_bin_ghc-stage2.exe.mk ...@@ -1097,6 +1097,7 @@ mk/ghcconfig*_inplace_bin_ghc-stage2.exe.mk
/tests/patsyn/should_run/ex-prov /tests/patsyn/should_run/ex-prov
/tests/patsyn/should_run/ex-prov-run /tests/patsyn/should_run/ex-prov-run
/tests/patsyn/should_run/match /tests/patsyn/should_run/match
/tests/patsyn/should_run/match-unboxed
/tests/perf/compiler/T1969.comp.stats /tests/perf/compiler/T1969.comp.stats
/tests/perf/compiler/T3064.comp.stats /tests/perf/compiler/T3064.comp.stats
/tests/perf/compiler/T3294.comp.stats /tests/perf/compiler/T3294.comp.stats
......
{-# LANGUAGE PatternSynonyms, MagicHash #-}
module Main where
import GHC.Base
pattern P1 <- 0
pattern P2 <- 1
f :: Int -> Int#
f P1 = 42#
f P2 = 44#
main = do
print $ I# (f 0)
print $ I# (f 1)
...@@ -3,3 +3,4 @@ test('match', normal, compile_and_run, ['']) ...@@ -3,3 +3,4 @@ test('match', normal, compile_and_run, [''])
test('ex-prov-run', normal, compile_and_run, ['']) test('ex-prov-run', normal, compile_and_run, [''])
test('bidir-explicit', normal, compile_and_run, ['']) test('bidir-explicit', normal, compile_and_run, [''])
test('bidir-explicit-scope', normal, compile_and_run, ['']) test('bidir-explicit-scope', normal, compile_and_run, [''])
test('T9783', normal, compile_and_run, [''])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment