Commit 00cbbab3 authored by eir@cis.upenn.edu's avatar eir@cis.upenn.edu
Browse files

Refactor the typechecker to use ExpTypes.

The idea here is described in [wiki:Typechecker]. Briefly,
this refactor keeps solid track of "synthesis" mode vs
"checking" in GHC's bidirectional type-checking algorithm.
When in synthesis mode, the expected type is just an IORef
to write to.

In addition, this patch does a significant reworking of
RebindableSyntax, allowing much more freedom in the types
of the rebindable operators. For example, we can now have
`negate :: Int -> Bool` and
`(>>=) :: m a -> (forall x. a x -> m b) -> m b`. The magic
is in tcSyntaxOp.

This addresses tickets #11397, #11452, and #11458.

Tests:
  typecheck/should_compile/{RebindHR,RebindNegate,T11397,T11458}
  th/T11452
parent 2899aa58
......@@ -31,6 +31,7 @@ import Id
import ConLike
import DataCon
import Name
import FamInstEnv
import TysWiredIn
import TyCon
import SrcLoc
......@@ -148,7 +149,8 @@ type PmResult = ( [[LPat Id]]
checkSingle :: Id -> Pat Id -> DsM PmResult
checkSingle var p = do
let lp = [noLoc p]
vec <- liftUs (translatePat p)
fam_insts <- dsGetFamInstEnvs
vec <- liftUs (translatePat fam_insts p)
vsa <- initial_uncovered [var]
(c,d,us') <- patVectProc False (vec,[]) vsa -- no guards
us <- pruneVSA us'
......@@ -171,7 +173,8 @@ checkMatches oversimplify vars matches
return ([], [], missing')
go (m:ms) missing = do
clause <- liftUs (translateMatch m)
fam_insts <- dsGetFamInstEnvs
clause <- liftUs (translateMatch fam_insts m)
(c, d, us ) <- patVectProc oversimplify clause missing
(rs, is, us') <- go ms us
return $ case (c,d) of
......@@ -209,7 +212,8 @@ noFailingGuards clauses = sum [ countPatVecs gvs | (_, gvs) <- clauses ]
computeNoGuards :: [LMatch Id (LHsExpr Id)] -> PmM Int
computeNoGuards matches = do
matches' <- mapM (liftUs . translateMatch) matches
fam_insts <- dsGetFamInstEnvs
matches' <- mapM (liftUs . translateMatch fam_insts) matches
return (noFailingGuards matches')
maximum_failing_guards :: Int
......@@ -264,46 +268,47 @@ mkLitPattern lit = PmLit { pm_lit_lit = PmSLit lit }
-- -----------------------------------------------------------------------
-- * Transform (Pat Id) into of (PmPat Id)
translatePat :: Pat Id -> UniqSM PatVec
translatePat pat = case pat of
translatePat :: FamInstEnvs -> Pat Id -> UniqSM PatVec
translatePat fam_insts pat = case pat of
WildPat ty -> mkPmVarsSM [ty]
VarPat id -> return [PmVar (unLoc id)]
ParPat p -> translatePat (unLoc p)
ParPat p -> translatePat fam_insts (unLoc p)
LazyPat _ -> mkPmVarsSM [hsPatType pat] -- like a variable
-- ignore strictness annotations for now
BangPat p -> translatePat (unLoc p)
BangPat p -> translatePat fam_insts (unLoc p)
AsPat lid p -> do
-- Note [Translating As Patterns]
ps <- translatePat (unLoc p)
ps <- translatePat fam_insts (unLoc p)
let [e] = map valAbsToPmExpr (coercePatVec ps)
g = PmGrd [PmVar (unLoc lid)] e
return (ps ++ [g])
SigPatOut p _ty -> translatePat (unLoc p)
SigPatOut p _ty -> translatePat fam_insts (unLoc p)
-- See Note [Translate CoPats]
CoPat wrapper p ty
| isIdHsWrapper wrapper -> translatePat p
| WpCast co <- wrapper, isReflexiveCo co -> translatePat p
| isIdHsWrapper wrapper -> translatePat fam_insts p
| WpCast co <- wrapper, isReflexiveCo co -> translatePat fam_insts p
| otherwise -> do
ps <- translatePat p
ps <- translatePat fam_insts p
(xp,xe) <- mkPmId2FormsSM ty
let g = mkGuard ps (HsWrap wrapper (unLoc xe))
return [xp,g]
-- (n + k) ===> x (True <- x >= k) (n <- x-k)
NPlusKPat (L _ n) k ge minus -> do
(xp, xe) <- mkPmId2FormsSM (idType n)
let ke = L (getLoc k) (HsOverLit (unLoc k))
g1 = mkGuard [truePattern] (OpApp xe (noLoc ge) no_fixity ke)
g2 = mkGuard [PmVar n] (OpApp xe (noLoc minus) no_fixity ke)
NPlusKPat (L _ n) k1 k2 ge minus ty -> do
(xp, xe) <- mkPmId2FormsSM ty
let ke1 = L (getLoc k1) (HsOverLit (unLoc k1))
ke2 = L (getLoc k1) (HsOverLit k2)
g1 = mkGuard [truePattern] (unLoc $ nlHsSyntaxApps ge [xe, ke1])
g2 = mkGuard [PmVar n] (unLoc $ nlHsSyntaxApps minus [xe, ke2])
return [xp, g1, g2]
-- (fun -> pat) ===> x (pat <- fun x)
ViewPat lexpr lpat arg_ty -> do
ps <- translatePat (unLoc lpat)
ps <- translatePat fam_insts (unLoc lpat)
-- See Note [Guards and Approximation]
case all cantFailPattern ps of
True -> do
......@@ -316,15 +321,18 @@ translatePat pat = case pat of
-- list
ListPat ps ty Nothing -> do
foldr (mkListPatVec ty) [nilPattern ty] <$> translatePatVec (map unLoc ps)
foldr (mkListPatVec ty) [nilPattern ty] <$> translatePatVec fam_insts (map unLoc ps)
-- overloaded list
ListPat lpats elem_ty (Just (pat_ty, _to_list))
| Just e_ty <- splitListTyConApp_maybe pat_ty, elem_ty `eqType` e_ty ->
| Just e_ty <- splitListTyConApp_maybe pat_ty
, (_, norm_elem_ty) <- normaliseType fam_insts Nominal elem_ty
-- elem_ty is frequently something like `Item [Int]`, but we prefer `Int`
, norm_elem_ty `eqType` e_ty ->
-- We have to ensure that the element types are exactly the same.
-- Otherwise, one may give an instance IsList [Int] (more specific than
-- the default IsList [a]) with a different implementation for `toList'
translatePat (ListPat lpats e_ty Nothing)
translatePat fam_insts (ListPat lpats e_ty Nothing)
| otherwise -> do
-- See Note [Guards and Approximation]
var <- mkPmVarSM pat_ty
......@@ -345,29 +353,29 @@ translatePat pat = case pat of
, pat_tvs = ex_tvs
, pat_dicts = dicts
, pat_args = ps } -> do
args <- translateConPatVec arg_tys ex_tvs con ps
args <- translateConPatVec fam_insts arg_tys ex_tvs con ps
return [PmCon { pm_con_con = con
, pm_con_arg_tys = arg_tys
, pm_con_tvs = ex_tvs
, pm_con_dicts = dicts
, pm_con_args = args }]
NPat (L _ ol) mb_neg _eq -> translateNPat ol mb_neg
NPat (L _ ol) mb_neg _eq ty -> translateNPat fam_insts ol mb_neg ty
LitPat lit
-- If it is a string then convert it to a list of characters
| HsString src s <- lit ->
foldr (mkListPatVec charTy) [nilPattern charTy] <$>
translatePatVec (map (LitPat . HsChar src) (unpackFS s))
translatePatVec fam_insts (map (LitPat . HsChar src) (unpackFS s))
| otherwise -> return [mkLitPattern lit]
PArrPat ps ty -> do
tidy_ps <- translatePatVec (map unLoc ps)
tidy_ps <- translatePatVec fam_insts (map unLoc ps)
let fake_con = parrFakeCon (length ps)
return [vanillaConPattern fake_con [ty] (concat tidy_ps)]
TuplePat ps boxity tys -> do
tidy_ps <- translatePatVec (map unLoc ps)
tidy_ps <- translatePatVec fam_insts (map unLoc ps)
let tuple_con = tupleDataCon boxity (length ps)
return [vanillaConPattern tuple_con tys (concat tidy_ps)]
......@@ -378,33 +386,35 @@ translatePat pat = case pat of
SigPatIn {} -> panic "Check.translatePat: SigPatIn"
-- | Translate an overloaded literal (see `tidyNPat' in deSugar/MatchLit.hs)
translateNPat :: HsOverLit Id -> Maybe (SyntaxExpr Id) -> UniqSM PatVec
translateNPat (OverLit val False _ ty) mb_neg
| isStringTy ty, HsIsString src s <- val, Nothing <- mb_neg
= translatePat (LitPat (HsString src s))
| isIntTy ty, HsIntegral src i <- val
= translatePat (mk_num_lit HsInt src i)
| isWordTy ty, HsIntegral src i <- val
= translatePat (mk_num_lit HsWordPrim src i)
translateNPat :: FamInstEnvs
-> HsOverLit Id -> Maybe (SyntaxExpr Id) -> Type -> UniqSM PatVec
translateNPat fam_insts (OverLit val False _ ty) mb_neg outer_ty
| not type_change, isStringTy ty, HsIsString src s <- val, Nothing <- mb_neg
= translatePat fam_insts (LitPat (HsString src s))
| not type_change, isIntTy ty, HsIntegral src i <- val
= translatePat fam_insts (mk_num_lit HsInt src i)
| not type_change, isWordTy ty, HsIntegral src i <- val
= translatePat fam_insts (mk_num_lit HsWordPrim src i)
where
type_change = not (outer_ty `eqType` ty)
mk_num_lit c src i = LitPat $ case mb_neg of
Nothing -> c src i
Just _ -> c src (-i)
translateNPat ol mb_neg
translateNPat _ ol mb_neg _
= return [PmLit { pm_lit_lit = PmOLit (isJust mb_neg) ol }]
-- | Translate a list of patterns (Note: each pattern is translated
-- to a pattern vector but we do not concatenate the results).
translatePatVec :: [Pat Id] -> UniqSM [PatVec]
translatePatVec pats = mapM translatePat pats
translatePatVec :: FamInstEnvs -> [Pat Id] -> UniqSM [PatVec]
translatePatVec fam_insts pats = mapM (translatePat fam_insts) pats
translateConPatVec :: [Type] -> [TyVar]
translateConPatVec :: FamInstEnvs -> [Type] -> [TyVar]
-> DataCon -> HsConPatDetails Id -> UniqSM PatVec
translateConPatVec _univ_tys _ex_tvs _ (PrefixCon ps)
= concat <$> translatePatVec (map unLoc ps)
translateConPatVec _univ_tys _ex_tvs _ (InfixCon p1 p2)
= concat <$> translatePatVec (map unLoc [p1,p2])
translateConPatVec univ_tys ex_tvs c (RecCon (HsRecFields fs _))
translateConPatVec fam_insts _univ_tys _ex_tvs _ (PrefixCon ps)
= concat <$> translatePatVec fam_insts (map unLoc ps)
translateConPatVec fam_insts _univ_tys _ex_tvs _ (InfixCon p1 p2)
= concat <$> translatePatVec fam_insts (map unLoc [p1,p2])
translateConPatVec fam_insts univ_tys ex_tvs c (RecCon (HsRecFields fs _))
-- Nothing matched. Make up some fresh term variables
| null fs = mkPmVarsSM arg_tys
-- The data constructor was not defined using record syntax. For the
......@@ -417,7 +427,7 @@ translateConPatVec univ_tys ex_tvs c (RecCon (HsRecFields fs _))
| matched_lbls `subsetOf` orig_lbls
= ASSERT(length orig_lbls == length arg_tys)
let translateOne (lbl, ty) = case lookup lbl matched_pats of
Just p -> translatePat p
Just p -> translatePat fam_insts p
Nothing -> mkPmVarsSM [ty]
in concatMapM translateOne (zip orig_lbls arg_tys)
-- The fields that appear are not in the correct order. Make up fresh
......@@ -426,7 +436,7 @@ translateConPatVec univ_tys ex_tvs c (RecCon (HsRecFields fs _))
| otherwise = do
arg_var_pats <- mkPmVarsSM arg_tys
translated_pats <- forM matched_pats $ \(x,pat) -> do
pvec <- translatePat pat
pvec <- translatePat fam_insts pat
return (x, pvec)
let zipped = zip orig_lbls [ x | PmVar x <- arg_var_pats ]
......@@ -453,10 +463,10 @@ translateConPatVec univ_tys ex_tvs c (RecCon (HsRecFields fs _))
| x == y = subsetOf xs ys
| otherwise = subsetOf (x:xs) ys
translateMatch :: LMatch Id (LHsExpr Id) -> UniqSM (PatVec,[PatVec])
translateMatch (L _ (Match _ lpats _ grhss)) = do
pats' <- concat <$> translatePatVec pats
guards' <- mapM translateGuards guards
translateMatch :: FamInstEnvs -> LMatch Id (LHsExpr Id) -> UniqSM (PatVec,[PatVec])
translateMatch fam_insts (L _ (Match _ lpats _ grhss)) = do
pats' <- concat <$> translatePatVec fam_insts pats
guards' <- mapM (translateGuards fam_insts) guards
return (pats', guards')
where
extractGuards :: LGRHS Id (LHsExpr Id) -> [GuardStmt Id]
......@@ -469,9 +479,9 @@ translateMatch (L _ (Match _ lpats _ grhss)) = do
-- * Transform source guards (GuardStmt Id) to PmPats (Pattern)
-- | Translate a list of guard statements to a pattern vector
translateGuards :: [GuardStmt Id] -> UniqSM PatVec
translateGuards guards = do
all_guards <- concat <$> mapM translateGuard guards
translateGuards :: FamInstEnvs -> [GuardStmt Id] -> UniqSM PatVec
translateGuards fam_insts guards = do
all_guards <- concat <$> mapM (translateGuard fam_insts) guards
return (replace_unhandled all_guards)
-- It should have been (return $ all_guards) but it is too expressive.
-- Since the term oracle does not handle all constraints we generate,
......@@ -509,24 +519,24 @@ cantFailPattern (PmGrd pv _e)
cantFailPattern _ = False
-- | Translate a guard statement to Pattern
translateGuard :: GuardStmt Id -> UniqSM PatVec
translateGuard (BodyStmt e _ _ _) = translateBoolGuard e
translateGuard (LetStmt binds) = translateLet (unLoc binds)
translateGuard (BindStmt p e _ _) = translateBind p e
translateGuard (LastStmt {}) = panic "translateGuard LastStmt"
translateGuard (ParStmt {}) = panic "translateGuard ParStmt"
translateGuard (TransStmt {}) = panic "translateGuard TransStmt"
translateGuard (RecStmt {}) = panic "translateGuard RecStmt"
translateGuard (ApplicativeStmt {}) = panic "translateGuard ApplicativeLastStmt"
translateGuard :: FamInstEnvs -> GuardStmt Id -> UniqSM PatVec
translateGuard _ (BodyStmt e _ _ _) = translateBoolGuard e
translateGuard _ (LetStmt binds) = translateLet (unLoc binds)
translateGuard fam_insts (BindStmt p e _ _ _) = translateBind fam_insts p e
translateGuard _ (LastStmt {}) = panic "translateGuard LastStmt"
translateGuard _ (ParStmt {}) = panic "translateGuard ParStmt"
translateGuard _ (TransStmt {}) = panic "translateGuard TransStmt"
translateGuard _ (RecStmt {}) = panic "translateGuard RecStmt"
translateGuard _ (ApplicativeStmt {}) = panic "translateGuard ApplicativeLastStmt"
-- | Translate let-bindings
translateLet :: HsLocalBinds Id -> UniqSM PatVec
translateLet _binds = return []
-- | Translate a pattern guard
translateBind :: LPat Id -> LHsExpr Id -> UniqSM PatVec
translateBind (L _ p) e = do
ps <- translatePat p
translateBind :: FamInstEnvs -> LPat Id -> LHsExpr Id -> UniqSM PatVec
translateBind fam_insts (L _ p) e = do
ps <- translatePat fam_insts p
return [mkGuard ps (unLoc e)]
-- | Translate a boolean guard
......@@ -600,7 +610,8 @@ below is the *right thing to do*:
The case with literals is a bit different. a literal @l@ should be translated
to @x (True <- x == from l)@. Since we want to have better warnings for
overloaded literals as it is a very common feature, we treat them differently.
They are mainly covered in Note [Undecidable Equality on Overloaded Literals].
They are mainly covered in Note [Undecidable Equality on Overloaded Literals]
in PmExpr.
4. N+K Patterns & Pattern Synonyms
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......@@ -845,9 +856,6 @@ coercePmPat (PmCon { pm_con_con = con, pm_con_arg_tys = arg_tys
, pm_con_args = coercePatVec args }]
coercePmPat (PmGrd {}) = [] -- drop the guards
no_fixity :: a -- TODO: Can we retrieve the fixity from the operator name?
no_fixity = panic "Check: no fixity"
-- Get all constructors in the family (including given)
allConstructors :: DataCon -> [DataCon]
allConstructors = tyConDataCons . dataConTyCon
......@@ -1101,7 +1109,7 @@ cMatcher us gvsa (p@(PmCon { pm_con_con = c1, pm_con_args = args1 })) ps
-- CLitLit
cMatcher us gvsa (PmLit l1) ps (va@(PmLit l2)) vsa = case eqPmLit l1 l2 of
-- See Note [Undecidable Equality for Overloaded Literals]
-- See Note [Undecidable Equality for Overloaded Literals] in PmExpr
True -> va `mkCons` covered us gvsa ps vsa -- match
False -> Empty -- mismatch
......@@ -1172,7 +1180,7 @@ uMatcher us gvsa ( p@(PmCon { pm_con_con = c1, pm_con_args = args1 })) ps
-- ULitLit
uMatcher us gvsa (PmLit l1) ps (va@(PmLit l2)) vsa = case eqPmLit l1 l2 of
-- See Note [Undecidable Equality for Overloaded Literals]
-- See Note [Undecidable Equality for Overloaded Literals] in PmExpr
True -> va `mkCons` uncovered us gvsa ps vsa -- match
False -> va `mkCons` vsa -- mismatch
......@@ -1256,7 +1264,7 @@ dMatcher us gvsa (p@(PmCon { pm_con_con = c1, pm_con_args = args1 })) ps
-- DLitLit
dMatcher us gvsa (PmLit l1) ps (va@(PmLit l2)) vsa = case eqPmLit l1 l2 of
-- See Note [Undecidable Equality for Overloaded Literals]
-- See Note [Undecidable Equality for Overloaded Literals] in PmExpr
True -> va `mkCons` divergent us gvsa ps vsa -- match
False -> Empty -- mismatch
......@@ -1331,10 +1339,12 @@ genCaseTmCs2 :: Maybe (LHsExpr Id) -- Scrutinee
-> [Id] -- MatchVars (should have length 1)
-> DsM (Bag SimpleEq)
genCaseTmCs2 Nothing _ _ = return emptyBag
genCaseTmCs2 (Just scr) [p] [var] = liftUs $ do
[e] <- map valAbsToPmExpr . coercePatVec <$> translatePat p
let scr_e = lhsExprToPmExpr scr
return $ listToBag [(var, e), (var, scr_e)]
genCaseTmCs2 (Just scr) [p] [var] = do
fam_insts <- dsGetFamInstEnvs
liftUs $ do
[e] <- map valAbsToPmExpr . coercePatVec <$> translatePat fam_insts p
let scr_e = lhsExprToPmExpr scr
return $ listToBag [(var, e), (var, scr_e)]
genCaseTmCs2 _ _ _ = panic "genCaseTmCs2: HsCase"
-- | Generate a simple equality when checking a case expression:
......
......@@ -592,8 +592,9 @@ addTickHsExpr (ExplicitList ty wit es) =
(addTickWit wit)
(mapM (addTickLHsExpr) es)
where addTickWit Nothing = return Nothing
addTickWit (Just fln) = do fln' <- addTickHsExpr fln
return (Just fln')
addTickWit (Just fln)
= do fln' <- addTickSyntaxExpr hpcSrcSpan fln
return (Just fln')
addTickHsExpr (ExplicitPArr ty es) =
liftM2 ExplicitPArr
(return ty)
......@@ -621,7 +622,7 @@ addTickHsExpr (ArithSeq ty wit arith_seq) =
(addTickWit wit)
(addTickArithSeqInfo arith_seq)
where addTickWit Nothing = return Nothing
addTickWit (Just fl) = do fl' <- addTickHsExpr fl
addTickWit (Just fl) = do fl' <- addTickSyntaxExpr hpcSrcSpan fl
return (Just fl')
-- We might encounter existing ticks (multiple Coverage passes)
......@@ -732,12 +733,13 @@ addTickStmt _isGuard (LastStmt e noret ret) = do
(addTickLHsExpr e)
(pure noret)
(addTickSyntaxExpr hpcSrcSpan ret)
addTickStmt _isGuard (BindStmt pat e bind fail) = do
liftM4 BindStmt
addTickStmt _isGuard (BindStmt pat e bind fail ty) = do
liftM5 BindStmt
(addTickLPat pat)
(addTickLHsExprRHS e)
(addTickSyntaxExpr hpcSrcSpan bind)
(addTickSyntaxExpr hpcSrcSpan fail)
(return ty)
addTickStmt isGuard (BodyStmt e bind' guard' ty) = do
liftM4 BodyStmt
(addTick isGuard e)
......@@ -747,11 +749,12 @@ addTickStmt isGuard (BodyStmt e bind' guard' ty) = do
addTickStmt _isGuard (LetStmt (L l binds)) = do
liftM (LetStmt . L l)
(addTickHsLocalBinds binds)
addTickStmt isGuard (ParStmt pairs mzipExpr bindExpr) = do
liftM3 ParStmt
addTickStmt isGuard (ParStmt pairs mzipExpr bindExpr ty) = do
liftM4 ParStmt
(mapM (addTickStmtAndBinders isGuard) pairs)
(addTickSyntaxExpr hpcSrcSpan mzipExpr)
(unLoc <$> addTickLHsExpr (L hpcSrcSpan mzipExpr))
(addTickSyntaxExpr hpcSrcSpan bindExpr)
(return ty)
addTickStmt isGuard (ApplicativeStmt args mb_join body_ty) = do
args' <- mapM (addTickApplicativeArg isGuard) args
return (ApplicativeStmt args' mb_join body_ty)
......@@ -765,7 +768,7 @@ addTickStmt isGuard stmt@(TransStmt { trS_stmts = stmts
t_u <- addTickLHsExprRHS using
t_f <- addTickSyntaxExpr hpcSrcSpan returnExpr
t_b <- addTickSyntaxExpr hpcSrcSpan bindExpr
t_m <- addTickSyntaxExpr hpcSrcSpan liftMExpr
L _ t_m <- addTickLHsExpr (L hpcSrcSpan liftMExpr)
return $ stmt { trS_stmts = t_s, trS_by = t_y, trS_using = t_u
, trS_ret = t_f, trS_bind = t_b, trS_fmap = t_m }
......@@ -792,7 +795,7 @@ addTickApplicativeArg isGuard (op, arg) =
addTickArg (ApplicativeArgMany stmts ret pat) =
ApplicativeArgMany
<$> addTickLStmts isGuard stmts
<*> addTickSyntaxExpr hpcSrcSpan ret
<*> (unLoc <$> addTickLHsExpr (L hpcSrcSpan ret))
<*> addTickLPat pat
addTickStmtAndBinders :: Maybe (Bool -> BoxLabel) -> ParStmtBlock Id Id
......@@ -837,9 +840,9 @@ addTickIPBind (IPBind nm e) =
-- There is no location here, so we might need to use a context location??
addTickSyntaxExpr :: SrcSpan -> SyntaxExpr Id -> TM (SyntaxExpr Id)
addTickSyntaxExpr pos x = do
addTickSyntaxExpr pos syn@(SyntaxExpr { syn_expr = x }) = do
L _ x' <- addTickLHsExpr (L pos x)
return $ x'
return $ syn { syn_expr = x' }
-- we do not walk into patterns.
addTickLPat :: LPat Id -> TM (LPat Id)
addTickLPat pat = return pat
......@@ -951,12 +954,13 @@ addTickLCmdStmts' lstmts res
binders = collectLStmtsBinders lstmts
addTickCmdStmt :: Stmt Id (LHsCmd Id) -> TM (Stmt Id (LHsCmd Id))
addTickCmdStmt (BindStmt pat c bind fail) = do
liftM4 BindStmt
addTickCmdStmt (BindStmt pat c bind fail ty) = do
liftM5 BindStmt
(addTickLPat pat)
(addTickLHsCmd c)
(return bind)
(return fail)
(return ty)
addTickCmdStmt (LastStmt c noret ret) = do
liftM3 LastStmt
(addTickLHsCmd c)
......
......@@ -25,7 +25,7 @@ import qualified HsUtils
-- So WATCH OUT; check each use of split*Ty functions.
-- Sigh. This is a pain.
import {-# SOURCE #-} DsExpr ( dsExpr, dsLExpr, dsLocalBinds )
import {-# SOURCE #-} DsExpr ( dsExpr, dsLExpr, dsLocalBinds, dsSyntaxExpr )
import TcType
import TcEvidence
......@@ -465,9 +465,8 @@ dsCmd ids local_vars stack_ty res_ty (HsCmdIf mb_fun cond then_cmd else_cmd)
core_right = mk_right_expr then_ty else_ty (buildEnvStack else_ids stack_id)
core_if <- case mb_fun of
Just fun -> do { core_fun <- dsExpr fun
; matchEnvStack env_ids stack_id $
mkCoreApps core_fun [core_cond, core_left, core_right] }
Just fun -> do { fun_apps <- dsSyntaxExpr fun [core_cond, core_left, core_right]
; matchEnvStack env_ids stack_id fun_apps }
Nothing -> matchEnvStack env_ids stack_id $
mkIfThenElse core_cond core_left core_right
......@@ -782,7 +781,7 @@ dsCmdStmt ids local_vars out_ids (BodyStmt cmd _ _ c_ty) env_ids = do
-- It would be simpler and more consistent to do this using second,
-- but that's likely to be defined in terms of first.
dsCmdStmt ids local_vars out_ids (BindStmt pat cmd _ _) env_ids = do
dsCmdStmt ids local_vars out_ids (BindStmt pat cmd _ _ _) env_ids = do
let pat_ty = hsLPatType pat
(core_cmd, fv_cmd, env_ids1) <- dsfixCmd ids local_vars unitTy pat_ty cmd
let pat_vars = mkVarSet (collectPatBinders pat)
......@@ -1142,8 +1141,8 @@ collectl (L _ pat) bndrs
collectEvBinders ds
++ foldr collectl bndrs (hsConPatArgs ps)
go (LitPat _) = bndrs
go (NPat _ _ _) = bndrs
go (NPlusKPat (L _ n) _ _ _) = n : bndrs
go (NPat {}) = bndrs
go (NPlusKPat (L _ n) _ _ _ _ _) = n : bndrs
go (SigPatIn pat _) = collectl pat bndrs
go (SigPatOut pat _) = collectl pat bndrs
......
......@@ -8,7 +8,8 @@ Desugaring exporessions.
{-# LANGUAGE CPP #-}
module DsExpr ( dsExpr, dsLExpr, dsLocalBinds, dsValBinds, dsLit ) where
module DsExpr ( dsExpr, dsLExpr, dsLocalBinds
, dsValBinds, dsLit, dsSyntaxExpr ) where
#include "HsVersions.h"
......@@ -221,7 +222,8 @@ dsExpr (HsWrap co_fn e)
; return wrapped_e }
dsExpr (NegApp expr neg_expr)
= App <$> dsExpr neg_expr <*> dsLExpr expr
= do { expr' <- dsLExpr expr
; dsSyntaxExpr neg_expr [expr'] }
dsExpr (HsLam a_Match)
= uncurry mkLams <$> matchWrapper LambdaExpr Nothing a_Match
......@@ -354,8 +356,7 @@ dsExpr (HsIf mb_fun guard_expr then_expr else_expr)
; b1 <- dsLExpr then_expr
; b2 <- dsLExpr else_expr
; case mb_fun of
Just fun -> do { core_fun <- dsExpr fun
; return (mkCoreApps core_fun [pred,b1,b2]) }
Just fun -> dsSyntaxExpr fun [pred, b1, b2]
Nothing -> return $ mkIfThenElse pred b1 b2 }
dsExpr (HsMultiIf res_ty alts)
......@@ -398,10 +399,8 @@ dsExpr (ExplicitPArr ty xs) = do
dsExpr (ArithSeq expr witness seq)
= case witness of
Nothing -> dsArithSeq expr seq
Just fl -> do {
; fl' <- dsExpr fl
; newArithSeq <- dsArithSeq expr seq
; return (App fl' newArithSeq)}
Just fl -> do { newArithSeq <- dsArithSeq expr seq
; dsSyntaxExpr fl [newArithSeq] }
dsExpr (PArrSeq expr (FromTo from to))
= mkApps <$> dsExpr expr <*> mapM dsLExpr [from, to]
......@@ -741,6 +740,16 @@ dsExpr (HsRecFld {}) = panic "dsExpr:HsRecFld"
dsExpr (HsTypeOut {})
= panic "dsExpr: tried to desugar a naked type application argument (HsTypeOut)"
------------------------------
dsSyntaxExpr :: SyntaxExpr Id -> [CoreExpr] -> DsM CoreExpr
dsSyntaxExpr (SyntaxExpr { syn_expr = expr
, syn_arg_wraps = arg_wraps
, syn_res_wrap = res_wrap })
arg_exprs
= do { args <- zipWithM dsHsWrapper arg_wraps arg_exprs
; fun <- dsExpr expr
; dsHsWrapper res_wrap $ mkApps fun args }
findField :: [LHsRecField Id arg] -> Name -> [arg]
findField rbinds sel
= [hsRecFieldArg fld | L _ fld <- rbinds
......@@ -832,10 +841,9 @@ dsExplicitList elt_ty Nothing xs
; return (foldr (App . App (Var c)) folded_suffix prefix) }
dsExplicitList elt_ty (Just fln) xs
= do { fln' <- dsExpr fln
; list <- dsExplicitList elt_ty Nothing xs
= do { list <- dsExplicitList elt_ty Nothing xs
; dflags <- getDynFlags
; return (App (App fln' (mkIntExprInt dflags (length xs))) list) }
; dsSyntaxExpr fln [mkIntExprInt dflags (length xs), list] }
spanTail :: (a -> Bool) -> [a] -> ([a], [a])
spanTail f xs = (reverse rejected, reverse satisfying)
......@@ -882,25 +890,21 @@ dsDo stmts
go _ (BodyStmt rhs then_expr _ _) stmts
= do { rhs2 <- dsLExpr rhs
; warnDiscardedDoBindings rhs (exprType rhs2)
; then_expr2 <- dsExpr then_expr
; rest <- goL stmts
; return (mkApps then_expr2 [rhs2, rest]) }
; dsSyntaxExpr then_expr [rhs2, rest] }
go _ (LetStmt (L _ binds)) stmts
= do { rest <- goL stmts
; dsLocalBinds binds rest }
go _ (BindStmt pat rhs bind_op fail_op) stmts
go _ (BindStmt pat rhs bind_op fail_op res1_ty) stmts
= do { body <- goL stmts
; rhs' <- dsLExpr rhs
; bind_op' <- dsExpr bind_op
; var <- selectSimpleMatchVarL pat
; let bind_ty = exprType bind_op' -- rhs -> (pat -> res1) -> res2
res1_ty = funResultTy (funArgTy (funResultTy bind_ty))
; match <- matchSinglePat (Var var) (StmtCtxt DoExpr) pat
res1_ty (cantFailMatchResult body)
; match_code <- handle_failure pat match fail_op
; return (mkApps bind_op' [rhs', Lam var match_code]) }
; dsSyntaxExpr bind_op [rhs', Lam var match_code] }
go _ (ApplicativeStmt args mb_join body_ty) stmts
= do {
......@@ -915,7 +919,6 @@ dsDo stmts
arg_tys = map hsLPatType pats
; rhss' <- sequence rhss
; ops' <- mapM dsExpr (map fst args)
; let body' = noLoc $ HsDo DoExpr (noLoc stmts) body_ty
......@@ -926,30 +929,30 @@ dsDo stmts