{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE CPP #-}
module HIO
(
HIO(..)
, runHIO
, unHIO
, Group(..)
, newGroup
, local
, close
, finished
, register
, Entry(..)
, Inhabitants(..)
, countingThreads
, threadCount
, incrementThreadCount
, printThreadReport
)
where
import System.IO.Unsafe (unsafePerformIO)
import Control.Monad
( ap
, when
, join )
import Control.Exception (mask, finally)
import Control.Concurrent.STM.MonadIO
( TVar(..)
, readTVar
, modifyTVar
, modifyTVar_
, newTVar
, writeTVarSTM
, writeTVar
, check
, readTVarSTM
, atomically )
import Control.Concurrent.MonadIO
( HasFork(..)
, MonadIO(..)
, ThreadId
, myThreadId
, killThread)
type Group = (TVar Int, TVar Inhabitants)
data Inhabitants = Closed | Open [Entry]
data Entry = Thread ThreadId | Group Group
newtype HIO a = HIO { forall a. HIO a -> Group -> IO a
inGroup :: Group -> IO a }
instance Functor HIO where
fmap :: forall a b. (a -> b) -> HIO a -> HIO b
fmap a -> b
f (HIO Group -> IO a
hio) = forall a. (Group -> IO a) -> HIO a
HIO (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f) Group -> IO a
hio)
instance Applicative HIO where
pure :: forall a. a -> HIO a
pure a
x = forall a. (Group -> IO a) -> HIO a
HIO forall a b. (a -> b) -> a -> b
$ \Group
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
<*> :: forall a b. HIO (a -> b) -> HIO a -> HIO b
(<*>) = forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap
instance Monad HIO where
HIO a
m >>= :: forall a b. HIO a -> (a -> HIO b) -> HIO b
>>= a -> HIO b
k = forall a. HIO (HIO a) -> HIO a
_join (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> HIO b
k HIO a
m)
where _join :: HIO (HIO a) -> HIO a
_join :: forall a. HIO (HIO a) -> HIO a
_join HIO (HIO a)
hhio = forall a. (Group -> IO a) -> HIO a
HIO forall a b. (a -> b) -> a -> b
$ \Group
w -> do
HIO a
x <- HIO (HIO a)
hhio forall a. HIO a -> Group -> IO a
`inGroup` Group
w
HIO a
x forall a. HIO a -> Group -> IO a
`inGroup` Group
w
instance MonadIO HIO where
liftIO :: forall a. IO a -> HIO a
liftIO IO a
io = forall a. (Group -> IO a) -> HIO a
HIO forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const IO a
io
instance HasFork HIO where
#ifdef __GHC_BLOCK_DEPRECATED__
fork :: HIO () -> HIO ThreadId
fork HIO ()
hio = forall a. (Group -> IO a) -> HIO a
HIO forall a b. (a -> b) -> a -> b
$ \Group
w -> forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore -> do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
countingThreads IO ()
incrementThreadCount
Group -> IO ()
increment Group
w
forall (io :: * -> *). HasFork io => io () -> io ThreadId
fork (do ThreadId
tid <- forall (io :: * -> *). HasFork io => io ThreadId
myThreadId
Entry -> Group -> IO ()
register (ThreadId -> Entry
Thread ThreadId
tid) Group
w
forall a. IO a -> IO a
restore (HIO ()
hio forall a. HIO a -> Group -> IO a
`inGroup` Group
w)
forall a b. IO a -> IO b -> IO a
`finally`
Group -> IO ()
decrement Group
w)
#else
fork hio = HIO $ \w -> block $ do
fork (block (do tid <- myThreadId
register (Thread tid) w
unblock (hio `inGroup` w))
`finally`
decrement w)
#endif
newGroup :: HIO Group
newGroup :: HIO Group
newGroup = forall a. (Group -> IO a) -> HIO a
HIO forall a b. (a -> b) -> a -> b
$ \Group
w -> do
Group
w' <- IO Group
newPrimGroup
Entry -> Group -> IO ()
register (Group -> Entry
Group Group
w') Group
w
forall (m :: * -> *) a. Monad m => a -> m a
return Group
w'
local :: Group -> HIO a -> HIO a
local :: forall a. Group -> HIO a -> HIO a
local Group
w HIO a
p = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (HIO a
p forall a. HIO a -> Group -> IO a
`inGroup` Group
w)
close :: Group -> IO ()
close :: Group -> IO ()
close (TVar Int
c, TVar Inhabitants
t) = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (io :: * -> *). HasFork io => io () -> io ThreadId
fork (Entry -> IO ()
kill (Group -> Entry
Group (TVar Int
c, TVar Inhabitants
t)) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (io :: * -> *) a. MonadIO io => TVar a -> a -> io ()
writeTVar TVar Int
c Int
0)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return ()
finished :: Group -> HIO ()
finished :: Group -> HIO ()
finished Group
w = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Group -> IO ()
isZero Group
w
runHIO :: HIO b -> IO ()
runHIO :: forall b. HIO b -> IO ()
runHIO HIO b
hio = do
Group
w <- IO Group
newPrimGroup
b
_r <- HIO b
hio forall a. HIO a -> Group -> IO a
`inGroup` Group
w
Group -> IO ()
isZero Group
w
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
countingThreads IO ()
printThreadReport
forall (m :: * -> *) a. Monad m => a -> m a
return ()
unHIO :: HIO a -> a
unHIO :: forall a. HIO a -> a
unHIO HIO a
hio = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
Group
w <- IO Group
newPrimGroup
a
_r <- HIO a
hio forall a. HIO a -> Group -> IO a
`inGroup` Group
w
Group -> IO ()
isZero Group
w
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
countingThreads IO ()
printThreadReport
forall (m :: * -> *) a. Monad m => a -> m a
return a
_r
newPrimGroup :: IO Group
newPrimGroup :: IO Group
newPrimGroup = do
TVar Int
count <- forall (io :: * -> *) a. MonadIO io => a -> io (TVar a)
newTVar Int
0
TVar Inhabitants
threads <- forall (io :: * -> *) a. MonadIO io => a -> io (TVar a)
newTVar ([Entry] -> Inhabitants
Open [])
forall (m :: * -> *) a. Monad m => a -> m a
return (TVar Int
count, TVar Inhabitants
threads)
register :: Entry -> Group -> IO ()
register :: Entry -> Group -> IO ()
register Entry
tid (TVar Int
_, TVar Inhabitants
t) = forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall a b. (a -> b) -> a -> b
$ forall (io :: * -> *) a. MonadIO io => STM a -> io a
atomically forall a b. (a -> b) -> a -> b
$ do
Inhabitants
ts <- forall a. TVar a -> STM a
readTVarSTM TVar Inhabitants
t
case Inhabitants
ts of
Inhabitants
Closed -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall (io :: * -> *). HasFork io => io ThreadId
myThreadId forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (io :: * -> *). HasFork io => ThreadId -> io ()
killThread)
Open [Entry]
tids -> forall a. TVar a -> a -> STM ()
writeTVarSTM TVar Inhabitants
t ([Entry] -> Inhabitants
Open (Entry
tid forall a. a -> [a] -> [a]
: [Entry]
tids)) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>>
forall (m :: * -> *) a. Monad m => a -> m a
return (forall (m :: * -> *) a. Monad m => a -> m a
return ())
kill :: Entry -> IO ()
kill :: Entry -> IO ()
kill (Thread ThreadId
tid) = forall (io :: * -> *). HasFork io => ThreadId -> io ()
killThread ThreadId
tid
kill (Group (TVar Int
_,TVar Inhabitants
t)) = do
(Inhabitants
ts, Inhabitants
_) <- forall (io :: * -> *) a.
MonadIO io =>
TVar a -> (a -> a) -> io (a, a)
modifyTVar TVar Inhabitants
t (forall a b. a -> b -> a
const Inhabitants
Closed)
case Inhabitants
ts of
Inhabitants
Closed -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
Open [Entry]
tids -> forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ (forall a b. (a -> b) -> [a] -> [b]
map Entry -> IO ()
kill [Entry]
tids)
increment, decrement, isZero :: Group -> IO ()
increment :: Group -> IO ()
increment (TVar Int
c, TVar Inhabitants
_) = forall (io :: * -> *) a. MonadIO io => TVar a -> (a -> a) -> io ()
modifyTVar_ TVar Int
c (forall a. Num a => a -> a -> a
+Int
1)
decrement :: Group -> IO ()
decrement (TVar Int
c, TVar Inhabitants
_) = forall (io :: * -> *) a. MonadIO io => TVar a -> (a -> a) -> io ()
modifyTVar_ TVar Int
c (\Int
x -> Int
x forall a. Num a => a -> a -> a
- Int
1)
isZero :: Group -> IO ()
isZero (TVar Int
c, TVar Inhabitants
_) = forall (io :: * -> *) a. MonadIO io => STM a -> io a
atomically forall a b. (a -> b) -> a -> b
$ (forall a. TVar a -> STM a
readTVarSTM TVar Int
c forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Bool -> STM ()
check forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Eq a => a -> a -> Bool
== Int
0)))
countingThreads :: Bool
countingThreads :: Bool
countingThreads = Bool
True
threadCount :: TVar Integer
threadCount :: TVar Integer
threadCount = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (io :: * -> *) a. MonadIO io => a -> io (TVar a)
newTVar Integer
0
incrementThreadCount :: IO ()
incrementThreadCount :: IO ()
incrementThreadCount = forall (io :: * -> *) a. MonadIO io => TVar a -> (a -> a) -> io ()
modifyTVar_ TVar Integer
threadCount (forall a. Num a => a -> a -> a
+Integer
1)
printThreadReport :: IO ()
printThreadReport :: IO ()
printThreadReport = do
Integer
n <- forall (io :: * -> *) a. MonadIO io => TVar a -> io a
readTVar TVar Integer
threadCount
String -> IO ()
putStrLn String
"----------"
String -> IO ()
putStrLn (forall a. Show a => a -> String
show Integer
n forall a. [a] -> [a] -> [a]
++ String
" HIO threads were forked.")