Skip to content

Commit

Permalink
Add draft of refactor CT to ReaderT
Browse files Browse the repository at this point in the history
  • Loading branch information
EduardoLR10 committed Jul 17, 2024
1 parent 52c092b commit 0186642
Show file tree
Hide file tree
Showing 9 changed files with 329 additions and 354 deletions.
2 changes: 2 additions & 0 deletions fact.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ library
, criterion-measurement >=0.2.2.0
, monadlist >=0.0.2
, mtl >=2.3.1
, transformers >= 0.6.1.1
default-language: Haskell2010

executable fact-exe
Expand All @@ -70,5 +71,6 @@ executable fact-exe
, criterion-measurement >=0.2.2.0
, monadlist >=0.0.2
, mtl >=2.3.1
, transformers >= 0.6.1.1
, fact
default-language: Haskell2010
1 change: 1 addition & 0 deletions src/Benchmarks.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Benchmarks where

import Examples.ChemicalReaction
import Examples.Lorenz
import Examples.Sine
import Driver
import CT
import IO
Expand Down
106 changes: 30 additions & 76 deletions src/CT.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,17 @@
-- Stability : stable
-- Tested with: GHC 8.10.7
-- |

{-# LANGUAGE FlexibleInstances #-}
module CT
(CT(..),
Parameters(..)) where

import Control.Monad
import Control.Monad.Fix
import Control.Monad.Trans.Reader ( ReaderT )

import Types
import Types ( Iteration )

import Solver
import Simulation
import Solver ( Solver )
import Simulation ( Interval )

-- | It defines the simulation time appended with additional information.
data Parameters = Parameters { interval :: Interval, -- ^ the simulation interval
Expand All @@ -57,83 +56,38 @@ data Parameters = Parameters { interval :: Interval, -- ^ the simulation interv
iteration :: Iteration -- ^ the current iteration
} deriving (Eq, Show)

newtype CT a = CT {apply :: Parameters -> IO a}

instance Functor CT where
fmap f (CT da) = CT $ \ps -> fmap f (da ps)

instance Applicative CT where
pure a = CT $ const (pure a)
(CT df) <*> (CT da) = CT $ \ps -> do f <- df ps
fmap f (da ps)

appComposition :: CT (a -> b) -> CT a -> CT b
appComposition (CT df) (CT da)
= CT $ \ps -> df ps >>= \f -> fmap f (da ps)

instance Monad CT where
return = pure
(CT m) >>= k = CT $ \ps -> do a <- m ps
k a `apply` ps

instance MonadFix CT where
-- mfix :: (a -> m a) -> m a
mfix f =
CT $ \ps -> mfix ((`apply` ps) . f)

returnD :: a -> CT a
returnD a = CT $ const (return a)

bindD :: (a -> CT b ) -> CT a -> CT b
bindD k (CT m) =
CT $ \ps -> m ps >>= \a -> (\(CT m') -> m' ps) $ k a

bindD' :: (a -> CT b ) -> CT a -> CT b
bindD' k (CT m) = CT $ \ps -> do
a <- m ps
k a `apply` ps

instance Eq (CT a) where
x == y = error "<< Can't compare dynamics >>"

instance Show (CT a) where
showsPrec _ x = showString "<< CT >>"

unaryOP :: (a -> b) -> CT a -> CT b
unaryOP = fmap

binaryOP :: (a -> b -> c) -> CT a -> CT b -> CT c
binaryOP func da db = fmap func da <*> db
type CT a = ReaderT Parameters IO a

instance (Num a) => Num (CT a) where
x + y = binaryOP (+) x y
x - y = binaryOP (-) x y
x * y = binaryOP (*) x y
negate = unaryOP negate
abs = unaryOP abs
signum = unaryOP signum
x + y = (+) <$> x <*> y
x - y = (-) <$> x <*> y
x * y = (*) <$> x <*> y
negate = fmap negate
abs = fmap abs
signum = fmap signum
fromInteger i = return $ fromInteger i

instance (Fractional a) => Fractional (CT a) where
x / y = binaryOP (/) x y
recip = unaryOP recip
x / y = (/) <$> x <*> y
recip = fmap recip
fromRational t = return $ fromRational t

instance (Floating a) => Floating (CT a) where
pi = return pi
exp = unaryOP exp
log = unaryOP log
sqrt = unaryOP sqrt
x ** y = binaryOP (**) x y
sin = unaryOP sin
cos = unaryOP cos
tan = unaryOP tan
asin = unaryOP asin
acos = unaryOP acos
atan = unaryOP atan
sinh = unaryOP sinh
cosh = unaryOP cosh
tanh = unaryOP tanh
asinh = unaryOP asinh
acosh = unaryOP acosh
atanh = unaryOP atanh
exp = fmap exp
log = fmap log
sqrt = fmap sqrt
x ** y = (**) <$> x <*> y
sin = fmap sin
cos = fmap cos
tan = fmap tan
asin = fmap asin
acos = fmap acos
atan = fmap atan
sinh = fmap sinh
cosh = fmap cosh
tanh = fmap tanh
asinh = fmap asinh
acosh = fmap acosh
atanh = fmap atanh
56 changes: 29 additions & 27 deletions src/Driver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import CT
import Solver
import Simulation
import Types
import Control.Monad.Trans.Reader (reader, ask, runReaderT)
import Control.Monad.IO.Class (liftIO)

type Model a = CT (CT a)

Expand All @@ -12,43 +14,43 @@ epslon = 0.00001
-- | Run the simulation and return the result in the last
-- time point using the specified simulation specs.
runCTFinal :: Model a -> Double -> Solver -> IO a
runCTFinal (CT m) t sl =
do d <- m Parameters { interval = Interval 0 t,
time = 0,
iteration = 0,
solver = sl { stage = SolverStage 0 }}
runCTFinal m t sl =
do d <- runReaderT m $ Parameters { interval = Interval 0 t,
time = 0,
iteration = 0,
solver = sl { stage = SolverStage 0 }}
subRunCTFinal d t sl

-- | Auxiliary functions to runCTFinal
subRunCTFinal :: CT a -> Double -> Solver -> IO a
subRunCTFinal (CT m) t sl =
do let iv = Interval 0 t
n = iterationHiBnd iv (dt sl)
disct = iterToTime iv sl n (SolverStage 0)
x = m Parameters { interval = iv,
time = disct,
iteration = n,
solver = sl { stage = SolverStage 0 }}
if disct - t < epslon
then x
else m Parameters { interval = iv,
time = t,
iteration = n,
solver = sl { stage = Interpolate }}
subRunCTFinal m t sl = do
let iv = Interval 0 t
n = iterationHiBnd iv (dt sl)
disct = iterToTime iv sl n (SolverStage 0)
x = runReaderT m $ Parameters { interval = iv,
time = disct,
iteration = n,
solver = sl { stage = SolverStage 0 }}
if disct - t < epslon
then x
else runReaderT m $ Parameters { interval = iv,
time = t,
iteration = n,
solver = sl { stage = Interpolate }}

-- | Run the simulation and return the results in all
-- integration time points using the specified simulation specs.
runCT :: Model a -> Double -> Solver -> IO [a]
runCT (CT m) t sl = do
d <- m Parameters { interval = Interval 0 t,
time = 0,
iteration = 0,
solver = sl { stage = SolverStage 0}}
runCT m t sl = do
d <- runReaderT m $ Parameters { interval = Interval 0 t,
time = 0,
iteration = 0,
solver = sl { stage = SolverStage 0}}
sequence $ subRunCT d t sl

-- | Auxiliary functions to runCT
subRunCT :: CT a -> Double -> Solver -> [IO a]
subRunCT (CT m) t sl = do
subRunCT m t sl = do
let iv = Interval 0 t
(nl, nu) = iterationBnds iv (dt sl)
parameterize n =
Expand All @@ -65,8 +67,8 @@ subRunCT (CT m) t sl = do
iteration = nu,
solver = sl {stage = Interpolate}}
endTime = iterToTime iv sl nu (SolverStage 0)
values = map (m . parameterize) [nl .. nu]
values = map (runReaderT m . parameterize) [nl .. nu]
if endTime - t < epslon
then values
else init values ++ [m ps]
else init values ++ [runReaderT m ps]

5 changes: 2 additions & 3 deletions src/Examples/Sine.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
{-# LANGUAGE RecordWildCards #-}
module Examples.Sine where

import Driver
Expand All @@ -11,8 +10,8 @@ import Data.List
import Simulation

sineSolv = Solver { dt = 0.01,
method = RungeKutta4,
stage = SolverStage 0 }
method = Euler,
stage = SolverStage 0 }

sineModel :: Model [Double]
sineModel =
Expand Down
40 changes: 20 additions & 20 deletions src/Examples/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,25 @@ predicate initialCondition =
if initialCondition >= 20
then example2 else example1

demux :: Predicate Double Result -> HybridModel Result
demux predicate (initialCondition, _) p = do
let m = predicate initialCondition
model <- m (pure initialCondition) `apply` p
head <$> model `apply` p
-- demux :: Predicate Double Result -> HybridModel Result
-- demux predicate (initialCondition, _) p = do
-- let m = predicate initialCondition
-- model <- m (pure initialCondition) `apply` p
-- head <$> model `apply` p

hybrid :: (MonadPlus p, Monad m) => (a -> Parameters -> m a) -> a -> Double -> Solver -> m (p a)
hybrid f z t sl =
do let iv = Interval 0 t
(nl, nu) = iterationBnds iv (dt sl)
parameterise n = Parameters { interval = iv,
time = iterToTime iv sl n (SolverStage 0),
iteration = 1,
solver = sl { stage = SolverStage 0 }}
ps = map parameterise [nl..nu]
scanM f z ps
-- hybrid :: (MonadPlus p, Monad m) => (a -> Parameters -> m a) -> a -> Double -> Solver -> m (p a)
-- hybrid f z t sl =
-- do let iv = Interval 0 t
-- (nl, nu) = iterationBnds iv (dt sl)
-- parameterise n = Parameters { interval = iv,
-- time = iterToTime iv sl n (SolverStage 0),
-- iteration = 1,
-- solver = sl { stage = SolverStage 0 }}
-- ps = map parameterise [nl..nu]
-- scanM f z ps

test = do
t <- hybrid (demux predicate) (1, "initial") 40 sineSolv2
case t of
[] -> fail "Something went wrong during hybrid simulation"
list -> print list
-- test = do
-- t <- hybrid (demux predicate) (1, "initial") 40 sineSolv2
-- case t of
-- [] -> fail "Something went wrong during hybrid simulation"
-- list -> print list
Loading

0 comments on commit 0186642

Please sign in to comment.