Commit 7637810a authored by Alexander Vershilov's avatar Alexander Vershilov Committed by Austin Seipp
Browse files

Trac #9878: Have StaticPointers support dynamic loading.



Summary:
A mutex is used to protect the SPT.

unsafeLookupStaticPtr and staticPtrKeys in GHC.StaticPtr are made
monadic.

SPT entries are removed in a destructor function of modules.
Authored-by: default avatarFacundo Domínguez <facundo.dominguez@tweag.io>
Authored-by: default avatarAlexander Vershilov <alexander.vershilov@tweag.io>

Test Plan: ./validate

Reviewers: austin, simonpj, hvr

Subscribers: carter, thomie, qnikst, mboes

Differential Revision: https://phabricator.haskell.org/D587

GHC Trac Issues: #9878
parent 099b7676
......@@ -26,6 +26,20 @@
--
-- where the constants are fingerprints produced from the static forms.
--
-- There is also a finalization function for the time when the module is
-- unloaded.
--
-- > static void hs_hpc_fini_Main(void) __attribute__((destructor));
-- > static void hs_hpc_fini_Main(void) {
-- >
-- > static StgWord64 k0[2] = {16252233372134256ULL,7370534374096082ULL};
-- > hs_spt_remove(k0);
-- >
-- > static StgWord64 k1[2] = {12545634534567898ULL,5409674567544151ULL};
-- > hs_spt_remove(k1);
-- >
-- > }
--
module StaticPtrTable (sptInitCode) where
import CoreSyn
......@@ -62,6 +76,15 @@ sptInitCode this_mod entries = vcat
<> semi
| (i, (fp, (n, _))) <- zip [0..] entries
]
, text "static void hs_spt_fini_" <> ppr this_mod
<> text "(void) __attribute__((destructor));"
, text "static void hs_spt_fini_" <> ppr this_mod <> text "(void)"
, braces $ vcat $
[ text "StgWord64 k" <> int i <> text "[2] = "
<> pprFingerprint fp <> semi
$$ text "hs_spt_remove" <> parens (char 'k' <> int i) <> semi
| (i, (fp, _)) <- zip [0..] entries
]
]
where
......
......@@ -28,4 +28,12 @@
* */
void hs_spt_insert (StgWord64 key[2],void* spe_closure);
/** Removes an entry from the Static Pointer Table.
*
* This function is called from the code generated by
* compiler/deSugar/StaticPtrTable.sptInitCode
*
* */
void hs_spt_remove (StgWord64 key[2]);
#endif /* RTS_STATICPTRTABLE_H */
......@@ -24,9 +24,9 @@
--
-- To solve such concern, the references provided by this module offer a key
-- that can be used to locate the values on each process. Each process maintains
-- a global and immutable table of references which can be looked up with a
-- given key. This table is known as the Static Pointer Table. The reference can
-- a global table of references which can be looked up with a given key. This
-- table is known as the Static Pointer Table. The reference can then be
-- dereferenced to obtain the value.
--
-----------------------------------------------------------------------------
......@@ -48,7 +48,6 @@ import Foreign.Ptr (castPtr)
import GHC.Exts (addrToAny#)
import GHC.Ptr (Ptr(..), nullPtr)
import GHC.Fingerprint (Fingerprint(..))
import System.IO.Unsafe (unsafePerformIO)
-- | A reference to a value of type 'a'.
......@@ -74,8 +73,15 @@ staticKey (StaticPtr k _ _) = k
-- This function is unsafe because the program behavior is undefined if the type
-- of the returned 'StaticPtr' does not match the expected one.
--
unsafeLookupStaticPtr :: StaticKey -> Maybe (StaticPtr a)
unsafeLookupStaticPtr k = unsafePerformIO $ sptLookup k
unsafeLookupStaticPtr :: StaticKey -> IO (Maybe (StaticPtr a))
unsafeLookupStaticPtr (Fingerprint w1 w2) = do
ptr@(Ptr addr) <- withArray [w1,w2] (hs_spt_lookup . castPtr)
if (ptr == nullPtr)
then return Nothing
else case addrToAny# addr of
(# spe #) -> return (Just spe)
foreign import ccall unsafe hs_spt_lookup :: Ptr () -> IO (Ptr a)
-- | Miscelaneous information available for debugging purposes.
data StaticPtrInfo = StaticPtrInfo
......@@ -96,20 +102,9 @@ data StaticPtrInfo = StaticPtrInfo
staticPtrInfo :: StaticPtr a -> StaticPtrInfo
staticPtrInfo (StaticPtr _ n _) = n
-- | Like 'unsafeLookupStaticPtr' but evaluates in 'IO'.
sptLookup :: StaticKey -> IO (Maybe (StaticPtr a))
sptLookup (Fingerprint w1 w2) = do
ptr@(Ptr addr) <- withArray [w1,w2] (hs_spt_lookup . castPtr)
if (ptr == nullPtr)
then return Nothing
else case addrToAny# addr of
(# spe #) -> return (Just spe)
foreign import ccall unsafe hs_spt_lookup :: Ptr () -> IO (Ptr a)
-- | A list of all known keys.
staticPtrKeys :: [StaticKey]
staticPtrKeys = unsafePerformIO $ do
staticPtrKeys :: IO [StaticKey]
staticPtrKeys = do
keyCount <- hs_spt_key_count
allocaArray (fromIntegral keyCount) $ \p -> do
count <- hs_spt_keys p keyCount
......
......@@ -1420,6 +1420,7 @@ typedef struct _RtsSymbolVal {
SymI_HasProto(atomic_dec) \
SymI_HasProto(hs_spt_lookup) \
SymI_HasProto(hs_spt_insert) \
SymI_HasProto(hs_spt_remove) \
SymI_HasProto(hs_spt_keys) \
SymI_HasProto(hs_spt_key_count) \
RTS_USER_SIGNALS_SYMBOLS \
......
......@@ -8,12 +8,18 @@
*
*/
#include "Rts.h"
#include "StaticPtrTable.h"
#include "Rts.h"
#include "RtsUtils.h"
#include "Hash.h"
#include "Stable.h"
static HashTable * spt = NULL;
#ifdef THREADED_RTS
static Mutex spt_lock;
#endif
/// Hash function for the SPT.
static int hashFingerprint(HashTable *table, StgWord64 key[2]) {
// Take half of the key to compute the hash.
......@@ -28,21 +34,59 @@ static int compareFingerprint(StgWord64 ptra[2], StgWord64 ptrb[2]) {
void hs_spt_insert(StgWord64 key[2],void *spe_closure) {
// hs_spt_insert is called from constructor functions, so
// the SPT needs to be initialized here.
if (spt == NULL)
if (spt == NULL) {
spt = allocHashTable_( (HashFunction *)hashFingerprint
, (CompareFunction *)compareFingerprint
);
#ifdef THREADED_RTS
initMutex(&spt_lock);
#endif
}
StgStablePtr * entry = stgMallocBytes( sizeof(StgStablePtr)
, "hs_spt_insert: entry"
);
*entry = getStablePtr(spe_closure);
ACQUIRE_LOCK(&spt_lock);
insertHashTable(spt, (StgWord)key, entry);
RELEASE_LOCK(&spt_lock);
}
getStablePtr(spe_closure);
insertHashTable(spt, (StgWord)key, spe_closure);
static void freeSptEntry(void* entry) {
freeStablePtr(*(StgStablePtr*)entry);
stgFree(entry);
}
void hs_spt_remove(StgWord64 key[2]) {
if (spt) {
ACQUIRE_LOCK(&spt_lock);
StgStablePtr* entry = removeHashTable(spt, (StgWord)key, NULL);
RELEASE_LOCK(&spt_lock);
if (entry)
freeSptEntry(entry);
}
}
StgPtr hs_spt_lookup(StgWord64 key[2]) {
return spt ? lookupHashTable(spt, (StgWord)key) : NULL;
if (spt) {
ACQUIRE_LOCK(&spt_lock);
const StgStablePtr * entry = lookupHashTable(spt, (StgWord)key);
RELEASE_LOCK(&spt_lock);
const StgPtr ret = entry ? deRefStablePtr(*entry) : NULL;
return ret;
} else
return NULL;
}
int hs_spt_keys(StgPtr keys[], int szKeys) {
return spt ? keysHashTable(spt, (StgWord*)keys, szKeys) : 0;
if (spt) {
ACQUIRE_LOCK(&spt_lock);
const int ret = keysHashTable(spt, (StgWord*)keys, szKeys);
RELEASE_LOCK(&spt_lock);
return ret;
} else
return 0;
}
int hs_spt_key_count() {
......@@ -51,7 +95,10 @@ int hs_spt_key_count() {
void exitStaticPtrTable() {
if (spt) {
freeHashTable(spt, NULL);
freeHashTable(spt, freeSptEntry);
spt = NULL;
#ifdef THREADED_RTS
closeMutex(&spt_lock);
#endif
}
}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE StaticPointers #-}
-- | A test to use symbols produced by the static form.
......@@ -9,15 +10,15 @@ import GHC.StaticPtr
main :: IO ()
main = do
print $ lookupKey (static (id . id)) (1 :: Int)
print $ lookupKey (static method :: StaticPtr (Char -> Int)) 'a'
lookupKey (static (id . id)) >>= \f -> print $ f (1 :: Int)
lookupKey (static method :: StaticPtr (Char -> Int)) >>= \f -> print $ f 'a'
print $ deRefStaticPtr (static g)
print $ deRefStaticPtr p0 'a'
print $ deRefStaticPtr (static t_field) $ T 'b'
lookupKey :: StaticPtr a -> a
lookupKey p = case unsafeLookupStaticPtr (staticKey p) of
Just p -> deRefStaticPtr p
lookupKey :: StaticPtr a -> IO a
lookupKey p = unsafeLookupStaticPtr (staticKey p) >>= \case
Just p -> return $ deRefStaticPtr p
Nothing -> error $ "couldn't find " ++ show (staticPtrInfo p)
g :: String
......
......@@ -26,7 +26,7 @@ main = do
print z
performGC
threadDelay 1000000
let Just p = unsafeLookupStaticPtr nats_key
Just p <- unsafeLookupStaticPtr nats_key
print (deRefStaticPtr (unsafeCoerce p) !! 800 :: Integer)
-- Uncommenting the next line keeps 'nats' alive and would prevent a segfault
-- if 'nats' were garbage collected.
......
......@@ -7,10 +7,12 @@ import Data.List ((\\))
import GHC.StaticPtr
import System.Exit
main = when (not $ eqBags staticPtrKeys expected) $ do
print ("expected", expected)
print ("found", staticPtrKeys)
exitFailure
main = do
found <- staticPtrKeys
when (not $ eqBags found expected) $ do
print ("expected", expected)
print ("found", found)
exitFailure
where
expected =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment