Skip to content

Commit

Permalink
Finish draft of replacing CT with ReaderT
Browse files Browse the repository at this point in the history
  • Loading branch information
EduardoLR10 committed Aug 5, 2024
1 parent 17e9fde commit e70f995
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 74 deletions.
98 changes: 48 additions & 50 deletions src/Integrator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ import Solver
iterToTime )
import Interpolation ( interpolate )
import Memo ( memo )
import Control.Monad.Trans.Reader (reader, ask, runReaderT)
import Control.Monad.Trans.Reader
import Control.Monad.IO.Class (liftIO)

integ :: CT Double -> CT Double -> CT (CT Double)
integ diff i =
mdo y <- memo interpolate z
z <- do ps <- ask
z <- ReaderT $ \ps ->
let f = solverToFunction (method $ solver ps)
pure . liftIO $ f diff i y ps
in pure . ReaderT $ f diff i y
return y

-- | The Integrator type represents an integral with caching.
Expand All @@ -31,47 +31,45 @@ data Integrator = Integrator { initial :: CT Double, -- ^ The initial value.
}

initialize :: CT a -> CT a
initialize m = do
ps <- ask
initialize m =
ReaderT $ \ps ->
if iteration ps == 0 && getSolverStage (stage $ solver ps) == 0 then
liftIO $ runReaderT m ps
runReaderT m ps
else
let iv = interval ps
sl = solver ps
in liftIO $ runReaderT m $ ps { time = iterToTime iv sl 0 (SolverStage 0),
iteration = 0,
solver = sl { stage = SolverStage 0 }}
in runReaderT m $ ps { time = iterToTime iv sl 0 (SolverStage 0),
iteration = 0,
solver = sl { stage = SolverStage 0 }}

createInteg :: CT Double -> CT Integrator
createInteg i = do
ps <- ask
r1 <- liftIO . newIORef $ initialize i
r2 <- liftIO . newIORef $ initialize i
let integ = Integrator { initial = i,
cache = r1,
computation = r2 }
z = do ps <- ask
v <- liftIO $ readIORef (computation integ)
liftIO $ runReaderT v ps
y <- liftIO . flip runReaderT ps $ memo interpolate z
liftIO $ writeIORef (cache integ) y
pure integ
createInteg i =
ReaderT $ \ps ->
do r1 <- newIORef $ initialize i
r2 <- newIORef $ initialize i
let integ = Integrator { initial = i,
cache = r1,
computation = r2 }
z = ReaderT $ \ps ->
do v <- readIORef (computation integ)
runReaderT v ps
y <- runReaderT (memo interpolate z) ps
writeIORef (cache integ) y
return integ

readInteg :: Integrator -> CT Double
readInteg integ = do
ps <- ask
v <- liftIO $ readIORef (cache integ)
liftIO $ runReaderT v ps
readInteg integ =
ReaderT $ \ps -> flip runReaderT ps =<< readIORef (cache integ)

updateInteg :: Integrator -> CT Double -> CT ()
updateInteg integ diff =
liftIO $ writeIORef (computation integ) z
ReaderT . const $ writeIORef (computation integ) z
where i = initial integ
z = do ps <- ask
v <- liftIO $ readIORef (cache integ)
let f = solverToFunction (method $ solver ps)
liftIO $ f diff i v ps

z = ReaderT $ \ps ->
let f = solverToFunction (method $ solver ps)
in
(\y -> f diff i y ps) =<< readIORef (cache integ)
solverToFunction Euler = integEuler
solverToFunction RungeKutta2 = integRK2
solverToFunction RungeKutta4 = integRK4
Expand All @@ -89,8 +87,8 @@ integEuler diff i y ps =
sl = solver ps
ty = iterToTime iv sl (n - 1) (SolverStage 0)
psy = ps { time = ty, iteration = n - 1, solver = sl { stage = SolverStage 0} }
a <- liftIO $ runReaderT y psy
b <- liftIO $ runReaderT diff psy
a <- runReaderT y psy
b <- runReaderT diff psy
let !v = a + dt (solver ps) * b
return v

