Commit b1d1c652 authored by Michal Terepeta's avatar Michal Terepeta Committed by Ben Gamari

Support MO_{Add,Sub}IntC and MO_Add2 in the LLVM backend

This includes:

- Adding new LlvmType called LMStructP that represents an unpacked
  struct (this is necessary since LLVM's instructions the
  llvm.sadd.with.overflow.* return an unpacked struct).

- Modifications to LlvmCodeGen.CodeGen to generate the LLVM
  instructions for the primops.

- Modifications to StgCmmPrim to actually use those three instructions
  if we use the LLVM backend (so far they were only used for NCG).

Test Plan: validate

Reviewers: austin, rwbarton, bgamari

Reviewed By: bgamari

Subscribers: thomie, bgamari

Differential Revision: https://phabricator.haskell.org/D991

GHC Trac Issues: #9430
parent 69beef56
...@@ -811,13 +811,16 @@ callishPrimOpSupported dflags op ...@@ -811,13 +811,16 @@ callishPrimOpSupported dflags op
WordQuotRem2Op | ncg && x86ish -> Left (MO_U_QuotRem2 (wordWidth dflags)) WordQuotRem2Op | ncg && x86ish -> Left (MO_U_QuotRem2 (wordWidth dflags))
| otherwise -> Right (genericWordQuotRem2Op dflags) | otherwise -> Right (genericWordQuotRem2Op dflags)
WordAdd2Op | ncg && x86ish -> Left (MO_Add2 (wordWidth dflags)) WordAdd2Op | (ncg && x86ish)
|| llvm -> Left (MO_Add2 (wordWidth dflags))
| otherwise -> Right genericWordAdd2Op | otherwise -> Right genericWordAdd2Op
IntAddCOp | ncg && x86ish -> Left (MO_AddIntC (wordWidth dflags)) IntAddCOp | (ncg && x86ish)
|| llvm -> Left (MO_AddIntC (wordWidth dflags))
| otherwise -> Right genericIntAddCOp | otherwise -> Right genericIntAddCOp
IntSubCOp | ncg && x86ish -> Left (MO_SubIntC (wordWidth dflags)) IntSubCOp | (ncg && x86ish)
|| llvm -> Left (MO_SubIntC (wordWidth dflags))
| otherwise -> Right genericIntSubCOp | otherwise -> Right genericIntSubCOp
WordMul2Op | ncg && x86ish -> Left (MO_U_Mul2 (wordWidth dflags)) WordMul2Op | ncg && x86ish -> Left (MO_U_Mul2 (wordWidth dflags))
...@@ -828,7 +831,9 @@ callishPrimOpSupported dflags op ...@@ -828,7 +831,9 @@ callishPrimOpSupported dflags op
ncg = case hscTarget dflags of ncg = case hscTarget dflags of
HscAsm -> True HscAsm -> True
_ -> False _ -> False
llvm = case hscTarget dflags of
HscLlvm -> True
_ -> False
x86ish = case platformArch (targetPlatform dflags) of x86ish = case platformArch (targetPlatform dflags) of
ArchX86 -> True ArchX86 -> True
ArchX86_64 -> True ArchX86_64 -> True
......
...@@ -208,6 +208,14 @@ data LlvmExpression ...@@ -208,6 +208,14 @@ data LlvmExpression
-} -}
| Extract LlvmVar LlvmVar | Extract LlvmVar LlvmVar
{- |
Extract a scalar element from a structure
* val: The structure
* idx: The index of the scalar within the structure
Corresponds to "extractvalue" instruction.
-}
| ExtractV LlvmVar Int
{- | {- |
Insert a scalar element into a vector Insert a scalar element into a vector
* val: The source vector * val: The source vector
......
...@@ -239,6 +239,7 @@ ppLlvmExpression expr ...@@ -239,6 +239,7 @@ ppLlvmExpression expr
Cast op from to -> ppCast op from to Cast op from to -> ppCast op from to
Compare op left right -> ppCmpOp op left right Compare op left right -> ppCmpOp op left right
Extract vec idx -> ppExtract vec idx Extract vec idx -> ppExtract vec idx
ExtractV struct idx -> ppExtractV struct idx
Insert vec elt idx -> ppInsert vec elt idx Insert vec elt idx -> ppInsert vec elt idx
GetElemPtr inb ptr indexes -> ppGetElementPtr inb ptr indexes GetElemPtr inb ptr indexes -> ppGetElementPtr inb ptr indexes
Load ptr -> ppLoad ptr Load ptr -> ppLoad ptr
...@@ -430,6 +431,12 @@ ppExtract vec idx = ...@@ -430,6 +431,12 @@ ppExtract vec idx =
<+> ppr (getVarType vec) <+> ppName vec <> comma <+> ppr (getVarType vec) <+> ppName vec <> comma
<+> ppr idx <+> ppr idx
ppExtractV :: LlvmVar -> Int -> SDoc
ppExtractV struct idx =
text "extractvalue"
<+> ppr (getVarType struct) <+> ppName struct <> comma
<+> ppr idx
ppInsert :: LlvmVar -> LlvmVar -> LlvmVar -> SDoc ppInsert :: LlvmVar -> LlvmVar -> LlvmVar -> SDoc
ppInsert vec elt idx = ppInsert vec elt idx =
text "insertelement" text "insertelement"
......
...@@ -50,7 +50,8 @@ data LlvmType ...@@ -50,7 +50,8 @@ data LlvmType
| LMVector Int LlvmType -- ^ A vector of 'LlvmType' | LMVector Int LlvmType -- ^ A vector of 'LlvmType'
| LMLabel -- ^ A 'LlvmVar' can represent a label (address) | LMLabel -- ^ A 'LlvmVar' can represent a label (address)
| LMVoid -- ^ Void type | LMVoid -- ^ Void type
| LMStruct [LlvmType] -- ^ Structure type | LMStruct [LlvmType] -- ^ Packed structure type
| LMStructU [LlvmType] -- ^ Unpacked structure type
| LMAlias LlvmAlias -- ^ A type alias | LMAlias LlvmAlias -- ^ A type alias
| LMMetadata -- ^ LLVM Metadata | LMMetadata -- ^ LLVM Metadata
...@@ -70,6 +71,7 @@ instance Outputable LlvmType where ...@@ -70,6 +71,7 @@ instance Outputable LlvmType where
ppr (LMLabel ) = text "label" ppr (LMLabel ) = text "label"
ppr (LMVoid ) = text "void" ppr (LMVoid ) = text "void"
ppr (LMStruct tys ) = text "<{" <> ppCommaJoin tys <> text "}>" ppr (LMStruct tys ) = text "<{" <> ppCommaJoin tys <> text "}>"
ppr (LMStructU tys ) = text "{" <> ppCommaJoin tys <> text "}"
ppr (LMMetadata ) = text "metadata" ppr (LMMetadata ) = text "metadata"
ppr (LMFunction (LlvmFunctionDecl _ _ _ r varg p _)) ppr (LMFunction (LlvmFunctionDecl _ _ _ r varg p _))
...@@ -326,6 +328,16 @@ llvmWidthInBits dflags (LMVector n ty) = n * llvmWidthInBits dflags ty ...@@ -326,6 +328,16 @@ llvmWidthInBits dflags (LMVector n ty) = n * llvmWidthInBits dflags ty
llvmWidthInBits _ LMLabel = 0 llvmWidthInBits _ LMLabel = 0
llvmWidthInBits _ LMVoid = 0 llvmWidthInBits _ LMVoid = 0
llvmWidthInBits dflags (LMStruct tys) = sum $ map (llvmWidthInBits dflags) tys llvmWidthInBits dflags (LMStruct tys) = sum $ map (llvmWidthInBits dflags) tys
llvmWidthInBits _ (LMStructU _) =
-- It's not trivial to calculate the bit width of the unpacked structs,
-- since they will be aligned depending on the specified datalayout (
-- http://llvm.org/docs/LangRef.html#data-layout ). One way we could support
-- this could be to make the LlvmCodeGen.Ppr.moduleLayout be a data type
-- that exposes the alignment information. However, currently the only place
-- we use unpacked structs is LLVM intrinsics that return them (e.g.,
-- llvm.sadd.with.overflow.*), so we don't actually need to compute their
-- bit width.
panic "llvmWidthInBits: not implemented for LMStructU"
llvmWidthInBits _ (LMFunction _) = 0 llvmWidthInBits _ (LMFunction _) = 0
llvmWidthInBits dflags (LMAlias (_,t)) = llvmWidthInBits dflags t llvmWidthInBits dflags (LMAlias (_,t)) = llvmWidthInBits dflags t
llvmWidthInBits _ LMMetadata = panic "llvmWidthInBits: Meta-data has no runtime representation!" llvmWidthInBits _ LMMetadata = panic "llvmWidthInBits: Meta-data has no runtime representation!"
......
...@@ -30,6 +30,7 @@ import Platform ...@@ -30,6 +30,7 @@ import Platform
import OrdList import OrdList
import UniqSupply import UniqSupply
import Unique import Unique
import Util
import Data.List ( nub ) import Data.List ( nub )
import Data.Maybe ( catMaybes ) import Data.Maybe ( catMaybes )
...@@ -255,6 +256,20 @@ genCall t@(PrimTarget op) [] args ...@@ -255,6 +256,20 @@ genCall t@(PrimTarget op) [] args
`appOL` stmts4 `snocOL` call `appOL` stmts4 `snocOL` call
return (stmts, top1 ++ top2) return (stmts, top1 ++ top2)
-- Handle the MO_{Add,Sub}IntC separately. LLVM versions return a record from
-- which we need to extract the actual values.
genCall t@(PrimTarget (MO_AddIntC w)) [dstV, dstO] [lhs, rhs] =
genCallWithOverflow t w [dstV, dstO] [lhs, rhs]
genCall t@(PrimTarget (MO_SubIntC w)) [dstV, dstO] [lhs, rhs] =
genCallWithOverflow t w [dstV, dstO] [lhs, rhs]
-- Similar to MO_{Add,Sub}IntC, but MO_Add2 expects the first element of the
-- return tuple to be the overflow bit and the second element to contain the
-- actual result of the addition. So we still use genCallWithOverflow but swap
-- the return registers.
genCall t@(PrimTarget (MO_Add2 w)) [dstO, dstV] [lhs, rhs] =
genCallWithOverflow t w [dstV, dstO] [lhs, rhs]
-- Handle all other foreign calls and prim ops. -- Handle all other foreign calls and prim ops.
genCall target res args = do genCall target res args = do
...@@ -360,6 +375,68 @@ genCall target res args = do ...@@ -360,6 +375,68 @@ genCall target res args = do
return (allStmts `snocOL` s2 `snocOL` s3 return (allStmts `snocOL` s2 `snocOL` s3
`appOL` retStmt, top1 ++ top2) `appOL` retStmt, top1 ++ top2)
-- | Generate a call to an LLVM intrinsic that performs arithmetic operation
-- with overflow bit (i.e., returns a struct containing the actual result of the
-- operation and an overflow bit). This function will also extract the overflow
-- bit and zero-extend it (all the corresponding Cmm PrimOps represent the
-- overflow "bit" as a usual Int# or Word#).
genCallWithOverflow
:: ForeignTarget -> Width -> [CmmFormal] -> [CmmActual] -> LlvmM StmtData
genCallWithOverflow t@(PrimTarget op) w [dstV, dstO] [lhs, rhs] = do
-- So far this was only tested for the following three CallishMachOps.
MASSERT( (op `elem` [MO_Add2 w, MO_AddIntC w, MO_SubIntC w]) )
let width = widthToLlvmInt w
-- This will do most of the work of generating the call to the intrinsic and
-- extracting the values from the struct.
(value, overflowBit, (stmts, top)) <-
genCallExtract t w (lhs, rhs) (width, i1)
-- value is i<width>, but overflowBit is i1, so we need to cast (Cmm expects
-- both to be i<width>)
(overflow, zext) <- doExpr width $ Cast LM_Zext overflowBit width
dstRegV <- getCmmReg (CmmLocal dstV)
dstRegO <- getCmmReg (CmmLocal dstO)
let storeV = Store value dstRegV
storeO = Store overflow dstRegO
return (stmts `snocOL` zext `snocOL` storeV `snocOL` storeO, top)
genCallWithOverflow _ _ _ _ =
panic "genCallExtract: wrong ForeignTarget or number of arguments"
-- | A helper function for genCallWithOverflow that handles generating the call
-- to the LLVM intrinsic and extracting the result from the struct to LlvmVars.
genCallExtract
:: ForeignTarget -- ^ PrimOp
-> Width -- ^ Width of the operands.
-> (CmmActual, CmmActual) -- ^ Actual arguments.
-> (LlvmType, LlvmType) -- ^ LLLVM types of the returned sturct.
-> LlvmM (LlvmVar, LlvmVar, StmtData)
genCallExtract target@(PrimTarget op) w (argA, argB) (llvmTypeA, llvmTypeB) = do
let width = widthToLlvmInt w
argTy = [width, width]
retTy = LMStructU [llvmTypeA, llvmTypeB]
-- Process the arguments.
let args_hints = zip [argA, argB] (snd $ foreignTargetHints target)
(argsV1, args1, top1) <- arg_vars args_hints ([], nilOL, [])
(argsV2, args2) <- castVars $ zip argsV1 argTy
-- Get the function and make the call.
fname <- cmmPrimOpFunctions op
(fptr, _, top2) <- getInstrinct fname retTy argTy
-- We use StdCall for primops. See also the last case of genCall.
(retV, call) <- doExpr retTy $ Call StdCall fptr argsV2 []
-- This will result in a two element struct, we need to use "extractvalue"
-- to get them out of it.
(res1, ext1) <- doExpr llvmTypeA (ExtractV retV 0)
(res2, ext2) <- doExpr llvmTypeB (ExtractV retV 1)
let stmts = args1 `appOL` args2 `snocOL` call `snocOL` ext1 `snocOL` ext2
tops = top1 ++ top2
return (res1, res2, (stmts, tops))
genCallExtract _ _ _ _ =
panic "genCallExtract: unsupported ForeignTarget"
-- Handle simple function call that only need simple type casting, of the form: -- Handle simple function call that only need simple type casting, of the form:
-- truncate arg >>= \a -> call(a) >>= zext -- truncate arg >>= \a -> call(a) >>= zext
-- --
...@@ -534,12 +611,16 @@ cmmPrimOpFunctions mop = do ...@@ -534,12 +611,16 @@ cmmPrimOpFunctions mop = do
(MO_Prefetch_Data _ )-> fsLit "llvm.prefetch" (MO_Prefetch_Data _ )-> fsLit "llvm.prefetch"
MO_AddIntC w -> fsLit $ "llvm.sadd.with.overflow."
++ showSDoc dflags (ppr $ widthToLlvmInt w)
MO_SubIntC w -> fsLit $ "llvm.ssub.with.overflow."
++ showSDoc dflags (ppr $ widthToLlvmInt w)
MO_Add2 w -> fsLit $ "llvm.uadd.with.overflow."
++ showSDoc dflags (ppr $ widthToLlvmInt w)
MO_S_QuotRem {} -> unsupported MO_S_QuotRem {} -> unsupported
MO_U_QuotRem {} -> unsupported MO_U_QuotRem {} -> unsupported
MO_U_QuotRem2 {} -> unsupported MO_U_QuotRem2 {} -> unsupported
MO_Add2 {} -> unsupported
MO_AddIntC {} -> unsupported
MO_SubIntC {} -> unsupported
MO_U_Mul2 {} -> unsupported MO_U_Mul2 {} -> unsupported
MO_WriteBarrier -> unsupported MO_WriteBarrier -> unsupported
MO_Touch -> unsupported MO_Touch -> unsupported
......
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
module Main where
import GHC.Exts
checkI
:: (Int, Int) -- ^ expected results
-> (Int# -> Int# -> (# Int#, Int# #)) -- ^ primop
-> Int -- ^ first argument
-> Int -- ^ second argument
-> Maybe String -- ^ maybe error
checkI (expX, expY) op (I# a) (I# b) =
case op a b of
(# x, y #)
| I# x == expX && I# y == expY -> Nothing
| otherwise ->
Just $
"Expected " ++ show expX ++ " and " ++ show expY
++ " but got " ++ show (I# x) ++ " and " ++ show (I# y)
checkW
:: (Word, Word) -- ^ expected results
-> (Word# -> Word# -> (# Word#, Word# #)) -- ^ primop
-> Word -- ^ first argument
-> Word -- ^ second argument
-> Maybe String -- ^ maybe error
checkW (expX, expY) op (W# a) (W# b) =
case op a b of
(# x, y #)
| W# x == expX && W# y == expY -> Nothing
| otherwise ->
Just $
"Expected " ++ show expX ++ " and " ++ show expY
++ " but got " ++ show (W# x) ++ " and " ++ show (W# y)
check :: String -> Maybe String -> IO ()
check s (Just err) = error $ "Error for " ++ s ++ ": " ++ err
check _ Nothing = return ()
main :: IO ()
main = do
-- First something trivial
check "addIntC# maxBound 0" $ checkI (maxBound, 0) addIntC# maxBound 0
check "addIntC# 0 maxBound" $ checkI (maxBound, 0) addIntC# 0 maxBound
-- Overflows
check "addIntC# maxBound 1" $ checkI (minBound, 1) addIntC# maxBound 1
check "addIntC# 1 maxBound" $ checkI (minBound, 1) addIntC# 1 maxBound
check "addIntC# maxBound 2" $ checkI (minBound + 1, 1) addIntC# maxBound 2
check "addIntC# 2 maxBound" $ checkI (minBound + 1, 1) addIntC# 2 maxBound
check "addIntC# minBound minBound" $
checkI (0, 1) addIntC# minBound minBound
-- First something trivial
check "subIntC# minBound 0" $ checkI (minBound, 0) subIntC# minBound 0
-- Overflows
check "subIntC# minBound 1" $ checkI (maxBound, 1) subIntC# minBound 1
check "subIntC# minBound 1" $ checkI (maxBound - 1, 1) subIntC# minBound 2
check "subIntC# 0 minBound" $ checkI (minBound, 1) subIntC# 0 minBound
check "subIntC# -1 minBound" $ checkI (maxBound, 0) subIntC# (-1) minBound
check "subIntC# minBound -1" $
checkI (minBound + 1, 0) subIntC# minBound (-1)
-- First something trivial (note that the order of results is different!)
check "plusWord2# maxBound 0" $ checkW (0, maxBound) plusWord2# maxBound 0
check "plusWord2# 0 maxBound" $ checkW (0, maxBound) plusWord2# 0 maxBound
-- Overflows
check "plusWord2# maxBound 1" $
checkW (1, minBound) plusWord2# maxBound 1
check "plusWord2# 1 maxBound" $
checkW (1, minBound) plusWord2# 1 maxBound
check "plusWord2# maxBound 2" $
checkW (1, minBound + 1) plusWord2# maxBound 2
check "plusWord2# 2 maxBound" $
checkW (1, minBound + 1) plusWord2# 2 maxBound
test('T6135', normal, compile_and_run, ['']) test('T6135', normal, compile_and_run, [''])
test('T7689', normal, compile_and_run, ['']) test('T7689', normal, compile_and_run, [''])
# The test is using unboxed tuples, so omit ghci
test('T9430', omit_ways(['ghci']), compile_and_run, [''])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment