diff --git a/patches/futhark-0.21.13.patch b/patches/futhark-0.22.1.patch similarity index 70% rename from patches/futhark-0.21.13.patch rename to patches/futhark-0.22.1.patch index 2b8dc8f4a6b2b1d533009d6880984a7e9c13485b..1e79fe23f1f4768ee1bf084f3e22cce1ae104e73 100644 --- a/patches/futhark-0.21.13.patch +++ b/patches/futhark-0.22.1.patch @@ -18,7 +18,7 @@ index 009d263..ad29459 100644 tell $ buildFGBody body pure body diff --git a/src/Futhark/IR/Mem.hs b/src/Futhark/IR/Mem.hs -index fff7d7f..47039f9 100644 +index 97793b3..57e1a6e 100644 --- a/src/Futhark/IR/Mem.hs +++ b/src/Futhark/IR/Mem.hs @@ -3,6 +3,7 @@ @@ -29,7 +29,7 @@ index fff7d7f..47039f9 100644 {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} -@@ -566,6 +567,7 @@ matchRetTypeToResult rettype result = do +@@ -567,6 +568,7 @@ matchRetTypeToResult rettype result = do matchReturnType rettype (map resSubExp result) result_ts matchFunctionReturnType :: @@ -37,7 +37,7 @@ index fff7d7f..47039f9 100644 (Mem rep inner, TC.Checkable rep) => [FunReturns] -> Result -> -@@ -574,6 +576,7 @@ matchFunctionReturnType rettype result = do +@@ -575,6 +577,7 @@ matchFunctionReturnType rettype result = do matchRetTypeToResult rettype result mapM_ (checkResultSubExp . resSubExp) result where @@ -45,15 +45,15 @@ index fff7d7f..47039f9 100644 checkResultSubExp Constant {} = pure () checkResultSubExp (Var v) = do -@@ -996,6 +999,7 @@ subExpReturns (Constant v) = +@@ -1014,6 +1017,7 @@ subExpReturns (Constant v) = -- | The return information of an expression. This can be seen as the -- "return type with memory annotations" of the expression. expReturns :: + forall m rep inner. - ( Monad m, - LocalScope rep m, - Mem rep inner -@@ -1086,6 +1090,7 @@ expReturns (WithAcc inputs lam) = + (LocalScope rep m, Mem rep inner) => + Exp rep -> + m [ExpReturns] +@@ -1093,6 +1097,7 @@ expReturns (WithAcc inputs lam) = -- think WithAcc should perhaps have a return annotation like If. pure (extReturns $ staticShapes $ drop num_accs $ lambdaReturnType lam) where @@ -62,7 +62,7 @@ index fff7d7f..47039f9 100644 num_accs = length inputs diff --git a/src/Futhark/IR/SOACS/Simplify.hs b/src/Futhark/IR/SOACS/Simplify.hs -index 086c695..50d0ff4 100644 +index 8fc0097..a5e70ff 100644 --- a/src/Futhark/IR/SOACS/Simplify.hs +++ b/src/Futhark/IR/SOACS/Simplify.hs @@ -3,6 +3,7 @@ @@ -73,7 +73,7 @@ index 086c695..50d0ff4 100644 {-# OPTIONS_GHC -fno-warn-orphans #-} module Futhark.IR.SOACS.Simplify -@@ -777,7 +778,7 @@ arrayOps = mconcat . map onStm . stmsToList . bodyStms +@@ -761,7 +762,7 @@ arrayOps = mconcat . map onStm . stmsToList . bodyStms tell $ arrayOps $ lambdaBody lam pure lam walker = @@ -82,7 +82,7 @@ index 086c695..50d0ff4 100644 { walkOnBody = const $ modify . (<>) . arrayOps, walkOnOp = modify . (<>) . onOp } -@@ -800,7 +801,7 @@ replaceArrayOps substs (Body _ stms res) = +@@ -784,7 +785,7 @@ replaceArrayOps substs (Body _ stms res) = fromArrayOp op' onExp _ cs e = (cs, mapExp mapper e) mapper = @@ -92,10 +92,10 @@ index 086c695..50d0ff4 100644 mapOnOp = pure . onOp } diff --git a/src/Futhark/IR/SegOp.hs b/src/Futhark/IR/SegOp.hs -index b845421..b1741df 100644 +index cbdf647..a3154ce 100644 --- a/src/Futhark/IR/SegOp.hs +++ b/src/Futhark/IR/SegOp.hs -@@ -1562,12 +1562,14 @@ bottomUpSegOp (vtable, used) (Pat kpes) dec segop = Simplify $ do +@@ -1477,12 +1477,14 @@ bottomUpSegOp (vtable, used) (Pat kpes) dec segop = Simplify $ do --- Memory kernelBodyReturns :: @@ -200,8 +200,29 @@ index 1a0c518..30aa20d 100644 allocation m@(Param attrs pname _, _) (BufferAlloc name size space b) = do stms <- lift $ runBuilder_ $ do +diff --git a/src/Futhark/Optimise/EntryPointMem.hs b/src/Futhark/Optimise/EntryPointMem.hs +index c082833..0c766bd 100644 +--- a/src/Futhark/Optimise/EntryPointMem.hs ++++ b/src/Futhark/Optimise/EntryPointMem.hs +@@ -58,7 +58,7 @@ optimiseFun consts_table fd = + let substs = mconcat $ map (mkSubst . resSubExp) res + in Body dec stms $ substituteNames substs res + +-entryPointMem :: Mem rep inner => Pass rep rep ++entryPointMem :: forall rep inner. Mem rep inner => Pass rep rep + entryPointMem = + Pass + { passName = "Entry point memory optimisation", +@@ -66,6 +66,7 @@ entryPointMem = + passFunction = intraproceduralTransformationWithConsts pure onFun + } + where ++ onFun :: Stms rep -> FunDef rep -> PassM (FunDef rep) + onFun consts fd = pure $ optimiseFun (mkTable consts) fd + + -- | The pass for GPU representation. diff --git a/src/Futhark/Optimise/GenRedOpt.hs b/src/Futhark/Optimise/GenRedOpt.hs -index 7e293f0..5f845ec 100644 +index 431a06c..c87f578 100644 --- a/src/Futhark/Optimise/GenRedOpt.hs +++ b/src/Futhark/Optimise/GenRedOpt.hs @@ -1,4 +1,5 @@ @@ -286,7 +307,7 @@ index 1c86fba..8533d55 100644 tapBottomUp :: ForwardingM rep a -> ForwardingM rep (a, BottomUp rep) tapBottomUp m = do diff --git a/src/Futhark/Optimise/InPlaceLowering/SubstituteIndices.hs b/src/Futhark/Optimise/InPlaceLowering/SubstituteIndices.hs -index 858e94e..1f46335 100644 +index 23b96da..96e720f 100644 --- a/src/Futhark/Optimise/InPlaceLowering/SubstituteIndices.hs +++ b/src/Futhark/Optimise/InPlaceLowering/SubstituteIndices.hs @@ -1,5 +1,7 @@ @@ -297,7 +318,7 @@ index 858e94e..1f46335 100644 -- | This module exports facilities for transforming array accesses in -- a list of 'Stm's (intended to be the bindings in a body). The -@@ -88,7 +90,7 @@ substituteIndicesInStm substs (Let pat rep e) = do +@@ -87,7 +89,7 @@ substituteIndicesInStm substs (Let pat rep e) = do addStm $ Let pat rep e' pure substs @@ -306,7 +327,7 @@ index 858e94e..1f46335 100644 (MonadBuilder m, Buildable (Rep m), Aliased (Rep m)) => IndexSubstitutions -> Exp (Rep m) -> -@@ -108,7 +110,7 @@ substituteIndicesInExp substs (Op op) = do +@@ -107,7 +109,7 @@ substituteIndicesInExp substs (Op op) = do substituteIndicesInExp substs e = do substs' <- copyAnyConsumed e let substitute = @@ -345,11 +366,11 @@ index 40145bf..8138e94 100644 } diff --git a/src/Futhark/Optimise/Simplify/Rules.hs b/src/Futhark/Optimise/Simplify/Rules.hs -index 1e76897..a40e830 100644 +index 6f762c1..bf4ab89 100644 --- a/src/Futhark/Optimise/Simplify/Rules.hs +++ b/src/Futhark/Optimise/Simplify/Rules.hs -@@ -2,6 +2,8 @@ - {-# LANGUAGE OverloadedStrings #-} +@@ -1,6 +1,8 @@ + {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE ScopedTypeVariables #-} @@ -357,7 +378,7 @@ index 1e76897..a40e830 100644 -- | This module defines a collection of simplification rules, as per -- "Futhark.Optimise.Simplify.Rule". They are used in the -@@ -403,7 +405,7 @@ withAccTopDown vtable (Let pat aux (WithAcc inputs lam)) = Simplify . auxing aux +@@ -198,7 +200,7 @@ withAccTopDown vtable (Let pat aux (WithAcc inputs lam)) = Simplify . auxing aux pure $ Just x withAccTopDown _ _ = Skip @@ -366,7 +387,7 @@ index 1e76897..a40e830 100644 elimUpdates get_rid_of = flip runState mempty . onBody where onBody body = do -@@ -419,7 +421,7 @@ elimUpdates get_rid_of = flip runState mempty . onBody +@@ -214,7 +216,7 @@ elimUpdates get_rid_of = flip runState mempty . onBody onExp = mapExpM mapper where mapper = @@ -376,7 +397,7 @@ index 1e76897..a40e830 100644 mapOnBody = \_ body -> onBody body } diff --git a/src/Futhark/Optimise/Sink.hs b/src/Futhark/Optimise/Sink.hs -index 5be2e89..67a82a7 100644 +index 4910a27..989664a 100644 --- a/src/Futhark/Optimise/Sink.hs +++ b/src/Futhark/Optimise/Sink.hs @@ -1,6 +1,8 @@ @@ -388,7 +409,7 @@ index 5be2e89..67a82a7 100644 -- | "Sinking" is conceptually the opposite of hoisting. The idea is -- to take code that looks like this: -@@ -142,7 +144,7 @@ optimiseLoop onOp vtable sinking (merge, form, body0) +@@ -145,7 +147,7 @@ optimiseLoop onOp vtable sinking (merge, form, body0) stm = Let pat aux e in stm <| stms @@ -397,7 +418,7 @@ index 5be2e89..67a82a7 100644 Constraints rep => Sinker rep (Op rep) -> SymbolTable rep -> -@@ -203,7 +205,7 @@ optimiseStms onOp init_vtable init_sinking all_stms free_in_res = +@@ -209,7 +211,7 @@ optimiseStms onOp init_vtable init_sinking all_stms free_in_res = where vtable' = ST.insertStm stm vtable mapper = @@ -407,7 +428,7 @@ index 5be2e89..67a82a7 100644 let (body', sunk) = optimiseBody diff --git a/src/Futhark/Optimise/TileLoops.hs b/src/Futhark/Optimise/TileLoops.hs -index 0f71f19..99cb8de 100644 +index fabadfd..2fe01d3 100644 --- a/src/Futhark/Optimise/TileLoops.hs +++ b/src/Futhark/Optimise/TileLoops.hs @@ -1,6 +1,7 @@ @@ -476,7 +497,7 @@ index 8e7e43e..99176c5 100644 mapOnOp = mapSOACM soac_mapper } diff --git a/src/Futhark/Pass/ExpandAllocations.hs b/src/Futhark/Pass/ExpandAllocations.hs -index 0538f91..8c9278c 100644 +index f077c74..22a284a 100644 --- a/src/Futhark/Pass/ExpandAllocations.hs +++ b/src/Futhark/Pass/ExpandAllocations.hs @@ -1,6 +1,7 @@ @@ -487,7 +508,7 @@ index 0538f91..8c9278c 100644 -- | Expand allocations inside of maps when possible. module Futhark.Pass.ExpandAllocations (expandAllocations) where -@@ -111,7 +112,7 @@ transformStm (Let pat aux e) = do +@@ -106,7 +107,7 @@ transformStm (Let pat aux e) = do pure $ stms <> oneStm (Let pat aux e') where transform = @@ -496,7 +517,7 @@ index 0538f91..8c9278c 100644 { mapOnBody = \scope -> localScope scope . transformBody } -@@ -392,7 +393,7 @@ extractStmAllocations user bound_outside bound_kernel stm = do +@@ -387,7 +388,7 @@ extractStmAllocations user bound_outside bound_kernel stm = do pure $ Just $ stm {stmExp = e} where expMapper user' = @@ -505,7 +526,7 @@ index 0538f91..8c9278c 100644 { mapOnBody = const $ onBody user', mapOnOp = onOp user' } -@@ -699,7 +700,7 @@ offsetMemoryInExp (DoLoop merge form body) = do +@@ -692,7 +693,7 @@ offsetMemoryInExp (DoLoop merge form body) = do offsetMemoryInExp e = mapExpM recurse e where recurse = @@ -514,176 +535,8 @@ index 0538f91..8c9278c 100644 { mapOnBody = \bscope -> localScope bscope . offsetMemoryInBody, mapOnBranchType = offsetMemoryInBodyReturns, mapOnOp = onOp -diff --git a/src/Futhark/Pass/ExplicitAllocations.hs b/src/Futhark/Pass/ExplicitAllocations.hs -index 9277011..5bfc0ae 100644 ---- a/src/Futhark/Pass/ExplicitAllocations.hs -+++ b/src/Futhark/Pass/ExplicitAllocations.hs -@@ -369,6 +369,7 @@ allocInFParam param pspace = - pure param {paramDec = MemAcc acc ispace ts u} - - allocInMergeParams :: -+ forall fromrep torep inner a. - (Allocable fromrep torep inner) => - [(FParam fromrep, SubExp)] -> - ( [(FParam torep, SubExp)] -> -@@ -397,6 +398,12 @@ allocInMergeParams merge m = do - param_names = namesFromList $ map (paramName . fst) merge - anyIsLoopParam names = names `namesIntersect` param_names - -+ scalarRes :: -+ DeclType -> -+ Space -> -+ IxFun -> -+ SubExp -> -+ WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp - scalarRes param_t v_mem_space v_ixfun (Var res) = do - -- Try really hard to avoid copying needlessly, but the result - -- _must_ be in ScalarSpace and have the right index function. -@@ -473,6 +480,16 @@ allocInMergeParams merge m = do - ) - allocInMergeParam (mergeparam, se) = doDefault mergeparam se =<< lift askDefaultSpace - -+ doDefault :: -+ Param DeclType -> -+ SubExp -> -+ Space -> -+ WriterT ([Param FParamMem], [Param FParamMem]) -+ (AllocM fromrep torep) -+ ( Param FParamMem, -+ SubExp, -+ SubExp -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp -+ ) - doDefault mergeparam se space = do - mergeparam' <- allocInFParam mergeparam space - pure (mergeparam', se, linearFuncallArg (paramType mergeparam) space) -@@ -737,6 +754,7 @@ allocInLambda params body = - mkLambda params . allocInStms (bodyStms body) $ pure $ bodyResult body - - allocInExp :: -+ forall fromrep torep inner. - (Allocable fromrep torep inner) => - Exp fromrep -> - AllocM fromrep torep (Exp torep) -@@ -829,6 +847,7 @@ allocInExp (If cond tbranch0 fbranch0 (IfDec rets ifsort)) = do - allocInExp (WithAcc inputs bodylam) = - WithAcc <$> mapM onInput inputs <*> onLambda bodylam - where -+ onLambda :: Lambda fromrep -> AllocM fromrep torep (Lambda torep) - onLambda lam = do - params <- forM (lambdaParams lam) $ \(Param attrs pv t) -> - case t of -@@ -837,9 +856,15 @@ allocInExp (WithAcc inputs bodylam) = - _ -> error $ "Unexpected WithAcc lambda param: " ++ pretty (Param attrs pv t) - allocInLambda params (lambdaBody lam) - -+ onInput :: WithAccInput fromrep -> AllocM fromrep torep (WithAccInput torep) - onInput (shape, arrs, op) = - (shape,arrs,) <$> traverse (onOp shape arrs) op - -+ onOp :: -+ Shape -> -+ [VName] -> -+ (Lambda fromrep, [SubExp]) -> -+ AllocM fromrep torep (Lambda torep, [SubExp]) - onOp accshape arrs (lam, nes) = do - let num_vs = length (lambdaReturnType lam) - num_is = shapeRank accshape -@@ -861,6 +886,11 @@ allocInExp (WithAcc inputs bodylam) = - Slice $ - is ++ map sliceDim (shapeDims shape) - -+ onXParam :: -+ [DimIndex SubExp] -> -+ Param Type -> -+ VName -> -+ AllocM fromrep torep (Param LParamMem) - onXParam _ (Param attrs p (Prim t)) _ = - pure $ Param attrs p (MemPrim t) - onXParam is (Param attrs p (Array pt shape u)) arr = do -@@ -869,6 +899,11 @@ allocInExp (WithAcc inputs bodylam) = - onXParam _ p _ = - error $ "Cannot handle MkAcc param: " ++ pretty p - -+ onYParam :: -+ [DimIndex SubExp] -> -+ Param Type -> -+ VName -> -+ AllocM fromrep torep (Param LParamMem) - onYParam _ (Param attrs p (Prim t)) _ = - pure $ Param attrs p $ MemPrim t - onYParam is (Param attrs p (Array pt shape u)) arr = do -@@ -911,6 +946,7 @@ subExpIxFun Constant {} = pure Nothing - subExpIxFun (Var v) = lookupIxFun v - - addResCtxInIfBody :: -+ forall fromrep torep inner. - (Allocable fromrep torep inner) => - [ExtType] -> - Body torep -> -@@ -930,6 +966,11 @@ addResCtxInIfBody ifrets (Body _ stms res) spaces substs = buildBody $ do - numCtxNeeded Array {} (Just (_, m)) = length m + 1 - numCtxNeeded _ _ = 0 - -+ helper :: -+ Int -> -+ (Result, [BodyReturns], Result, [BodyReturns]) -> -+ (ExtType, SubExpRes, Maybe (ExtIxFun, [TPrimExp Int64 VName]), Maybe Space, Int) -> -+ AllocM fromrep torep (Result, [BodyReturns], Result, [BodyReturns]) - helper - num_new_ctx - (ctx_acc, ctx_rets_acc, res_acc, res_rets_acc) -@@ -989,6 +1030,7 @@ addResCtxInIfBody ifrets (Body _ stms res) spaces substs = buildBody $ do - adjustExtPE k = fmap (adjustExt k) - - mkSpaceOks :: -+ forall torep inner m. - (Mem torep inner, LocalScope torep m) => - Int -> - Body torep -> -@@ -996,6 +1038,7 @@ mkSpaceOks :: - mkSpaceOks num_vals (Body _ stms res) = - inScopeOf stms $ mapM (mkSpaceOK . resSubExp) $ takeLast num_vals res - where -+ mkSpaceOK :: SubExp -> m (Maybe Space) - mkSpaceOK (Var v) = do - v_info <- lookupMemInfo v - case v_info of -@@ -1094,6 +1137,7 @@ mkLetNamesB'' names e = do - nohints = map (const NoHint) names - - simplifiable :: -+ forall rep inner. - ( Engine.SimplifiableRep rep, - ExpDec rep ~ (), - BodyDec rep ~ (), -@@ -1110,6 +1154,12 @@ simplifiable innerUsage simplifyInnerOp = - - mkBodyS' _ stms res = pure $ mkWiseBody () stms res - -+ protectOp :: -+ forall inner'. -+ SubExp -> -+ Pat (Engine.VarWisdom, LetDec rep) -> -+ MemOp inner' -> -+ Maybe (Builder (Engine.Wise rep) ()) - protectOp taken pat (Alloc size space) = Just $ do - tbody <- resultBodyM [size] - fbody <- resultBodyM [intConst Int64 0] -diff --git a/src/Futhark/Pass/ExtractKernels/Interchange.hs b/src/Futhark/Pass/ExtractKernels/Interchange.hs -index b3e93fc..1f0fa5b 100644 ---- a/src/Futhark/Pass/ExtractKernels/Interchange.hs -+++ b/src/Futhark/Pass/ExtractKernels/Interchange.hs -@@ -46,7 +46,7 @@ seqLoopStm (SeqLoop _ pat merge form body) = - Let pat (defAux ()) $ DoLoop merge form body - - interchangeLoop :: -- (MonadBuilder m, LocalScope SOACS m) => -+ (MonadBuilder m, LocalScope SOACS m, Rep m ~ SOACS) => - (VName -> Maybe VName) -> - SeqLoop -> - LoopNesting -> diff --git a/src/Futhark/Pass/KernelBabysitting.hs b/src/Futhark/Pass/KernelBabysitting.hs -index 13d092f..209dc12 100644 +index 3fc8867..5f293eb 100644 --- a/src/Futhark/Pass/KernelBabysitting.hs +++ b/src/Futhark/Pass/KernelBabysitting.hs @@ -1,5 +1,6 @@ @@ -693,7 +546,7 @@ index 13d092f..209dc12 100644 -- | Do various kernel optimisations - mostly related to coalescing. module Futhark.Pass.KernelBabysitting (babysitKernels) where -@@ -105,7 +106,7 @@ transformStm expmap (Let pat aux e) = do +@@ -104,7 +105,7 @@ transformStm expmap (Let pat aux e) = do transform :: ExpMap -> Mapper GPU GPU BabysitM transform expmap = @@ -702,7 +555,7 @@ index 13d092f..209dc12 100644 transformKernelBody :: ExpMap -> -@@ -212,7 +213,7 @@ traverseKernelBodyArrayIndexes free_ker_vars thread_variant outer_scope f (Kerne +@@ -194,7 +195,7 @@ traverseKernelBodyArrayIndexes free_ker_vars thread_variant outer_scope f (Kerne onOp _ op = pure op mapper ctx =