Commit 392210bf authored by Richard Eisenberg's avatar Richard Eisenberg Committed by Ben Gamari
Browse files

GHCi support for levity-polymorphic join points

Fixes #16509.

See Note [Levity-polymorphic join points] in ByteCodeGen,
which tells the full story.

This commit also adds some comments and cleans some code
in the byte-code generator, as I was exploring around trying
to understand it.

test case: ghci/scripts/T16509
parent 8e60e3f0
...@@ -156,7 +156,11 @@ assembleOneBCO hsc_env pbco = do ...@@ -156,7 +156,11 @@ assembleOneBCO hsc_env pbco = do
return ubco' return ubco'
assembleBCO :: DynFlags -> ProtoBCO Name -> IO UnlinkedBCO assembleBCO :: DynFlags -> ProtoBCO Name -> IO UnlinkedBCO
assembleBCO dflags (ProtoBCO nm instrs bitmap bsize arity _origin _malloced) = do assembleBCO dflags (ProtoBCO { protoBCOName = nm
, protoBCOInstrs = instrs
, protoBCOBitmap = bitmap
, protoBCOBitmapSize = bsize
, protoBCOArity = arity }) = do
-- pass 1: collect up the offsets of the local labels. -- pass 1: collect up the offsets of the local labels.
let asm = mapM_ (assembleI dflags) instrs let asm = mapM_ (assembleI dflags) instrs
......
...@@ -26,6 +26,7 @@ import Platform ...@@ -26,6 +26,7 @@ import Platform
import Name import Name
import MkId import MkId
import Id import Id
import Var ( updateVarType )
import ForeignCall import ForeignCall
import HscTypes import HscTypes
import CoreUtils import CoreUtils
...@@ -61,7 +62,6 @@ import Data.Char ...@@ -61,7 +62,6 @@ import Data.Char
import UniqSupply import UniqSupply
import Module import Module
import Control.Arrow ( second )
import Control.Exception import Control.Exception
import Data.Array import Data.Array
...@@ -90,7 +90,7 @@ byteCodeGen hsc_env this_mod binds tycs mb_modBreaks ...@@ -90,7 +90,7 @@ byteCodeGen hsc_env this_mod binds tycs mb_modBreaks
(const ()) $ do (const ()) $ do
-- Split top-level binds into strings and others. -- Split top-level binds into strings and others.
-- See Note [generating code for top-level string literal bindings]. -- See Note [generating code for top-level string literal bindings].
let (strings, flatBinds) = partitionEithers $ do let (strings, flatBinds) = partitionEithers $ do -- list monad
(bndr, rhs) <- flattenBinds binds (bndr, rhs) <- flattenBinds binds
return $ case exprIsTickedString_maybe rhs of return $ case exprIsTickedString_maybe rhs of
Just str -> Left (bndr, str) Just str -> Left (bndr, str)
...@@ -181,29 +181,13 @@ coreExprToBCOs hsc_env this_mod expr ...@@ -181,29 +181,13 @@ coreExprToBCOs hsc_env this_mod expr
where dflags = hsc_dflags hsc_env where dflags = hsc_dflags hsc_env
-- The regular freeVars function gives more information than is useful to -- The regular freeVars function gives more information than is useful to
-- us here. simpleFreeVars does the impedance matching. -- us here. We need only the free variables, not everything in an FVAnn.
-- Historical note: At one point FVAnn was more sophisticated than just
-- a set. Now it isn't. So this function is much simpler. Keeping it around
-- so that if someone changes FVAnn, they will get a nice type error right
-- here.
simpleFreeVars :: CoreExpr -> AnnExpr Id DVarSet simpleFreeVars :: CoreExpr -> AnnExpr Id DVarSet
simpleFreeVars = go . freeVars simpleFreeVars = freeVars
where
go :: AnnExpr Id FVAnn -> AnnExpr Id DVarSet
go (ann, e) = (freeVarsOfAnn ann, go' e)
go' :: AnnExpr' Id FVAnn -> AnnExpr' Id DVarSet
go' (AnnVar id) = AnnVar id
go' (AnnLit lit) = AnnLit lit
go' (AnnLam bndr body) = AnnLam bndr (go body)
go' (AnnApp fun arg) = AnnApp (go fun) (go arg)
go' (AnnCase scrut bndr ty alts) = AnnCase (go scrut) bndr ty (map go_alt alts)
go' (AnnLet bind body) = AnnLet (go_bind bind) (go body)
go' (AnnCast expr (ann, co)) = AnnCast (go expr) (freeVarsOfAnn ann, co)
go' (AnnTick tick body) = AnnTick tick (go body)
go' (AnnType ty) = AnnType ty
go' (AnnCoercion co) = AnnCoercion co
go_alt (con, args, expr) = (con, args, go expr)
go_bind (AnnNonRec bndr rhs) = AnnNonRec bndr (go rhs)
go_bind (AnnRec pairs) = AnnRec (map (second go) pairs)
-- ----------------------------------------------------------------------------- -- -----------------------------------------------------------------------------
-- Compilation schema for the bytecode generator -- Compilation schema for the bytecode generator
...@@ -256,6 +240,7 @@ mkProtoBCO ...@@ -256,6 +240,7 @@ mkProtoBCO
-> name -> name
-> BCInstrList -> BCInstrList
-> Either [AnnAlt Id DVarSet] (AnnExpr Id DVarSet) -> Either [AnnAlt Id DVarSet] (AnnExpr Id DVarSet)
-- ^ original expression; for debugging only
-> Int -> Int
-> Word16 -> Word16
-> [StgWord] -> [StgWord]
...@@ -368,6 +353,9 @@ schemeR fvs (nm, rhs) ...@@ -368,6 +353,9 @@ schemeR fvs (nm, rhs)
-} -}
= schemeR_wrk fvs nm rhs (collect rhs) = schemeR_wrk fvs nm rhs (collect rhs)
-- If an expression is a lambda (after apply bcView), return the
-- list of arguments to the lambda (in R-to-L order) and the
-- underlying expression
collect :: AnnExpr Id DVarSet -> ([Var], AnnExpr' Id DVarSet) collect :: AnnExpr Id DVarSet -> ([Var], AnnExpr' Id DVarSet)
collect (_, e) = go [] e collect (_, e) = go [] e
where where
...@@ -382,8 +370,8 @@ collect (_, e) = go [] e ...@@ -382,8 +370,8 @@ collect (_, e) = go [] e
schemeR_wrk schemeR_wrk
:: [Id] :: [Id]
-> Id -> Id
-> AnnExpr Id DVarSet -> AnnExpr Id DVarSet -- expression e, for debugging only
-> ([Var], AnnExpr' Var DVarSet) -> ([Var], AnnExpr' Var DVarSet) -- result of collect on e
-> BcM (ProtoBCO Name) -> BcM (ProtoBCO Name)
schemeR_wrk fvs nm original_body (args, body) schemeR_wrk fvs nm original_body (args, body)
= do = do
...@@ -508,8 +496,16 @@ schemeE d s p e@(AnnLit lit) = returnUnboxedAtom d s p e (typeArgRep (litera ...@@ -508,8 +496,16 @@ schemeE d s p e@(AnnLit lit) = returnUnboxedAtom d s p e (typeArgRep (litera
schemeE d s p e@(AnnCoercion {}) = returnUnboxedAtom d s p e V schemeE d s p e@(AnnCoercion {}) = returnUnboxedAtom d s p e V
schemeE d s p e@(AnnVar v) schemeE d s p e@(AnnVar v)
-- See Note [Levity-polymorphic join points], step 3.
| isLPJoinPoint v = schemeT d s p $
AnnApp (bogus_fvs, AnnVar (protectLPJoinPointId v))
(bogus_fvs, AnnVar voidPrimId)
-- schemeT will call splitApp, dropping the fvs.
| isUnliftedType (idType v) = returnUnboxedAtom d s p e (bcIdArgRep v) | isUnliftedType (idType v) = returnUnboxedAtom d s p e (bcIdArgRep v)
| otherwise = schemeT d s p e | otherwise = schemeT d s p e
where
bogus_fvs = pprPanic "schemeE bogus_fvs" (ppr v)
schemeE d s p (AnnLet (AnnNonRec x (_,rhs)) (_,body)) schemeE d s p (AnnLet (AnnNonRec x (_,rhs)) (_,body))
| (AnnVar v, args_r_to_l) <- splitApp rhs, | (AnnVar v, args_r_to_l) <- splitApp rhs,
...@@ -534,19 +530,22 @@ schemeE d s p (AnnLet binds (_,body)) = do ...@@ -534,19 +530,22 @@ schemeE d s p (AnnLet binds (_,body)) = do
fvss = map (fvsToEnv p' . fst) rhss fvss = map (fvsToEnv p' . fst) rhss
-- See Note [Levity-polymorphic join points], step 2.
(xs',rhss') = zipWithAndUnzip protectLPJoinPointBind xs rhss
-- Sizes of free vars -- Sizes of free vars
size_w = trunc16W . idSizeW dflags size_w = trunc16W . idSizeW dflags
sizes = map (\rhs_fvs -> sum (map size_w rhs_fvs)) fvss sizes = map (\rhs_fvs -> sum (map size_w rhs_fvs)) fvss
-- the arity of each rhs -- the arity of each rhs
arities = map (genericLength . fst . collect) rhss arities = map (genericLength . fst . collect) rhss'
-- This p', d' defn is safe because all the items being pushed -- This p', d' defn is safe because all the items being pushed
-- are ptrs, so all have size 1 word. d' and p' reflect the stack -- are ptrs, so all have size 1 word. d' and p' reflect the stack
-- after the closures have been allocated in the heap (but not -- after the closures have been allocated in the heap (but not
-- filled in), and pointers to them parked on the stack. -- filled in), and pointers to them parked on the stack.
offsets = mkStackOffsets d (genericReplicate n_binds (wordSize dflags)) offsets = mkStackOffsets d (genericReplicate n_binds (wordSize dflags))
p' = Map.insertList (zipE xs offsets) p p' = Map.insertList (zipE xs' offsets) p
d' = d + wordsToBytes dflags n_binds d' = d + wordsToBytes dflags n_binds
zipE = zipEqual "schemeE" zipE = zipEqual "schemeE"
...@@ -587,7 +586,7 @@ schemeE d s p (AnnLet binds (_,body)) = do ...@@ -587,7 +586,7 @@ schemeE d s p (AnnLet binds (_,body)) = do
compile_binds = compile_binds =
[ compile_bind d' fvs x rhs size arity (trunc16W n) [ compile_bind d' fvs x rhs size arity (trunc16W n)
| (fvs, x, rhs, size, arity, n) <- | (fvs, x, rhs, size, arity, n) <-
zip6 fvss xs rhss sizes arities [n_binds, n_binds-1 .. 1] zip6 fvss xs' rhss' sizes arities [n_binds, n_binds-1 .. 1]
] ]
body_code <- schemeE d' s p' body body_code <- schemeE d' s p' body
thunk_codes <- sequence compile_binds thunk_codes <- sequence compile_binds
...@@ -681,6 +680,30 @@ schemeE _ _ _ expr ...@@ -681,6 +680,30 @@ schemeE _ _ _ expr
= pprPanic "ByteCodeGen.schemeE: unhandled case" = pprPanic "ByteCodeGen.schemeE: unhandled case"
(pprCoreExpr (deAnnotate' expr)) (pprCoreExpr (deAnnotate' expr))
-- Is this Id a levity-polymorphic join point?
-- See Note [Levity-polymorphic join points], step 1
isLPJoinPoint :: Id -> Bool
isLPJoinPoint x = isJoinId x &&
isNothing (isLiftedType_maybe (idType x))
-- If necessary, modify this Id and body to protect levity-polymorphic join points.
-- See Note [Levity-polymorphic join points], step 2.
protectLPJoinPointBind :: Id -> AnnExpr Id DVarSet -> (Id, AnnExpr Id DVarSet)
protectLPJoinPointBind x rhs@(fvs, _)
| isLPJoinPoint x
= (protectLPJoinPointId x, (fvs, AnnLam voidArgId rhs))
| otherwise
= (x, rhs)
-- Update an Id's type to take a Void# argument.
-- Precondition: the Id is a levity-polymorphic join point.
-- See Note [Levity-polymorphic join points]
protectLPJoinPointId :: Id -> Id
protectLPJoinPointId x
= ASSERT( isLPJoinPoint x )
updateVarType (voidPrimTy `mkVisFunTy`) x
{- {-
Ticked Expressions Ticked Expressions
------------------ ------------------
...@@ -689,6 +712,41 @@ schemeE _ _ _ expr ...@@ -689,6 +712,41 @@ schemeE _ _ _ expr
the code. When we find such a thing, we pull out the useful information, the code. When we find such a thing, we pull out the useful information,
and then compile the code as if it was just the expression E. and then compile the code as if it was just the expression E.
Note [Levity-polymorphic join points]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
A join point variable is essentially a goto-label: it is, for example,
never used as an argument to another function, and it is called only
in tail position. See Note [Join points] and Note [Invariants on join points],
both in CoreSyn. Because join points do not compile to true, red-blooded
variables (with, e.g., registers allocated to them), they are allowed
to be levity-polymorphic. (See invariant #6 in Note [Invariants on join points]
in CoreSyn.)
However, in this byte-code generator, join points *are* treated just as
ordinary variables. There is no check whether a binding is for a join point
or not; they are all treated uniformly. (Perhaps there is a missed optimization
opportunity here, but that is beyond the scope of my (Richard E's) Thursday.)
We thus must have *some* strategy for dealing with levity-polymorphic join
points (LPJPs), because we cannot have a levity-polymorphic variable.
(Not having such a strategy led to #16509, which panicked in the isUnliftedType
check in the AnnVar case of schemeE.) Here is the strategy:
1. Detect LPJPs. This is done in isLPJoinPoint.
2. When binding an LPJP, add a `\ (_ :: Void#) ->` to its RHS, and modify the
type to tack on a `Void# ->`. (Void# is written voidPrimTy within GHC.)
Note that functions are never levity-polymorphic, so this transformation
changes an LPJP to a non-levity-polymorphic join point. This is done
in protectLPJoinPointBind, called from the AnnLet case of schemeE.
3. At an occurrence of an LPJP, add an application to void# (called voidPrimId),
being careful to note the new type of the LPJP. This is done in the AnnVar
case of schemeE, with help from protectLPJoinPointId.
It's a bit hacky, but it works well in practice and is local. I suspect the
Right Fix is to take advantage of join points as goto-labels.
-} -}
-- Compile code to do a tail call. Specifically, push the fn, -- Compile code to do a tail call. Specifically, push the fn,
......
...@@ -45,7 +45,7 @@ data ProtoBCO a ...@@ -45,7 +45,7 @@ data ProtoBCO a
protoBCOBitmap :: [StgWord], protoBCOBitmap :: [StgWord],
protoBCOBitmapSize :: Word16, protoBCOBitmapSize :: Word16,
protoBCOArity :: Int, protoBCOArity :: Int,
-- what the BCO came from -- what the BCO came from, for debugging only
protoBCOExpr :: Either [AnnAlt Id DVarSet] (AnnExpr Id DVarSet), protoBCOExpr :: Either [AnnAlt Id DVarSet] (AnnExpr Id DVarSet),
-- malloc'd pointers -- malloc'd pointers
protoBCOFFIs :: [FFIInfo] protoBCOFFIs :: [FFIInfo]
...@@ -179,7 +179,13 @@ data BCInstr ...@@ -179,7 +179,13 @@ data BCInstr
-- Printing bytecode instructions -- Printing bytecode instructions
instance Outputable a => Outputable (ProtoBCO a) where instance Outputable a => Outputable (ProtoBCO a) where
ppr (ProtoBCO name instrs bitmap bsize arity origin ffis) ppr (ProtoBCO { protoBCOName = name
, protoBCOInstrs = instrs
, protoBCOBitmap = bitmap
, protoBCOBitmapSize = bsize
, protoBCOArity = arity
, protoBCOExpr = origin
, protoBCOFFIs = ffis })
= (text "ProtoBCO" <+> ppr name <> char '#' <> int arity = (text "ProtoBCO" <+> ppr name <> char '#' <> int arity
<+> text (show ffis) <> colon) <+> text (show ffis) <> colon)
$$ nest 3 (case origin of $$ nest 3 (case origin of
......
...@@ -64,7 +64,7 @@ isNvUnaryType ty ...@@ -64,7 +64,7 @@ isNvUnaryType ty
= False = False
-- INVARIANT: the result list is never empty. -- INVARIANT: the result list is never empty.
typePrimRepArgs :: Type -> [PrimRep] typePrimRepArgs :: HasDebugCallStack => Type -> [PrimRep]
typePrimRepArgs ty typePrimRepArgs ty
| [] <- reps | [] <- reps
= [VoidRep] = [VoidRep]
......
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ViewPatterns #-}
module PatternPanic where
pattern TestPat :: (Int, Int)
pattern TestPat <- (isSameRef -> True, 0)
isSameRef :: Int -> Bool
isSameRef e | 0 <- e = True
isSameRef _ = False
...@@ -295,6 +295,7 @@ test('T11606', normal, ghci_script, ['T11606.script']) ...@@ -295,6 +295,7 @@ test('T11606', normal, ghci_script, ['T11606.script'])
test('T16089', normal, ghci_script, ['T16089.script']) test('T16089', normal, ghci_script, ['T16089.script'])
test('T14828', normal, ghci_script, ['T14828.script']) test('T14828', normal, ghci_script, ['T14828.script'])
test('T16376', normal, ghci_script, ['T16376.script']) test('T16376', normal, ghci_script, ['T16376.script'])
test('T16509', normal, ghci_script, ['T16509.script'])
test('T16527', normal, ghci_script, ['T16527.script']) test('T16527', normal, ghci_script, ['T16527.script'])
test('T16569', normal, ghci_script, ['T16569.script']) test('T16569', normal, ghci_script, ['T16569.script'])
test('T16767', normal, ghci_script, ['T16767.script']) test('T16767', normal, ghci_script, ['T16767.script'])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment