{-|
Module : HIO
Description : Hierarchical IO

This is a re-implementation of a module from Galois. It is part of the @orc@
package, which I am interested in experimenting with. I wanted to re-implement 2
out of the 3 modules in the original @Orc@ package, so rather than import it
only to chuck 2/3 of it just for its @HIO@ module I have reproduced it here.
-}

{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DeriveFunctor #-}

{-# LANGUAGE CPP #-}

module HIO
  ( -- * Hierarchical IO
    HIO(..)
  , runHIO
  , unHIO
    -- * Thread groups
  , Group(..)
  , newGroup
  , local
  , close
  , finished
  , register
    -- * Auxiliary types
  , Entry(..)
  , Inhabitants(..)
    -- * Profiling HIO
  , countingThreads
  , threadCount
  , incrementThreadCount
  , printThreadReport
  )
where

import System.IO.Unsafe
import Control.Applicative
import Control.Monad
import Control.Exception
import Control.Concurrent.STM.MonadIO
import Control.Concurrent.MonadIO

-- * Preliminary: HIO, Hierarchical I/O

-- | A thread 'Group' accounts for its inhabitants, which may be threads or
-- other 'Group's.
type Group = (TVar Int, TVar Inhabitants)

-- | A group can be 'Closed', in which case it is empty and cannot accept new
-- inhabitants; or 'Open', in which case it contains any number of constituents,
-- and new 'Thread's and 'Group's may be registered with it.
data Inhabitants = Closed | Open [Entry]
data Entry = Thread ThreadId | Group Group

-- | 'HIO' is simply 'IO' augmented with an environment that tracks the current
-- thread 'Group'. This permits tracking forked threads and culling them
-- en masse when an ancestor is killed.
-- Because of its 'MonadIO' instance arbitrary 'IO' actions may be embedded;
-- however it is advised that any action be summarily killed.
newtype HIO a = HIO { HIO a -> Group -> IO a
inGroup :: Group -> IO a }

instance Functor HIO where
    fmap :: (a -> b) -> HIO a -> HIO b
fmap f :: a -> b
f (HIO hio :: Group -> IO a
hio) = (Group -> IO b) -> HIO b
forall a. (Group -> IO a) -> HIO a
HIO ((IO a -> IO b) -> (Group -> IO a) -> Group -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f) Group -> IO a
hio)

instance Monad HIO where
    return :: a -> HIO a
return x :: a
x = (Group -> IO a) -> HIO a
forall a. (Group -> IO a) -> HIO a
HIO ((Group -> IO a) -> HIO a) -> (Group -> IO a) -> HIO a
forall a b. (a -> b) -> a -> b
$ \_ -> a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
    m :: HIO a
m >>= :: HIO a -> (a -> HIO b) -> HIO b
>>= k :: a -> HIO b
k = HIO (HIO b) -> HIO b
forall a. HIO (HIO a) -> HIO a
_join ((a -> HIO b) -> HIO a -> HIO (HIO b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> HIO b
k HIO a
m)
        where _join :: HIO (HIO a) -> HIO a     -- if you stand for nothing,
              _join :: HIO (HIO a) -> HIO a
_join hhio :: HIO (HIO a)
hhio = (Group -> IO a) -> HIO a
forall a. (Group -> IO a) -> HIO a
HIO ((Group -> IO a) -> HIO a) -> (Group -> IO a) -> HIO a
forall a b. (a -> b) -> a -> b
$ \w :: Group
w -> do       -- you'll fall for anything.
                HIO a
x <- HIO (HIO a)
hhio HIO (HIO a) -> Group -> IO (HIO a)
forall a. HIO a -> Group -> IO a
`inGroup` Group
w           --         -- a bathroom
                HIO a
x HIO a -> Group -> IO a
forall a. HIO a -> Group -> IO a
`inGroup` Group
w

instance Applicative HIO where
    pure :: a -> HIO a
pure = a -> HIO a
forall (m :: * -> *) a. Monad m => a -> m a
return
    <*> :: HIO (a -> b) -> HIO a -> HIO b
(<*>) = HIO (a -> b) -> HIO a -> HIO b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

instance MonadIO HIO where
    liftIO :: IO a -> HIO a
liftIO io :: IO a
io = (Group -> IO a) -> HIO a
forall a. (Group -> IO a) -> HIO a
HIO ((Group -> IO a) -> HIO a) -> (Group -> IO a) -> HIO a
forall a b. (a -> b) -> a -> b
$ IO a -> Group -> IO a
forall a b. a -> b -> a
const IO a
io

instance HasFork HIO where
#ifdef __GHC_BLOCK_DEPRECATED__
    fork :: HIO () -> HIO ThreadId
fork hio :: HIO ()
hio = (Group -> IO ThreadId) -> HIO ThreadId
forall a. (Group -> IO a) -> HIO a
HIO ((Group -> IO ThreadId) -> HIO ThreadId)
-> (Group -> IO ThreadId) -> HIO ThreadId
forall a b. (a -> b) -> a -> b
$ \w :: Group
w -> ((forall a. IO a -> IO a) -> IO ThreadId) -> IO ThreadId
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO ThreadId) -> IO ThreadId)
-> ((forall a. IO a -> IO a) -> IO ThreadId) -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ \restore :: forall a. IO a -> IO a
restore -> do
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
countingThreads IO ()
incrementThreadCount
        Group -> IO ()
increment Group
w
        IO () -> IO ThreadId
forall (io :: * -> *). HasFork io => io () -> io ThreadId
fork (do ThreadId
tid <- IO ThreadId
forall (io :: * -> *). HasFork io => io ThreadId
myThreadId
                 Entry -> Group -> IO ()
register (ThreadId -> Entry
Thread ThreadId
tid) Group
w
                 IO () -> IO ()
forall a. IO a -> IO a
restore (HIO ()
hio HIO () -> Group -> IO ()
forall a. HIO a -> Group -> IO a
`inGroup` Group
w)
              IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally`
              Group -> IO ()
decrement Group
w)
#else
    fork hio = HIO $ \w -> block $ do
        fork (block (do tid <- myThreadId
                        register (Thread tid) w
                        unblock (hio `inGroup` w))
              `finally`
              decrement w)
#endif

-- | Creates a new thread group and registers the current environment's thread
-- group in it. If the current group is closed, immediately terminates
-- execution of the current thread.
newGroup :: HIO Group
newGroup :: HIO Group
newGroup = (Group -> IO Group) -> HIO Group
forall a. (Group -> IO a) -> HIO a
HIO ((Group -> IO Group) -> HIO Group)
-> (Group -> IO Group) -> HIO Group
forall a b. (a -> b) -> a -> b
$ \w :: Group
w -> do
    Group
w' <- IO Group
newPrimGroup
    Entry -> Group -> IO ()
register (Group -> Entry
Group Group
w') Group
w
    Group -> IO Group
forall (m :: * -> *) a. Monad m => a -> m a
return Group
w'

-- | Explicitly sets the current 'Group' environment for a 'HIO' monad.
local :: Group -> HIO a -> HIO a
local :: Group -> HIO a -> HIO a
local w :: Group
w p :: HIO a
p = IO a -> HIO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (HIO a
p HIO a -> Group -> IO a
forall a. HIO a -> Group -> IO a
`inGroup` Group
w)

-- | Kills all threads which are descendants of a 'Group' and closes the group,
-- disallowing new threads or groups to be added to the group.
-- Doesn't do anything if the group is already closed.
close :: Group -> IO ()
close :: Group -> IO ()
close (c :: TVar Int
c, t :: TVar Inhabitants
t) = IO () -> IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO ThreadId
forall (io :: * -> *). HasFork io => io () -> io ThreadId
fork (Entry -> IO ()
kill (Group -> Entry
Group (TVar Int
c, TVar Inhabitants
t)) IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> TVar Int -> Int -> IO ()
forall (io :: * -> *) a. MonadIO io => TVar a -> a -> io ()
writeTVar TVar Int
c 0)
               IO ThreadId -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Blocks until the 'Group' @w@ is finished executing.
finished :: Group -> HIO ()
finished :: Group -> HIO ()
finished w :: Group
w = IO () -> HIO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> HIO ()) -> IO () -> HIO ()
forall a b. (a -> b) -> a -> b
$ Group -> IO ()
isZero Group
w

-- | Runs a 'HIO' computation inside a new thread group that has no parent, and
-- blocks until all subthreads of the operation are done executing.
-- If @countingThreads@ is @True@, it then prints some debugging information
-- about the threads run.
runHIO :: HIO b -> IO ()
runHIO :: HIO b -> IO ()
runHIO hio :: HIO b
hio = do
    Group
w <- IO Group
newPrimGroup
    b
_r <- HIO b
hio HIO b -> Group -> IO b
forall a. HIO a -> Group -> IO a
`inGroup` Group
w
    Group -> IO ()
isZero Group
w
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
countingThreads IO ()
printThreadReport
    () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Unsafely extracts the underlying result value from the 'HIO' monad.
unHIO :: HIO a -> a
unHIO :: HIO a -> a
unHIO hio :: HIO a
hio = IO a -> a
forall a. IO a -> a
unsafePerformIO (IO a -> a) -> IO a -> a
forall a b. (a -> b) -> a -> b
$ do
    Group
w <- IO Group
newPrimGroup
    a
_r <- HIO a
hio HIO a -> Group -> IO a
forall a. HIO a -> Group -> IO a
`inGroup` Group
w
    Group -> IO ()
isZero Group
w
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
countingThreads IO ()
printThreadReport
    a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
_r

-- | Creates a new empty thread group.
newPrimGroup :: IO Group
newPrimGroup :: IO Group
newPrimGroup = do
    TVar Int
count <- Int -> IO (TVar Int)
forall (io :: * -> *) a. MonadIO io => a -> io (TVar a)
newTVar 0
    TVar Inhabitants
threads <- Inhabitants -> IO (TVar Inhabitants)
forall (io :: * -> *) a. MonadIO io => a -> io (TVar a)
newTVar ([Entry] -> Inhabitants
Open [])
    Group -> IO Group
forall (m :: * -> *) a. Monad m => a -> m a
return (TVar Int
count, TVar Inhabitants
threads)

-- | Registers a thread/group entry @tid@ in a 'Group', terminating the current
-- thread (suicide) if the group is closed.
register :: Entry -> Group -> IO ()
register :: Entry -> Group -> IO ()
register tid :: Entry
tid (_, t :: TVar Inhabitants
t) = IO (IO ()) -> IO ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (IO (IO ()) -> IO ()) -> IO (IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ STM (IO ()) -> IO (IO ())
forall (io :: * -> *) a. MonadIO io => STM a -> io a
atomically (STM (IO ()) -> IO (IO ())) -> STM (IO ()) -> IO (IO ())
forall a b. (a -> b) -> a -> b
$ do
    Inhabitants
ts <- TVar Inhabitants -> STM Inhabitants
forall a. TVar a -> STM a
readTVarSTM TVar Inhabitants
t
    case Inhabitants
ts of
        Closed      -> IO () -> STM (IO ())
forall (m :: * -> *) a. Monad m => a -> m a
return (IO ThreadId
forall (io :: * -> *). HasFork io => io ThreadId
myThreadId IO ThreadId -> (ThreadId -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ThreadId -> IO ()
forall (io :: * -> *). HasFork io => ThreadId -> io ()
killThread) -- suicide
        Open tids :: [Entry]
tids   -> TVar Inhabitants -> Inhabitants -> STM ()
forall a. TVar a -> a -> STM ()
writeTVarSTM TVar Inhabitants
t ([Entry] -> Inhabitants
Open (Entry
tid Entry -> [Entry] -> [Entry]
forall a. a -> [a] -> [a]
: [Entry]
tids)) STM () -> STM (IO ()) -> STM (IO ())
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> -- register
                       IO () -> STM (IO ())
forall (m :: * -> *) a. Monad m => a -> m a
return (() -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())

-- | Recursively kills a thread/group entry. Does not do anything if the entry
-- is a closed group.
kill :: Entry -> IO ()
kill :: Entry -> IO ()
kill (Thread tid :: ThreadId
tid) = ThreadId -> IO ()
forall (io :: * -> *). HasFork io => ThreadId -> io ()
killThread ThreadId
tid
kill (Group (_,t :: TVar Inhabitants
t)) = do
    (ts :: Inhabitants
ts, _) <- TVar Inhabitants
-> (Inhabitants -> Inhabitants) -> IO (Inhabitants, Inhabitants)
forall (io :: * -> *) a.
MonadIO io =>
TVar a -> (a -> a) -> io (a, a)
modifyTVar TVar Inhabitants
t (Inhabitants -> Inhabitants -> Inhabitants
forall a b. a -> b -> a
const Inhabitants
Closed)
    case Inhabitants
ts of
        Closed      -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        Open tids :: [Entry]
tids   -> [IO ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ ((Entry -> IO ()) -> [Entry] -> [IO ()]
forall a b. (a -> b) -> [a] -> [b]
map Entry -> IO ()
kill [Entry]
tids)

increment, decrement, isZero :: Group -> IO ()
increment :: Group -> IO ()
increment (c :: TVar Int
c, _) = TVar Int -> (Int -> Int) -> IO ()
forall (io :: * -> *) a. MonadIO io => TVar a -> (a -> a) -> io ()
modifyTVar_ TVar Int
c (Int -> Int -> Int
forall a. Num a => a -> a -> a
+1)
decrement :: Group -> IO ()
decrement (c :: TVar Int
c, _) = TVar Int -> (Int -> Int) -> IO ()
forall (io :: * -> *) a. MonadIO io => TVar a -> (a -> a) -> io ()
modifyTVar_ TVar Int
c (\x :: Int
x -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)
isZero :: Group -> IO ()
isZero    (c :: TVar Int
c, _) = STM () -> IO ()
forall (io :: * -> *) a. MonadIO io => STM a -> io a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ (TVar Int -> STM Int
forall a. TVar a -> STM a
readTVarSTM TVar Int
c STM Int -> (Int -> STM ()) -> STM ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Bool -> STM ()
check (Bool -> STM ()) -> (Int -> Bool) -> Int -> STM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 0)))

-- * Profiling HIO

countingThreads :: Bool
countingThreads :: Bool
countingThreads = Bool
True

threadCount :: TVar Integer
threadCount :: TVar Integer
threadCount = IO (TVar Integer) -> TVar Integer
forall a. IO a -> a
unsafePerformIO (IO (TVar Integer) -> TVar Integer)
-> IO (TVar Integer) -> TVar Integer
forall a b. (a -> b) -> a -> b
$ Integer -> IO (TVar Integer)
forall (io :: * -> *) a. MonadIO io => a -> io (TVar a)
newTVar 0

incrementThreadCount :: IO ()
incrementThreadCount :: IO ()
incrementThreadCount = TVar Integer -> (Integer -> Integer) -> IO ()
forall (io :: * -> *) a. MonadIO io => TVar a -> (a -> a) -> io ()
modifyTVar_ TVar Integer
threadCount (Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+1)

printThreadReport :: IO ()
printThreadReport :: IO ()
printThreadReport = do
    Integer
n <- TVar Integer -> IO Integer
forall (io :: * -> *) a. MonadIO io => TVar a -> io a
readTVar TVar Integer
threadCount
    String -> IO ()
putStrLn "----------"
    String -> IO ()
putStrLn (Integer -> String
forall a. Show a => a -> String
show Integer
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ " HIO threads were forked.")