Expand All @@ -112,9 +110,9 @@ integRK2 f i y ps =
psy = ps { time = ty, iteration = n - 1, solver = sl { stage = SolverStage 0 }}
ps1 = psy
ps2 = ps { time = t2, iteration = n - 1, solver = sl { stage = SolverStage 1 }}
vy <- liftIO $ runReaderT y psy
k1 <- liftIO $ runReaderT f ps1
k2 <- liftIO $ runReaderT f ps2
vy <- runReaderT y psy
k1 <- runReaderT f ps1
k2 <- runReaderT f ps2
let !v = vy + dt sl / 2.0 * (k1 + k2)
return v
SolverStage 1 -> do
Expand All @@ -125,8 +123,8 @@ integRK2 f i y ps =
t1 = ty
psy = ps { time = ty, iteration = n, solver = sl { stage = SolverStage 0 }}
ps1 = psy
vy <- liftIO $ runReaderT y psy
k1 <- liftIO $ runReaderT f ps1
vy <- runReaderT y psy
k1 <- runReaderT f ps1
let !v = vy + dt sl * k1
return v
_ ->
Expand Down Expand Up @@ -154,11 +152,11 @@ integRK4 f i y ps =
ps2 = ps { time = t2, iteration = n - 1, solver = sl { stage = SolverStage 1 }}
ps3 = ps { time = t3, iteration = n - 1, solver = sl { stage = SolverStage 2 }}
ps4 = ps { time = t4, iteration = n - 1, solver = sl { stage = SolverStage 3 }}
vy <- liftIO $ runReaderT y psy
k1 <- liftIO $ runReaderT f ps1
k2 <- liftIO $ runReaderT f ps2
k3 <- liftIO $ runReaderT f ps3
k4 <- liftIO $ runReaderT f ps4
vy <- runReaderT y psy
k1 <- runReaderT f ps1
k2 <- runReaderT f ps2
k3 <- runReaderT f ps3
k4 <- runReaderT f ps4
let !v = vy + dt sl / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4)
return v
SolverStage 1 -> do
Expand All @@ -169,8 +167,8 @@ integRK4 f i y ps =
t1 = ty
psy = ps { time = ty, iteration = n, solver = sl { stage = SolverStage 0 }}
ps1 = psy
vy <- liftIO $ runReaderT y psy
k1 <- liftIO $ runReaderT f ps1
vy <- runReaderT y psy
k1 <- runReaderT f ps1
let !v = vy + dt sl / 2.0 * k1
return v
SolverStage 2 -> do
Expand All @@ -181,8 +179,8 @@ integRK4 f i y ps =
t2 = iterToTime iv sl n (SolverStage 1)
psy = ps { time = ty, iteration = n, solver = sl { stage = SolverStage 0 }}
ps2 = ps { time = t2, iteration = n, solver = sl { stage = SolverStage 1 }}
vy <- liftIO $ runReaderT y psy
k2 <- liftIO $ runReaderT f ps2
vy <- runReaderT y psy
k2 <- runReaderT f ps2
let !v = vy + dt sl / 2.0 * k2
return v
SolverStage 3 -> do
Expand All @@ -193,8 +191,8 @@ integRK4 f i y ps =
t3 = iterToTime iv sl n (SolverStage 2)
psy = ps { time = ty, iteration = n, solver = sl { stage = SolverStage 0 }}
ps3 = ps { time = t3, iteration = n, solver = sl { stage = SolverStage 2 }}
vy <- liftIO $ runReaderT y psy
k3 <- liftIO $ runReaderT f ps3
vy <- runReaderT y psy
k3 <- runReaderT f ps3
let !v = vy + dt sl * k3
return v
_ ->
Expand Down
14 changes: 7 additions & 7 deletions src/Interpolation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import Solver
Stage(SolverStage, Interpolate),
getSolverStage,
iterToTime )
import Control.Monad.Trans.Reader (reader, ask, runReaderT)
import Control.Monad.Trans.Reader ( ReaderT(ReaderT, runReaderT) )
import Control.Monad.IO.Class (liftIO)

-- | Function to solve floating point approximations
Expand All @@ -18,8 +18,8 @@ neighborhood sl t t' =

-- | Discretize the computation in the integration time points.
discrete :: CT a -> CT a
discrete m = do
ps <- ask
discrete m =
ReaderT $ \ps ->
let st = getSolverStage $ stage (solver ps)
r | st == 0 = runReaderT m ps
| st > 0 = let iv = interval ps
Expand All @@ -36,14 +36,14 @@ discrete m = do
in runReaderT m $ ps { time = iterToTime iv sl n' (SolverStage 0),
iteration = n',
solver = sl { stage = SolverStage 0} }
liftIO r
in r

-- | Interpolate the computation based on the integration time points only.
interpolate :: CT Double -> CT Double
interpolate m = do
ps <- ask
interpolate m =
ReaderT $ \ps ->
case stage $ solver ps of
SolverStage _ -> liftIO $ runReaderT m ps
SolverStage _ -> runReaderT m ps
Interpolate ->
let iv = interval ps
sl = solver ps
Expand Down
32 changes: 15 additions & 17 deletions src/Memo.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import Simulation
import Data.IORef
import Data.Array
import Data.Array.IO
import Control.Monad.Trans.Reader (reader, ask, runReaderT)
import Control.Monad.Trans.Reader
import Control.Monad.IO.Class (liftIO)

-- -- | The 'Memo' class specifies a type for which an array can be created.
Expand All @@ -30,17 +30,16 @@ instance (MArray IOUArray e IO) => UMemo e where
-- the specified interpolation and being aware of the Runge-Kutta method.
memo :: UMemo e => (CT e -> CT e) -> CT e
-> CT (CT e)
memo tr m = do
ps <- ask
memo tr m =
ReaderT $ \ps -> do
let sl = solver ps
iv = interval ps
(SolverStage stl, SolverStage stu) = stageBnds sl
(nl, nu) = iterationBnds iv (dt sl)
arr <- liftIO $ newMemoUArray_ ((stl, nl), (stu, nu))
nref <- liftIO $ newIORef 0
stref <- liftIO $ newIORef 0
let r = do
ps <- ask
arr <- newMemoUArray_ ((stl, nl), (stu, nu))
nref <- newIORef 0
stref <- newIORef 0
let r ps = do
let sl = solver ps
iv = interval ps
n = iteration ps
Expand All @@ -62,23 +61,22 @@ memo tr m = do
loop (n' + 1) 0
else do writeIORef stref (st' + 1)
loop n' (st' + 1)
n' <- liftIO $ readIORef nref
st' <- liftIO $ readIORef stref
liftIO $ loop n' st'
pure . tr $ r
n' <- readIORef nref
st' <- readIORef stref
loop n' st'
pure . tr . ReaderT $ r

-- | Memoize and order the computation in the integration time points using
-- the specified interpolation and without knowledge of the Runge-Kutta method.
memo0 :: Memo e => (CT e -> CT e) -> CT e
-> CT (CT e)
memo0 tr m = do
ps <- ask
memo0 tr m =
ReaderT $ \ps -> do
let iv = interval ps
bnds = iterationBnds iv (dt (solver ps))
arr <- liftIO $ newMemoArray_ bnds
nref <- liftIO $ newIORef 0
let r = do
ps <- ask
let r ps = do
let sl = solver ps
iv = interval ps
n = iteration ps
Expand All @@ -96,4 +94,4 @@ memo0 tr m = do
loop (n' + 1)
n' <- liftIO $ readIORef nref
liftIO $ loop n'
pure . tr $ r
pure . tr . ReaderT $ r

0 comments on commit e70f995

Please sign in to comment.