{-# LANGUAGE ViewPatterns #-}
-- | Functions for computing retainers
module GHC.Debug.Retainers(findRetainersOf, findRetainersOfConstructor, findRetainersOfConstructorExact, findRetainersOfInfoTable, findRetainers, addLocationToStack, displayRetainerStack, addLocationToStack', displayRetainerStack') where

import GHC.Debug.Client
import Control.Monad.State
import GHC.Debug.Trace
import GHC.Debug.Types.Graph
import Control.Monad

import qualified Data.Set as Set
import Control.Monad.RWS

addOne :: [ClosurePtr] -> (Maybe Int, [[ClosurePtr]]) -> (Maybe Int, [[ClosurePtr]])
addOne _ (Just 0, cp) = (Just 0, cp)
addOne cp (n, cps)    = (subtract 1 <$> n, cp:cps)

findRetainersOf :: Maybe Int
                -> [ClosurePtr]
                -> [ClosurePtr]
                -> DebugM [[ClosurePtr]]
findRetainersOf limit cps bads = findRetainers limit cps (\cp _ -> return (cp `Set.member` bad_set))
  where
    bad_set = Set.fromList bads

findRetainersOfConstructor :: Maybe Int -> [ClosurePtr] -> String -> DebugM [[ClosurePtr]]
findRetainersOfConstructor limit rroots con_name =
  findRetainers limit rroots go
  where
    go _ sc =
      case noSize sc of
        ConstrClosure _ _ _ cd -> do
          ConstrDesc _ _  cname <- dereferenceConDesc cd
          return $ cname == con_name
        _ -> return $ False

findRetainersOfConstructorExact :: Maybe Int -> [ClosurePtr] -> String -> DebugM [[ClosurePtr]]
findRetainersOfConstructorExact limit rroots clos_name =
  findRetainers limit rroots go
  where
    go _ sc = do
      loc <- getSourceInfo (tableId (info (noSize sc)))
      case loc of
        Nothing -> return False
        Just cur_loc ->

          return $ (infoName cur_loc) == clos_name

findRetainersOfInfoTable :: Maybe Int -> [ClosurePtr] -> InfoTablePtr -> DebugM [[ClosurePtr]]
findRetainersOfInfoTable limit rroots info_ptr =
  findRetainers limit rroots go
  where
    go _ sc = return $ tableId (info (noSize sc)) == info_ptr

-- | From the given roots, find any path to one of the given pointers.
-- Note: This function can be quite slow! The first argument is a limit to
-- how many paths to find. You should normally set this to a small number
-- such as 10.
findRetainers :: Maybe Int -> [ClosurePtr] -> (ClosurePtr -> SizedClosure -> DebugM Bool) -> DebugM [[ClosurePtr]]
findRetainers limit rroots p = (\(_, r, _) -> snd r) <$> runRWST (traceFromM funcs rroots) [] (limit, [])
  where
    funcs = TraceFunctions {
               papTrace = const (return ())
              , srtTrace = const (return ())
              , stackTrace = const (return ())
              , closTrace = closAccum
              , visitedVal = const (return ())
              , conDescTrace = const (return ())

            }
    -- Add clos
    closAccum  :: ClosurePtr
               -> SizedClosure
               -> RWST [ClosurePtr] () (Maybe Int, [[ClosurePtr]]) DebugM ()
               -> RWST [ClosurePtr] () (Maybe Int, [[ClosurePtr]]) DebugM ()
    closAccum _ (noSize -> WeakClosure {}) _ = return ()
    closAccum cp sc k = do
      b <- lift $ p cp sc
      if b
        then do
          ctx <- ask
          modify' (addOne (cp: ctx))
          local (cp:) k
          -- Don't call k, there might be more paths to the pointer but we
          -- probably just care about this first one.
        else do
          (lim, _) <- get
          case lim of
            Just 0 -> return ()
            _ -> local (cp:) k

addLocationToStack :: [ClosurePtr] -> DebugM [(SizedClosureP, Maybe SourceInformation)]
addLocationToStack r = do
  cs <- dereferenceClosures r
  cs' <- mapM dereferenceToClosurePtr cs
  locs <- mapM getSourceLoc cs'
  return $ (zip cs' locs)
  where
    getSourceLoc c = getSourceInfo (tableId (info (noSize c)))

addLocationToStack' :: [ClosurePtr] -> DebugM [(ClosurePtr, SizedClosureP, Maybe SourceInformation)]
addLocationToStack' r = do
  cs <- dereferenceClosures r
  cs' <- mapM dereferenceToClosurePtr cs
  locs <- mapM getSourceLoc cs'
  return $ (zip3 r cs' locs)
  where
    getSourceLoc c = getSourceInfo (tableId (info (noSize c)))

displayRetainerStack :: [(String, [(SizedClosureP, Maybe SourceInformation)])] -> IO ()
displayRetainerStack rs = do
      let disp (d, l) =
            (ppClosure  (\_ -> show) 0 . noSize $ d) ++  " <" ++ maybe "nl" tdisplay l ++ ">"
            where
              tdisplay sl = infoName sl ++ ":" ++ infoType sl ++ ":" ++ infoModule sl ++ ":" ++ infoPosition sl
          do_one k (l, stack) = do
            putStrLn (show k ++ "-------------------------------------")
            print l
            mapM (putStrLn . disp) stack
      zipWithM_ do_one [0 :: Int ..] rs

displayRetainerStack' :: [(String, [(ClosurePtr, SizedClosureP, Maybe SourceInformation)])] -> IO ()
displayRetainerStack' rs = do
      let disp (p, d, l) =
            show p ++ ": " ++ (ppClosure  (\_ -> show) 0 . noSize $ d) ++  " <" ++ maybe "nl" tdisplay l ++ ">"
            where
              tdisplay sl = infoName sl ++ ":" ++ infoType sl ++ ":" ++ infoModule sl ++ ":" ++ infoPosition sl
          do_one k (l, stack) = do
            putStrLn (show k ++ "-------------------------------------")
            print l
            mapM (putStrLn . disp) stack
      zipWithM_ do_one [0 :: Int ..] rs
