Skip to content
Snippets Groups Projects
Commit 860c678c authored by Ryan Scott's avatar Ryan Scott
Browse files

Patch futhark-0.21.8

See ghc/ghc#21319.
parent e5ab8050
No related branches found
No related tags found
No related merge requests found
Pipeline #49841 passed
diff --git a/src/Futhark/IR/Mem.hs b/src/Futhark/IR/Mem.hs
index 577454f..9845796 100644
--- a/src/Futhark/IR/Mem.hs
+++ b/src/Futhark/IR/Mem.hs
@@ -3,6 +3,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
@@ -564,6 +565,7 @@ matchRetTypeToResult rettype result = do
matchReturnType rettype (map resSubExp result) result_ts
matchFunctionReturnType ::
+ forall rep inner.
(Mem rep inner, TC.Checkable rep) =>
[FunReturns] ->
Result ->
@@ -572,6 +574,7 @@ matchFunctionReturnType rettype result = do
matchRetTypeToResult rettype result
mapM_ (checkResultSubExp . resSubExp) result
where
+ checkResultSubExp :: SubExp -> TC.TypeM rep ()
checkResultSubExp Constant {} =
return ()
checkResultSubExp (Var v) = do
@@ -989,6 +992,7 @@ subExpReturns (Constant v) =
-- | The return information of an expression. This can be seen as the
-- "return type with memory annotations" of the expression.
expReturns ::
+ forall m rep inner.
( Monad m,
LocalScope rep m,
Mem rep inner
@@ -1074,6 +1078,7 @@ expReturns (WithAcc inputs lam) =
-- think WithAcc should perhaps have a return annotation like If.
pure (extReturns $ staticShapes $ drop num_accs $ lambdaReturnType lam)
where
+ inputReturns :: WithAccInput rep -> m [ExpReturns]
inputReturns (_, arrs, _) = mapM varReturns arrs
num_accs = length inputs
diff --git a/src/Futhark/IR/SegOp.hs b/src/Futhark/IR/SegOp.hs
index ef1ab6c..bce04e8 100644
--- a/src/Futhark/IR/SegOp.hs
+++ b/src/Futhark/IR/SegOp.hs
@@ -1525,12 +1525,14 @@ bottomUpSegOp (vtable, used) (Pat kpes) dec segop = Simplify $ do
--- Memory
kernelBodyReturns ::
+ forall rep inner m somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep ->
[ExpReturns] ->
m [ExpReturns]
kernelBodyReturns = zipWithM correct . kernelBodyResult
where
+ correct :: KernelResult -> ExpReturns -> m ExpReturns
correct (WriteReturns _ _ arr _) _ = varReturns arr
correct _ ret = return ret
diff --git a/src/Futhark/Optimise/DoubleBuffer.hs b/src/Futhark/Optimise/DoubleBuffer.hs
index f94b427..55fdac4 100644
--- a/src/Futhark/Optimise/DoubleBuffer.hs
+++ b/src/Futhark/Optimise/DoubleBuffer.hs
@@ -259,7 +259,7 @@ isArrayIn :: VName -> Param FParamMem -> Bool
isArrayIn x (Param _ _ (MemArray _ _ _ (ArrayIn y _))) = x == y
isArrayIn _ _ = False
-optimiseLoopBySwitching :: Constraints rep inner => OptimiseLoop rep
+optimiseLoopBySwitching :: forall rep inner. Constraints rep inner => OptimiseLoop rep
optimiseLoopBySwitching (Pat pes) merge (Body _ body_stms body_res) = do
((pat', merge', body'), outer_stms) <- runBuilder $ do
((buffered, body_stms'), (pes', merge', body_res')) <-
@@ -308,6 +308,10 @@ optimiseLoopBySwitching (Pat pes) merge (Body _ body_stms body_res) = do
([pe], [(param, arg)], [res])
)
+ maybeCopyInitial ::
+ M.Map VName VName ->
+ (Param FParamMem, SubExp) ->
+ Builder rep (Param FParamMem, SubExp)
maybeCopyInitial buffered (param@(Param _ _ (MemArray _ _ _ (ArrayIn mem _))), Var arg)
| Just mem' <- mem `M.lookup` buffered = do
arg_info <- lookupMemInfo arg
@@ -413,12 +417,17 @@ doubleBufferMergeParams ctx_and_res bound_in_loop =
_ -> pure NoBuffer
allocStms ::
+ forall rep inner.
Constraints rep inner =>
[(FParam rep, SubExp)] ->
[DoubleBuffer] ->
DoubleBufferM rep ([(FParam rep, SubExp)], [Stm rep])
allocStms merge = runWriterT . zipWithM allocation merge
where
+ allocation ::
+ (Param FParamMem, SubExp) ->
+ DoubleBuffer ->
+ WriterT [Stm rep] (DoubleBufferM rep) (Param FParamMem, SubExp)
allocation m@(Param attrs pname _, _) (BufferAlloc name size space b) = do
stms <- lift $
runBuilder_ $ do
diff --git a/src/Futhark/Pass/ExplicitAllocations.hs b/src/Futhark/Pass/ExplicitAllocations.hs
index 47a4e71..717ae14 100644
--- a/src/Futhark/Pass/ExplicitAllocations.hs
+++ b/src/Futhark/Pass/ExplicitAllocations.hs
@@ -368,6 +368,7 @@ allocInFParam param pspace =
return param {paramDec = MemAcc acc ispace ts u}
allocInMergeParams ::
+ forall fromrep torep inner a.
(Allocable fromrep torep inner) =>
[(FParam fromrep, SubExp)] ->
( [(FParam torep, SubExp)] ->
@@ -396,6 +397,12 @@ allocInMergeParams merge m = do
param_names = namesFromList $ map (paramName . fst) merge
anyIsLoopParam names = names `namesIntersect` param_names
+ scalarRes ::
+ DeclType ->
+ Space ->
+ IxFun ->
+ SubExp ->
+ WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
scalarRes param_t v_mem_space v_ixfun (Var res) = do
-- Try really hard to avoid copying needlessly, but the result
-- _must_ be in ScalarSpace and have the right index function.
@@ -472,6 +479,16 @@ allocInMergeParams merge m = do
)
allocInMergeParam (mergeparam, se) = doDefault mergeparam se =<< lift askDefaultSpace
+ doDefault ::
+ Param DeclType ->
+ SubExp ->
+ Space ->
+ WriterT ([Param FParamMem], [Param FParamMem])
+ (AllocM fromrep torep)
+ ( Param FParamMem,
+ SubExp,
+ SubExp -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
+ )
doDefault mergeparam se space = do
mergeparam' <- allocInFParam mergeparam space
return (mergeparam', se, linearFuncallArg (paramType mergeparam) space)
@@ -731,6 +748,7 @@ allocInLambda params body =
mkLambda params . allocInStms (bodyStms body) $ pure $ bodyResult body
allocInExp ::
+ forall fromrep torep inner.
(Allocable fromrep torep inner) =>
Exp fromrep ->
AllocM fromrep torep (Exp torep)
@@ -823,6 +841,7 @@ allocInExp (If cond tbranch0 fbranch0 (IfDec rets ifsort)) = do
allocInExp (WithAcc inputs bodylam) =
WithAcc <$> mapM onInput inputs <*> onLambda bodylam
where
+ onLambda :: Lambda fromrep -> AllocM fromrep torep (Lambda torep)
onLambda lam = do
params <- forM (lambdaParams lam) $ \(Param attrs pv t) ->
case t of
@@ -831,9 +850,15 @@ allocInExp (WithAcc inputs bodylam) =
_ -> error $ "Unexpected WithAcc lambda param: " ++ pretty (Param attrs pv t)
allocInLambda params (lambdaBody lam)
+ onInput :: WithAccInput fromrep -> AllocM fromrep torep (WithAccInput torep)
onInput (shape, arrs, op) =
(shape,arrs,) <$> traverse (onOp shape arrs) op
+ onOp ::
+ Shape ->
+ [VName] ->
+ (Lambda fromrep, [SubExp]) ->
+ AllocM fromrep torep (Lambda torep, [SubExp])
onOp accshape arrs (lam, nes) = do
let num_vs = length (lambdaReturnType lam)
num_is = shapeRank accshape
@@ -853,6 +878,11 @@ allocInExp (WithAcc inputs bodylam) =
Param attrs p . MemArray pt shape u . ArrayIn mem . IxFun.slice ixfun $
fmap pe64 $ Slice $ is ++ map sliceDim (shapeDims shape)
+ onXParam ::
+ [DimIndex SubExp] ->
+ Param Type ->
+ VName ->
+ AllocM fromrep torep (Param LParamMem)
onXParam _ (Param attrs p (Prim t)) _ =
pure $ Param attrs p (MemPrim t)
onXParam is (Param attrs p (Array pt shape u)) arr = do
@@ -861,6 +891,11 @@ allocInExp (WithAcc inputs bodylam) =
onXParam _ p _ =
error $ "Cannot handle MkAcc param: " ++ pretty p
+ onYParam ::
+ [DimIndex SubExp] ->
+ Param Type ->
+ VName ->
+ AllocM fromrep torep (Param LParamMem)
onYParam _ (Param attrs p (Prim t)) _ =
pure $ Param attrs p $ MemPrim t
onYParam is (Param attrs p (Array pt shape u)) arr = do
@@ -903,6 +938,7 @@ subExpIxFun Constant {} = return Nothing
subExpIxFun (Var v) = lookupIxFun v
addResCtxInIfBody ::
+ forall fromrep torep inner.
(Allocable fromrep torep inner) =>
[ExtType] ->
Body torep ->
@@ -922,6 +958,11 @@ addResCtxInIfBody ifrets (Body _ stms res) spaces substs = buildBody $ do
numCtxNeeded Array {} (Just (_, m)) = length m + 1
numCtxNeeded _ _ = 0
+ helper ::
+ Int ->
+ (Result, [BodyReturns], Result, [BodyReturns]) ->
+ (ExtType, SubExpRes, Maybe (ExtIxFun, [TPrimExp Int64 VName]), Maybe Space, Int) ->
+ AllocM fromrep torep (Result, [BodyReturns], Result, [BodyReturns])
helper
num_new_ctx
(ctx_acc, ctx_rets_acc, res_acc, res_rets_acc)
@@ -979,6 +1020,7 @@ addResCtxInIfBody ifrets (Body _ stms res) spaces substs = buildBody $ do
adjustExtPE k = fmap (adjustExt k)
mkSpaceOks ::
+ forall torep inner m.
(Mem torep inner, LocalScope torep m) =>
Int ->
Body torep ->
@@ -986,6 +1028,7 @@ mkSpaceOks ::
mkSpaceOks num_vals (Body _ stms res) =
inScopeOf stms $ mapM (mkSpaceOK . resSubExp) $ takeLast num_vals res
where
+ mkSpaceOK :: SubExp -> m (Maybe Space)
mkSpaceOK (Var v) = do
v_info <- lookupMemInfo v
case v_info of
@@ -1084,6 +1127,7 @@ mkLetNamesB'' names e = do
nohints = map (const NoHint) names
simplifiable ::
+ forall rep inner.
( Engine.SimplifiableRep rep,
ExpDec rep ~ (),
BodyDec rep ~ (),
@@ -1100,6 +1144,12 @@ simplifiable innerUsage simplifyInnerOp =
mkBodyS' _ stms res = return $ mkWiseBody () stms res
+ protectOp ::
+ forall inner'.
+ SubExp ->
+ Pat (Engine.VarWisdom, LetDec rep) ->
+ MemOp inner' ->
+ Maybe (Builder (Engine.Wise rep) ())
protectOp taken pat (Alloc size space) = Just $ do
tbody <- resultBodyM [size]
fbody <- resultBodyM [intConst Int64 0]
diff --git a/src/Futhark/Pass/ExtractKernels/Interchange.hs b/src/Futhark/Pass/ExtractKernels/Interchange.hs
index cedc438..c3fe5db 100644
--- a/src/Futhark/Pass/ExtractKernels/Interchange.hs
+++ b/src/Futhark/Pass/ExtractKernels/Interchange.hs
@@ -46,7 +46,7 @@ seqLoopStm (SeqLoop _ pat merge form body) =
Let pat (defAux ()) $ DoLoop merge form body
interchangeLoop ::
- (MonadBuilder m, LocalScope SOACS m) =>
+ (MonadBuilder m, LocalScope SOACS m, Rep m ~ SOACS) =>
(VName -> Maybe VName) ->
SeqLoop ->
LoopNesting ->
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment