diff --git a/src/Integrator.hs b/src/Integrator.hs index 0fe82ee..a40a803 100644 --- a/src/Integrator.hs +++ b/src/Integrator.hs @@ -13,15 +13,15 @@ import Solver iterToTime ) import Interpolation ( interpolate ) import Memo ( memo ) -import Control.Monad.Trans.Reader (reader, ask, runReaderT) +import Control.Monad.Trans.Reader import Control.Monad.IO.Class (liftIO) integ :: CT Double -> CT Double -> CT (CT Double) integ diff i = mdo y <- memo interpolate z - z <- do ps <- ask + z <- ReaderT $ \ps -> let f = solverToFunction (method $ solver ps) - pure . liftIO $ f diff i y ps + in pure . ReaderT $ f diff i y return y -- | The Integrator type represents an integral with caching. @@ -31,47 +31,45 @@ data Integrator = Integrator { initial :: CT Double, -- ^ The initial value. } initialize :: CT a -> CT a -initialize m = do - ps <- ask +initialize m = + ReaderT $ \ps -> if iteration ps == 0 && getSolverStage (stage $ solver ps) == 0 then - liftIO $ runReaderT m ps + runReaderT m ps else let iv = interval ps sl = solver ps - in liftIO $ runReaderT m $ ps { time = iterToTime iv sl 0 (SolverStage 0), - iteration = 0, - solver = sl { stage = SolverStage 0 }} + in runReaderT m $ ps { time = iterToTime iv sl 0 (SolverStage 0), + iteration = 0, + solver = sl { stage = SolverStage 0 }} createInteg :: CT Double -> CT Integrator -createInteg i = do - ps <- ask - r1 <- liftIO . newIORef $ initialize i - r2 <- liftIO . newIORef $ initialize i - let integ = Integrator { initial = i, - cache = r1, - computation = r2 } - z = do ps <- ask - v <- liftIO $ readIORef (computation integ) - liftIO $ runReaderT v ps - y <- liftIO . flip runReaderT ps $ memo interpolate z - liftIO $ writeIORef (cache integ) y - pure integ +createInteg i = + ReaderT $ \ps -> + do r1 <- newIORef $ initialize i + r2 <- newIORef $ initialize i + let integ = Integrator { initial = i, + cache = r1, + computation = r2 } + z = ReaderT $ \ps -> + do v <- readIORef (computation integ) + runReaderT v ps + y <- runReaderT (memo interpolate z) ps + writeIORef (cache integ) y + return integ readInteg :: Integrator -> CT Double -readInteg integ = do - ps <- ask - v <- liftIO $ readIORef (cache integ) - liftIO $ runReaderT v ps +readInteg integ = + ReaderT $ \ps -> flip runReaderT ps =<< readIORef (cache integ) updateInteg :: Integrator -> CT Double -> CT () updateInteg integ diff = - liftIO $ writeIORef (computation integ) z + ReaderT . const $ writeIORef (computation integ) z where i = initial integ - z = do ps <- ask - v <- liftIO $ readIORef (cache integ) - let f = solverToFunction (method $ solver ps) - liftIO $ f diff i v ps - + z = ReaderT $ \ps -> + let f = solverToFunction (method $ solver ps) + in + (\y -> f diff i y ps) =<< readIORef (cache integ) + solverToFunction Euler = integEuler solverToFunction RungeKutta2 = integRK2 solverToFunction RungeKutta4 = integRK4 @@ -89,8 +87,8 @@ integEuler diff i y ps = sl = solver ps ty = iterToTime iv sl (n - 1) (SolverStage 0) psy = ps { time = ty, iteration = n - 1, solver = sl { stage = SolverStage 0} } - a <- liftIO $ runReaderT y psy - b <- liftIO $ runReaderT diff psy + a <- runReaderT y psy + b <- runReaderT diff psy let !v = a + dt (solver ps) * b return v @@ -112,9 +110,9 @@ integRK2 f i y ps = psy = ps { time = ty, iteration = n - 1, solver = sl { stage = SolverStage 0 }} ps1 = psy ps2 = ps { time = t2, iteration = n - 1, solver = sl { stage = SolverStage 1 }} - vy <- liftIO $ runReaderT y psy - k1 <- liftIO $ runReaderT f ps1 - k2 <- liftIO $ runReaderT f ps2 + vy <- runReaderT y psy + k1 <- runReaderT f ps1 + k2 <- runReaderT f ps2 let !v = vy + dt sl / 2.0 * (k1 + k2) return v SolverStage 1 -> do @@ -125,8 +123,8 @@ integRK2 f i y ps = t1 = ty psy = ps { time = ty, iteration = n, solver = sl { stage = SolverStage 0 }} ps1 = psy - vy <- liftIO $ runReaderT y psy - k1 <- liftIO $ runReaderT f ps1 + vy <- runReaderT y psy + k1 <- runReaderT f ps1 let !v = vy + dt sl * k1 return v _ -> @@ -154,11 +152,11 @@ integRK4 f i y ps = ps2 = ps { time = t2, iteration = n - 1, solver = sl { stage = SolverStage 1 }} ps3 = ps { time = t3, iteration = n - 1, solver = sl { stage = SolverStage 2 }} ps4 = ps { time = t4, iteration = n - 1, solver = sl { stage = SolverStage 3 }} - vy <- liftIO $ runReaderT y psy - k1 <- liftIO $ runReaderT f ps1 - k2 <- liftIO $ runReaderT f ps2 - k3 <- liftIO $ runReaderT f ps3 - k4 <- liftIO $ runReaderT f ps4 + vy <- runReaderT y psy + k1 <- runReaderT f ps1 + k2 <- runReaderT f ps2 + k3 <- runReaderT f ps3 + k4 <- runReaderT f ps4 let !v = vy + dt sl / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4) return v SolverStage 1 -> do @@ -169,8 +167,8 @@ integRK4 f i y ps = t1 = ty psy = ps { time = ty, iteration = n, solver = sl { stage = SolverStage 0 }} ps1 = psy - vy <- liftIO $ runReaderT y psy - k1 <- liftIO $ runReaderT f ps1 + vy <- runReaderT y psy + k1 <- runReaderT f ps1 let !v = vy + dt sl / 2.0 * k1 return v SolverStage 2 -> do @@ -181,8 +179,8 @@ integRK4 f i y ps = t2 = iterToTime iv sl n (SolverStage 1) psy = ps { time = ty, iteration = n, solver = sl { stage = SolverStage 0 }} ps2 = ps { time = t2, iteration = n, solver = sl { stage = SolverStage 1 }} - vy <- liftIO $ runReaderT y psy - k2 <- liftIO $ runReaderT f ps2 + vy <- runReaderT y psy + k2 <- runReaderT f ps2 let !v = vy + dt sl / 2.0 * k2 return v SolverStage 3 -> do @@ -193,8 +191,8 @@ integRK4 f i y ps = t3 = iterToTime iv sl n (SolverStage 2) psy = ps { time = ty, iteration = n, solver = sl { stage = SolverStage 0 }} ps3 = ps { time = t3, iteration = n, solver = sl { stage = SolverStage 2 }} - vy <- liftIO $ runReaderT y psy - k3 <- liftIO $ runReaderT f ps3 + vy <- runReaderT y psy + k3 <- runReaderT f ps3 let !v = vy + dt sl * k3 return v _ -> diff --git a/src/Interpolation.hs b/src/Interpolation.hs index 0ef4da2..e83f8f0 100644 --- a/src/Interpolation.hs +++ b/src/Interpolation.hs @@ -8,7 +8,7 @@ import Solver Stage(SolverStage, Interpolate), getSolverStage, iterToTime ) -import Control.Monad.Trans.Reader (reader, ask, runReaderT) +import Control.Monad.Trans.Reader ( ReaderT(ReaderT, runReaderT) ) import Control.Monad.IO.Class (liftIO) -- | Function to solve floating point approximations @@ -18,8 +18,8 @@ neighborhood sl t t' = -- | Discretize the computation in the integration time points. discrete :: CT a -> CT a -discrete m = do - ps <- ask +discrete m = + ReaderT $ \ps -> let st = getSolverStage $ stage (solver ps) r | st == 0 = runReaderT m ps | st > 0 = let iv = interval ps @@ -36,14 +36,14 @@ discrete m = do in runReaderT m $ ps { time = iterToTime iv sl n' (SolverStage 0), iteration = n', solver = sl { stage = SolverStage 0} } - liftIO r + in r -- | Interpolate the computation based on the integration time points only. interpolate :: CT Double -> CT Double -interpolate m = do - ps <- ask +interpolate m = + ReaderT $ \ps -> case stage $ solver ps of - SolverStage _ -> liftIO $ runReaderT m ps + SolverStage _ -> runReaderT m ps Interpolate -> let iv = interval ps sl = solver ps diff --git a/src/Memo.hs b/src/Memo.hs index ada30fe..c305daa 100644 --- a/src/Memo.hs +++ b/src/Memo.hs @@ -9,7 +9,7 @@ import Simulation import Data.IORef import Data.Array import Data.Array.IO -import Control.Monad.Trans.Reader (reader, ask, runReaderT) +import Control.Monad.Trans.Reader import Control.Monad.IO.Class (liftIO) -- -- | The 'Memo' class specifies a type for which an array can be created. @@ -30,17 +30,16 @@ instance (MArray IOUArray e IO) => UMemo e where -- the specified interpolation and being aware of the Runge-Kutta method. memo :: UMemo e => (CT e -> CT e) -> CT e -> CT (CT e) -memo tr m = do - ps <- ask +memo tr m = + ReaderT $ \ps -> do let sl = solver ps iv = interval ps (SolverStage stl, SolverStage stu) = stageBnds sl (nl, nu) = iterationBnds iv (dt sl) - arr <- liftIO $ newMemoUArray_ ((stl, nl), (stu, nu)) - nref <- liftIO $ newIORef 0 - stref <- liftIO $ newIORef 0 - let r = do - ps <- ask + arr <- newMemoUArray_ ((stl, nl), (stu, nu)) + nref <- newIORef 0 + stref <- newIORef 0 + let r ps = do let sl = solver ps iv = interval ps n = iteration ps @@ -62,23 +61,22 @@ memo tr m = do loop (n' + 1) 0 else do writeIORef stref (st' + 1) loop n' (st' + 1) - n' <- liftIO $ readIORef nref - st' <- liftIO $ readIORef stref - liftIO $ loop n' st' - pure . tr $ r + n' <- readIORef nref + st' <- readIORef stref + loop n' st' + pure . tr . ReaderT $ r -- | Memoize and order the computation in the integration time points using -- the specified interpolation and without knowledge of the Runge-Kutta method. memo0 :: Memo e => (CT e -> CT e) -> CT e -> CT (CT e) -memo0 tr m = do - ps <- ask +memo0 tr m = + ReaderT $ \ps -> do let iv = interval ps bnds = iterationBnds iv (dt (solver ps)) arr <- liftIO $ newMemoArray_ bnds nref <- liftIO $ newIORef 0 - let r = do - ps <- ask + let r ps = do let sl = solver ps iv = interval ps n = iteration ps @@ -96,4 +94,4 @@ memo0 tr m = do loop (n' + 1) n' <- liftIO $ readIORef nref liftIO $ loop n' - pure . tr $ r + pure . tr . ReaderT $ r