From 2a7fcf3c2a50bc863bf285901ecee2ff898f78db Mon Sep 17 00:00:00 2001 From: Mike Owen Date: Mon, 4 Nov 2024 15:43:56 -0800 Subject: [PATCH 01/14] Unneccessary include --- src/DataBase/DataBase.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/DataBase/DataBase.cc b/src/DataBase/DataBase.cc index 3ad7073e8..117ea69b0 100644 --- a/src/DataBase/DataBase.cc +++ b/src/DataBase/DataBase.cc @@ -13,7 +13,6 @@ #include "Material/EquationOfState.hh" #include "Utilities/testBoxIntersection.hh" #include "Utilities/safeInv.hh" -#include "State.hh" #include "Hydro/HydroFieldNames.hh" #include "Utilities/globalBoundingVolumes.hh" #include "Utilities/globalNodeIDs.hh" From 31d640136e189b78f45d9d7399079d5a589bed27 Mon Sep 17 00:00:00 2001 From: Mike Owen Date: Wed, 6 Nov 2024 09:51:40 -0800 Subject: [PATCH 02/14] Updating to use new State interface methods --- src/ArtificialConduction/ArtificialConduction.cc | 4 ++-- src/ArtificialViscosity/TensorCRKSPHViscosity.cc | 2 +- src/ArtificialViscosity/VonNeumanViscosity.cc | 2 +- src/CRKSPH/CRKSPHEvaluateDerivatives.cc | 4 ++-- src/CRKSPH/CRKSPHHydroBase.cc | 6 +++--- src/CRKSPH/CRKSPHHydroBaseRZ.cc | 4 ++-- src/CRKSPH/SolidCRKSPHHydroBase.cc | 4 ++-- src/CRKSPH/SolidCRKSPHHydroBaseRZ.cc | 4 ++-- src/DEM/SolidBoundary/CircularPlaneSolidBoundary.cc | 6 +++--- src/DEM/SolidBoundary/ClippedSphereSolidBoundary.cc | 6 +++--- src/DEM/SolidBoundary/CylinderSolidBoundary.cc | 6 +++--- src/DEM/SolidBoundary/InfinitePlaneSolidBoundary.cc | 6 +++--- src/DEM/SolidBoundary/RectangularPlaneSolidBoundary.cc | 4 ++-- src/DEM/SolidBoundary/SphereSolidBoundary.cc | 4 ++-- src/DataBase/CopyStateInline.hh | 4 ++-- src/FSISPH/SolidFSISPHEvaluateDerivatives.cc | 4 ++-- src/FSISPH/SolidFSISPHHydroBase.cc | 4 ++-- src/GSPH/GSPHEvaluateDerivatives.cc | 4 ++-- src/GSPH/GenericRiemannHydro.cc | 4 ++-- src/GSPH/MFMEvaluateDerivatives.cc | 4 ++-- src/GSPH/MFVEvaluateDerivatives.cc | 6 +++--- src/GSPH/MFVHydroBase.cc | 2 +- .../CompatibleMFVSpecificThermalEnergyPolicy.cc | 6 +++--- .../CompatibleDifferenceSpecificThermalEnergyPolicy.cc | 4 ++-- src/Hydro/NonSymmetricSpecificThermalEnergyPolicy.cc | 2 +- src/Hydro/RZNonSymmetricSpecificThermalEnergyPolicy.cc | 2 +- src/Hydro/SpecificThermalEnergyPolicy.cc | 2 +- src/RK/RKCorrections.cc | 4 ++-- src/RK/ReproducingKernel.cc | 9 ++++++--- src/RK/ReproducingKernel.hh | 3 ++- src/RK/ReproducingKernelMethods.cc | 10 +++++++--- src/RK/ReproducingKernelMethods.hh | 3 ++- src/SPH/PSPHHydroBase.cc | 2 +- src/SPH/SPHHydroBase.cc | 4 ++-- src/SPH/SPHHydroBaseRZ.cc | 2 +- src/SPH/SolidSPHHydroBase.cc | 2 +- src/SPH/SolidSPHHydroBaseRZ.cc | 2 +- src/SPH/SolidSphericalSPHHydroBase.cc | 2 +- src/SPH/SphericalSPHHydroBase.cc | 2 +- src/VoronoiCells/SubPointPressureHourglassControl.cc | 2 +- 40 files changed, 83 insertions(+), 74 deletions(-) diff --git a/src/ArtificialConduction/ArtificialConduction.cc b/src/ArtificialConduction/ArtificialConduction.cc index 413eb19cd..318793efe 100644 --- a/src/ArtificialConduction/ArtificialConduction.cc +++ b/src/ArtificialConduction/ArtificialConduction.cc @@ -121,10 +121,10 @@ evaluateDerivatives(const typename Dimension::Scalar /*time*/, ReproducingKernel WR; auto maxOrder = RKOrder::ZerothOrder; if (useRK) { - const auto& rkOrders = state.template getAny>(RKFieldNames::rkOrders); + const auto& rkOrders = state.template get>(RKFieldNames::rkOrders); CHECK(not rkOrders.empty()); const auto maxOrder = *rkOrders.rbegin(); - WR = state.template getAny>(RKFieldNames::reproducingKernel(maxOrder)); + WR = state.template get>(RKFieldNames::reproducingKernel(maxOrder)); } // The connectivity map diff --git a/src/ArtificialViscosity/TensorCRKSPHViscosity.cc b/src/ArtificialViscosity/TensorCRKSPHViscosity.cc index 67a4fa120..b6a805d1f 100644 --- a/src/ArtificialViscosity/TensorCRKSPHViscosity.cc +++ b/src/ArtificialViscosity/TensorCRKSPHViscosity.cc @@ -184,7 +184,7 @@ calculateSigmaAndGradDivV(const DataBase& dataBase, const auto velocity = state.fields(HydroFieldNames::velocity, Vector::zero); const auto rho = state.fields(HydroFieldNames::massDensity, 0.0); const auto H = state.fields(HydroFieldNames::H, SymTensor::zero); - const auto WR = state.template getAny>(RKFieldNames::reproducingKernel(order)); + const auto WR = state.template get>(RKFieldNames::reproducingKernel(order)); const auto corrections = state.fields(RKFieldNames::rkCorrections(order), RKCoefficients()); const auto& connectivityMap = dataBase.connectivityMap(); diff --git a/src/ArtificialViscosity/VonNeumanViscosity.cc b/src/ArtificialViscosity/VonNeumanViscosity.cc index 83bb3e99b..367313e07 100644 --- a/src/ArtificialViscosity/VonNeumanViscosity.cc +++ b/src/ArtificialViscosity/VonNeumanViscosity.cc @@ -83,7 +83,7 @@ initialize(const DataBase& dataBase, const auto pressure = state.fields(HydroFieldNames::pressure, 0.0); const auto soundSpeed = state.fields(HydroFieldNames::soundSpeed, 0.0); const auto vol = mass/massDensity; - const auto WR = state.template getAny>(RKFieldNames::reproducingKernel(order)); + const auto WR = state.template get>(RKFieldNames::reproducingKernel(order)); const auto corrections = state.fields(RKFieldNames::rkCorrections(order), RKCoefficients()); // We'll compute the higher-accuracy RK gradient. diff --git a/src/CRKSPH/CRKSPHEvaluateDerivatives.cc b/src/CRKSPH/CRKSPHEvaluateDerivatives.cc index fa1c72d26..2d99a67ae 100644 --- a/src/CRKSPH/CRKSPHEvaluateDerivatives.cc +++ b/src/CRKSPH/CRKSPHEvaluateDerivatives.cc @@ -16,7 +16,7 @@ evaluateDerivatives(const typename Dimension::Scalar /*time*/, auto& Q = this->artificialViscosity(); // The kernels and such. - const auto& WR = state.template getAny>(RKFieldNames::reproducingKernel(mOrder)); + const auto& WR = state.template get>(RKFieldNames::reproducingKernel(mOrder)); // A few useful constants we'll use in the following loop. //const double tiny = 1.0e-30; @@ -65,7 +65,7 @@ evaluateDerivatives(const typename Dimension::Scalar /*time*/, auto maxViscousPressure = derivatives.fields(HydroFieldNames::maxViscousPressure, 0.0); auto effViscousPressure = derivatives.fields(HydroFieldNames::effectiveViscousPressure, 0.0); auto viscousWork = derivatives.fields(HydroFieldNames::viscousWork, 0.0); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); CHECK(DxDt.size() == numNodeLists); CHECK(DrhoDt.size() == numNodeLists); diff --git a/src/CRKSPH/CRKSPHHydroBase.cc b/src/CRKSPH/CRKSPHHydroBase.cc index 06047630c..13f8b60c6 100644 --- a/src/CRKSPH/CRKSPHHydroBase.cc +++ b/src/CRKSPH/CRKSPHHydroBase.cc @@ -263,7 +263,7 @@ registerDerivatives(DataBase& dataBase, derivs.enroll(mDspecificThermalEnergyDt); derivs.enroll(mDvDx); derivs.enroll(mInternalDvDx); - derivs.enrollAny(HydroFieldNames::pairAccelerations, mPairAccelerations); + derivs.enroll(HydroFieldNames::pairAccelerations, mPairAccelerations); } //------------------------------------------------------------------------------ @@ -282,7 +282,7 @@ preStepInitialize(const DataBase& dataBase, if (mDensityUpdate == MassDensityType::RigorousSumDensity or mDensityUpdate == MassDensityType::VoronoiCellDensity) { auto massDensity = state.fields(HydroFieldNames::massDensity, 0.0); - const auto& WR = state.template getAny>(RKFieldNames::reproducingKernel(mOrder)); + const auto& WR = state.template get>(RKFieldNames::reproducingKernel(mOrder)); const auto& W = WR.kernel(); const auto& connectivityMap = dataBase.connectivityMap(); const auto mass = state.fields(HydroFieldNames::mass, 0.0); @@ -311,7 +311,7 @@ initialize(const typename Dimension::Scalar time, State& state, StateDerivatives& derivs) { // Initialize the artificial viscosity - const auto& WR = state.template getAny>(RKFieldNames::reproducingKernel(mOrder)); + const auto& WR = state.template get>(RKFieldNames::reproducingKernel(mOrder)); auto& Q = this->artificialViscosity(); Q.initialize(dataBase, state, diff --git a/src/CRKSPH/CRKSPHHydroBaseRZ.cc b/src/CRKSPH/CRKSPHHydroBaseRZ.cc index 53387f123..3dc37ca08 100644 --- a/src/CRKSPH/CRKSPHHydroBaseRZ.cc +++ b/src/CRKSPH/CRKSPHHydroBaseRZ.cc @@ -219,7 +219,7 @@ evaluateDerivatives(const Dim<2>::Scalar /*time*/, // The kernels and such. //const auto order = this->correctionOrder(); - const auto& WR = state.template getAny>(RKFieldNames::reproducingKernel(mOrder)); + const auto& WR = state.template get>(RKFieldNames::reproducingKernel(mOrder)); // A few useful constants we'll use in the following loop. //const auto tiny = 1.0e-30; @@ -263,7 +263,7 @@ evaluateDerivatives(const Dim<2>::Scalar /*time*/, auto maxViscousPressure = derivatives.fields(HydroFieldNames::maxViscousPressure, 0.0); auto effViscousPressure = derivatives.fields(HydroFieldNames::effectiveViscousPressure, 0.0); auto viscousWork = derivatives.fields(HydroFieldNames::viscousWork, 0.0); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); CHECK(DxDt.size() == numNodeLists); CHECK(DrhoDt.size() == numNodeLists); diff --git a/src/CRKSPH/SolidCRKSPHHydroBase.cc b/src/CRKSPH/SolidCRKSPHHydroBase.cc index be82ef925..6e9402385 100644 --- a/src/CRKSPH/SolidCRKSPHHydroBase.cc +++ b/src/CRKSPH/SolidCRKSPHHydroBase.cc @@ -258,7 +258,7 @@ evaluateDerivatives(const typename Dimension::Scalar /*time*/, // The kernels and such. const auto order = this->correctionOrder(); - const auto& WR = state.template getAny>(RKFieldNames::reproducingKernel(order)); + const auto& WR = state.template get>(RKFieldNames::reproducingKernel(order)); // A few useful constants we'll use in the following loop. //const double tiny = 1.0e-30; @@ -318,7 +318,7 @@ evaluateDerivatives(const typename Dimension::Scalar /*time*/, auto maxViscousPressure = derivatives.fields(HydroFieldNames::maxViscousPressure, 0.0); auto effViscousPressure = derivatives.fields(HydroFieldNames::effectiveViscousPressure, 0.0); auto viscousWork = derivatives.fields(HydroFieldNames::viscousWork, 0.0); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); auto DSDt = derivatives.fields(IncrementState::prefix() + SolidFieldNames::deviatoricStress, SymTensor::zero); CHECK(DxDt.size() == numNodeLists); diff --git a/src/CRKSPH/SolidCRKSPHHydroBaseRZ.cc b/src/CRKSPH/SolidCRKSPHHydroBaseRZ.cc index 9a6068eb4..b8b3b951b 100644 --- a/src/CRKSPH/SolidCRKSPHHydroBaseRZ.cc +++ b/src/CRKSPH/SolidCRKSPHHydroBaseRZ.cc @@ -275,7 +275,7 @@ evaluateDerivatives(const Dim<2>::Scalar /*time*/, // The kernels and such. const auto order = this->correctionOrder(); - const auto& WR = state.template getAny>(RKFieldNames::reproducingKernel(order)); + const auto& WR = state.template get>(RKFieldNames::reproducingKernel(order)); // A few useful constants we'll use in the following loop. //const double tiny = 1.0e-30; @@ -334,7 +334,7 @@ evaluateDerivatives(const Dim<2>::Scalar /*time*/, auto maxViscousPressure = derivatives.fields(HydroFieldNames::maxViscousPressure, 0.0); auto effViscousPressure = derivatives.fields(HydroFieldNames::effectiveViscousPressure, 0.0); auto viscousWork = derivatives.fields(HydroFieldNames::viscousWork, 0.0); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); auto DSDt = derivatives.fields(IncrementState::prefix() + SolidFieldNames::deviatoricStress, SymTensor::zero); CHECK(DxDt.size() == numNodeLists); diff --git a/src/DEM/SolidBoundary/CircularPlaneSolidBoundary.cc b/src/DEM/SolidBoundary/CircularPlaneSolidBoundary.cc index efda0cf5e..7569c16ca 100644 --- a/src/DEM/SolidBoundary/CircularPlaneSolidBoundary.cc +++ b/src/DEM/SolidBoundary/CircularPlaneSolidBoundary.cc @@ -63,9 +63,9 @@ registerState(DataBase& dataBase, const auto pointKey = boundaryKey +"_point"; const auto velocityKey = boundaryKey +"_velocity"; const auto normalKey = boundaryKey +"_normal"; - state.enrollAny(pointKey,mPoint); - state.enrollAny(pointKey,mVelocity); - state.enrollAny(pointKey,mNormal); + state.enroll(pointKey,mPoint); + state.enroll(pointKey,mVelocity); + state.enroll(pointKey,mNormal); } template diff --git a/src/DEM/SolidBoundary/ClippedSphereSolidBoundary.cc b/src/DEM/SolidBoundary/ClippedSphereSolidBoundary.cc index 0be718247..6ea0bfd7f 100644 --- a/src/DEM/SolidBoundary/ClippedSphereSolidBoundary.cc +++ b/src/DEM/SolidBoundary/ClippedSphereSolidBoundary.cc @@ -87,9 +87,9 @@ registerState(DataBase& dataBase, const auto clipPointKey = boundaryKey +"_clipPoint"; const auto velocityKey = boundaryKey +"_velocity"; - state.enrollAny(pointKey,mCenter); - state.enrollAny(clipPointKey,mClipPoint); - state.enrollAny(pointKey,mVelocity); + state.enroll(pointKey,mCenter); + state.enroll(clipPointKey,mClipPoint); + state.enroll(pointKey,mVelocity); } diff --git a/src/DEM/SolidBoundary/CylinderSolidBoundary.cc b/src/DEM/SolidBoundary/CylinderSolidBoundary.cc index 8280ef1c4..53ad556cc 100644 --- a/src/DEM/SolidBoundary/CylinderSolidBoundary.cc +++ b/src/DEM/SolidBoundary/CylinderSolidBoundary.cc @@ -64,9 +64,9 @@ registerState(DataBase& dataBase, const auto pointKey = boundaryKey +"_point"; const auto velocityKey = boundaryKey +"_velocity"; //const auto normalKey = boundaryKey +"_normal"; - state.enrollAny(pointKey,mPoint); - state.enrollAny(pointKey,mVelocity); - //state.enrollAny(pointKey,mNormal); + state.enroll(pointKey,mPoint); + state.enroll(pointKey,mVelocity); + //state.enroll(pointKey,mNormal); } template diff --git a/src/DEM/SolidBoundary/InfinitePlaneSolidBoundary.cc b/src/DEM/SolidBoundary/InfinitePlaneSolidBoundary.cc index eb29ba5c5..5276f1541 100644 --- a/src/DEM/SolidBoundary/InfinitePlaneSolidBoundary.cc +++ b/src/DEM/SolidBoundary/InfinitePlaneSolidBoundary.cc @@ -54,9 +54,9 @@ registerState(DataBase& dataBase, const auto pointKey = boundaryKey +"_point"; const auto velocityKey = boundaryKey +"_velocity"; const auto normalKey = boundaryKey +"_normal"; - state.enrollAny(pointKey,mPoint); - state.enrollAny(velocityKey,mVelocity); - state.enrollAny(normalKey,mNormal); + state.enroll(pointKey,mPoint); + state.enroll(velocityKey,mVelocity); + state.enroll(normalKey,mNormal); } template diff --git a/src/DEM/SolidBoundary/RectangularPlaneSolidBoundary.cc b/src/DEM/SolidBoundary/RectangularPlaneSolidBoundary.cc index 9e38fc6e2..eff75c3e6 100644 --- a/src/DEM/SolidBoundary/RectangularPlaneSolidBoundary.cc +++ b/src/DEM/SolidBoundary/RectangularPlaneSolidBoundary.cc @@ -57,8 +57,8 @@ registerState(DataBase& dataBase, const auto boundaryKey = "RectangularPlaneSolidBoundary_" + std::to_string(std::abs(this->uniqueIndex())); const auto pointKey = boundaryKey +"_point"; const auto velocityKey = boundaryKey +"_velocity"; - state.enrollAny(pointKey,mPoint); - state.enrollAny(velocityKey,mVelocity); + state.enroll(pointKey,mPoint); + state.enroll(velocityKey,mVelocity); } template void diff --git a/src/DEM/SolidBoundary/SphereSolidBoundary.cc b/src/DEM/SolidBoundary/SphereSolidBoundary.cc index 3a600a1a1..205aaee85 100644 --- a/src/DEM/SolidBoundary/SphereSolidBoundary.cc +++ b/src/DEM/SolidBoundary/SphereSolidBoundary.cc @@ -62,8 +62,8 @@ registerState(DataBase& dataBase, const auto pointKey = boundaryKey +"_point"; const auto velocityKey = boundaryKey +"_velocity"; - state.enrollAny(pointKey,mCenter); - state.enrollAny(pointKey,mVelocity); + state.enroll(pointKey,mCenter); + state.enroll(pointKey,mVelocity); } diff --git a/src/DataBase/CopyStateInline.hh b/src/DataBase/CopyStateInline.hh index 2782a067f..a70425343 100644 --- a/src/DataBase/CopyStateInline.hh +++ b/src/DataBase/CopyStateInline.hh @@ -42,10 +42,10 @@ update(const KeyType& key, REQUIRE(key == mCopyStateName); // The state we're updating - ValueType& f = state.template getAny(key); + ValueType& f = state.template get(key); // The master state we're copying - const ValueType& fmaster = state.template getAny(mMasterStateName); + const ValueType& fmaster = state.template get(mMasterStateName); // Copy the master state using the assignment operator f = fmaster; diff --git a/src/FSISPH/SolidFSISPHEvaluateDerivatives.cc b/src/FSISPH/SolidFSISPHEvaluateDerivatives.cc index fde9b5bd5..6f568dda2 100644 --- a/src/FSISPH/SolidFSISPHEvaluateDerivatives.cc +++ b/src/FSISPH/SolidFSISPHEvaluateDerivatives.cc @@ -147,8 +147,8 @@ secondDerivativesLoop(const typename Dimension::Scalar time, auto XSPHWeightSum = derivatives.fields(HydroFieldNames::XSPHWeightSum, 0.0); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); auto DSDt = derivatives.fields(IncrementState::prefix() + SolidFieldNames::deviatoricStress, SymTensor::zero); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); - auto& pairDepsDt = derivatives.getAny(HydroFieldNames::pairWork, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); + auto& pairDepsDt = derivatives.get(HydroFieldNames::pairWork, vector()); CHECK(M.size() == numNodeLists); CHECK(localM.size() == numNodeLists); diff --git a/src/FSISPH/SolidFSISPHHydroBase.cc b/src/FSISPH/SolidFSISPHHydroBase.cc index d7c920e73..0afac53b4 100644 --- a/src/FSISPH/SolidFSISPHHydroBase.cc +++ b/src/FSISPH/SolidFSISPHHydroBase.cc @@ -445,8 +445,8 @@ registerDerivatives(DataBase& dataBase, CHECK(not derivs.registered(mDvDt)); - derivs.enrollAny(HydroFieldNames::pairAccelerations, mPairAccelerations); - derivs.enrollAny(HydroFieldNames::pairWork, mPairDepsDt); + derivs.enroll(HydroFieldNames::pairAccelerations, mPairAccelerations); + derivs.enroll(HydroFieldNames::pairWork, mPairDepsDt); derivs.enroll(plasticStrainRate); derivs.enroll(mXSPHDeltaV); diff --git a/src/GSPH/GSPHEvaluateDerivatives.cc b/src/GSPH/GSPHEvaluateDerivatives.cc index 4644684d4..6573c7223 100644 --- a/src/GSPH/GSPHEvaluateDerivatives.cc +++ b/src/GSPH/GSPHEvaluateDerivatives.cc @@ -72,8 +72,8 @@ evaluateDerivatives(const typename Dimension::Scalar time, auto DvDt = derivatives.fields(HydroFieldNames::hydroAcceleration, Vector::zero); auto DepsDt = derivatives.fields(IncrementState::prefix() + HydroFieldNames::specificThermalEnergy, 0.0); auto DvDx = derivatives.fields(HydroFieldNames::velocityGradient, Tensor::zero); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); - auto& pairDepsDt = derivatives.getAny(HydroFieldNames::pairWork, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); + auto& pairDepsDt = derivatives.get(HydroFieldNames::pairWork, vector()); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); auto newRiemannDpDx = derivatives.fields(ReplaceState::prefix() + GSPHFieldNames::RiemannPressureGradient,Vector::zero); auto newRiemannDvDx = derivatives.fields(ReplaceState::prefix() + GSPHFieldNames::RiemannVelocityGradient,Tensor::zero); diff --git a/src/GSPH/GenericRiemannHydro.cc b/src/GSPH/GenericRiemannHydro.cc index bd05e621e..af593b15a 100644 --- a/src/GSPH/GenericRiemannHydro.cc +++ b/src/GSPH/GenericRiemannHydro.cc @@ -302,8 +302,8 @@ registerDerivatives(DataBase& dataBase, derivs.enroll(mDspecificThermalEnergyDt); derivs.enroll(mDvDx); derivs.enroll(mM); - derivs.enrollAny(HydroFieldNames::pairAccelerations, mPairAccelerations); - derivs.enrollAny(HydroFieldNames::pairWork, mPairDepsDt); + derivs.enroll(HydroFieldNames::pairAccelerations, mPairAccelerations); + derivs.enroll(HydroFieldNames::pairWork, mPairDepsDt); } //------------------------------------------------------------------------------ diff --git a/src/GSPH/MFMEvaluateDerivatives.cc b/src/GSPH/MFMEvaluateDerivatives.cc index c68d4adfd..32dc20d3d 100644 --- a/src/GSPH/MFMEvaluateDerivatives.cc +++ b/src/GSPH/MFMEvaluateDerivatives.cc @@ -71,8 +71,8 @@ evaluateDerivatives(const typename Dimension::Scalar time, auto DvDt = derivatives.fields(HydroFieldNames::hydroAcceleration, Vector::zero); auto DepsDt = derivatives.fields(IncrementState::prefix() + HydroFieldNames::specificThermalEnergy, 0.0); auto DvDx = derivatives.fields(HydroFieldNames::velocityGradient, Tensor::zero); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); - auto& pairDepsDt = derivatives.getAny(HydroFieldNames::pairWork, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); + auto& pairDepsDt = derivatives.get(HydroFieldNames::pairWork, vector()); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); auto newRiemannDpDx = derivatives.fields(ReplaceState::prefix() + GSPHFieldNames::RiemannPressureGradient,Vector::zero); auto newRiemannDvDx = derivatives.fields(ReplaceState::prefix() + GSPHFieldNames::RiemannVelocityGradient,Tensor::zero); diff --git a/src/GSPH/MFVEvaluateDerivatives.cc b/src/GSPH/MFVEvaluateDerivatives.cc index 53a17a73b..1e48ae27b 100644 --- a/src/GSPH/MFVEvaluateDerivatives.cc +++ b/src/GSPH/MFVEvaluateDerivatives.cc @@ -89,9 +89,9 @@ secondDerivativesLoop(const typename Dimension::Scalar time, //auto HStretchTensor = derivatives.fields("HStretchTensor", SymTensor::zero); auto newRiemannDpDx = derivatives.fields(ReplaceState::prefix() + GSPHFieldNames::RiemannPressureGradient,Vector::zero); auto newRiemannDvDx = derivatives.fields(ReplaceState::prefix() + GSPHFieldNames::RiemannVelocityGradient,Tensor::zero); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); - auto& pairDepsDt = derivatives.getAny(HydroFieldNames::pairWork, vector()); - auto& pairMassFlux = derivatives.getAny(GSPHFieldNames::pairMassFlux, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); + auto& pairDepsDt = derivatives.get(HydroFieldNames::pairWork, vector()); + auto& pairMassFlux = derivatives.get(GSPHFieldNames::pairMassFlux, vector()); CHECK(DrhoDx.size() == numNodeLists); CHECK(M.size() == numNodeLists); diff --git a/src/GSPH/MFVHydroBase.cc b/src/GSPH/MFVHydroBase.cc index be68ef1fa..e60aea910 100644 --- a/src/GSPH/MFVHydroBase.cc +++ b/src/GSPH/MFVHydroBase.cc @@ -228,7 +228,7 @@ registerDerivatives(DataBase& dataBase, derivs.enroll(mDmomentumDt); derivs.enroll(mDvolumeDt); //derivs.enroll(mHStretchTensor); - derivs.enrollAny(GSPHFieldNames::pairMassFlux, mPairMassFlux); + derivs.enroll(GSPHFieldNames::pairMassFlux, mPairMassFlux); } //------------------------------------------------------------------------------ diff --git a/src/GSPH/Policies/CompatibleMFVSpecificThermalEnergyPolicy.cc b/src/GSPH/Policies/CompatibleMFVSpecificThermalEnergyPolicy.cc index eb7a33a27..4f2a4a958 100644 --- a/src/GSPH/Policies/CompatibleMFVSpecificThermalEnergyPolicy.cc +++ b/src/GSPH/Policies/CompatibleMFVSpecificThermalEnergyPolicy.cc @@ -84,9 +84,9 @@ update(const KeyType& key, const auto velocity = state.fields(HydroFieldNames::velocity, Vector::zero); const auto DmassDt = derivs.fields(IncrementState::prefix() + HydroFieldNames::mass, 0.0); const auto DmomentumDt = derivs.fields(IncrementState::prefix() + GSPHFieldNames::momentum, Vector::zero); - const auto& pairAccelerations = derivs.getAny(HydroFieldNames::pairAccelerations, vector()); - const auto& pairDepsDt = derivs.getAny(HydroFieldNames::pairWork, vector()); - const auto& pairMassFlux = derivs.getAny(GSPHFieldNames::pairMassFlux, vector()); + const auto& pairAccelerations = derivs.get(HydroFieldNames::pairAccelerations, vector()); + const auto& pairDepsDt = derivs.get(HydroFieldNames::pairWork, vector()); + const auto& pairMassFlux = derivs.get(GSPHFieldNames::pairMassFlux, vector()); const auto& connectivityMap = mDataBasePtr->connectivityMap(); const auto& pairs = connectivityMap.nodePairList(); diff --git a/src/Hydro/CompatibleDifferenceSpecificThermalEnergyPolicy.cc b/src/Hydro/CompatibleDifferenceSpecificThermalEnergyPolicy.cc index 154900215..c761a4ac9 100644 --- a/src/Hydro/CompatibleDifferenceSpecificThermalEnergyPolicy.cc +++ b/src/Hydro/CompatibleDifferenceSpecificThermalEnergyPolicy.cc @@ -79,8 +79,8 @@ update(const KeyType& key, const auto mass = state.fields(HydroFieldNames::mass, Scalar()); const auto velocity = state.fields(HydroFieldNames::velocity, Vector::zero); const auto acceleration = derivs.fields(HydroFieldNames::hydroAcceleration, Vector::zero); - const auto& pairAccelerations = derivs.getAny(HydroFieldNames::pairAccelerations, vector()); - const auto& pairDepsDt = derivs.getAny(HydroFieldNames::pairWork, vector()); + const auto& pairAccelerations = derivs.get(HydroFieldNames::pairAccelerations, vector()); + const auto& pairDepsDt = derivs.get(HydroFieldNames::pairWork, vector()); const auto& connectivityMap = mDataBasePtr->connectivityMap(); const auto& pairs = connectivityMap.nodePairList(); const auto npairs = pairs.size(); diff --git a/src/Hydro/NonSymmetricSpecificThermalEnergyPolicy.cc b/src/Hydro/NonSymmetricSpecificThermalEnergyPolicy.cc index 8646392df..66827862f 100644 --- a/src/Hydro/NonSymmetricSpecificThermalEnergyPolicy.cc +++ b/src/Hydro/NonSymmetricSpecificThermalEnergyPolicy.cc @@ -82,7 +82,7 @@ update(const KeyType& key, const auto velocity = state.fields(HydroFieldNames::velocity, Vector::zero); const auto acceleration = derivs.fields(HydroFieldNames::hydroAcceleration, Vector::zero); const auto eps0 = state.fields(HydroFieldNames::specificThermalEnergy + "0", Scalar()); - const auto& pairAccelerations = derivs.getAny(HydroFieldNames::pairAccelerations, vector()); + const auto& pairAccelerations = derivs.get(HydroFieldNames::pairAccelerations, vector()); const auto DepsDt0 = derivs.fields(IncrementState >::prefix() + HydroFieldNames::specificThermalEnergy, 0.0); const auto& connectivityMap = mDataBasePtr->connectivityMap(); const auto& pairs = connectivityMap.nodePairList(); diff --git a/src/Hydro/RZNonSymmetricSpecificThermalEnergyPolicy.cc b/src/Hydro/RZNonSymmetricSpecificThermalEnergyPolicy.cc index 90ff42f2b..0519846c2 100644 --- a/src/Hydro/RZNonSymmetricSpecificThermalEnergyPolicy.cc +++ b/src/Hydro/RZNonSymmetricSpecificThermalEnergyPolicy.cc @@ -94,7 +94,7 @@ update(const KeyType& key, const auto velocity = state.fields(HydroFieldNames::velocity, Vector::zero); const auto acceleration = derivs.fields(HydroFieldNames::hydroAcceleration, Vector::zero); const auto eps0 = state.fields(HydroFieldNames::specificThermalEnergy + "0", Scalar()); - const auto& pairAccelerations = derivs.getAny(HydroFieldNames::pairAccelerations, vector()); + const auto& pairAccelerations = derivs.get(HydroFieldNames::pairAccelerations, vector()); const auto DepsDt0 = derivs.fields(IncrementState >::prefix() + HydroFieldNames::specificThermalEnergy, 0.0); const auto& connectivityMap = mDataBasePtr->connectivityMap(); const auto& pairs = connectivityMap.nodePairList(); diff --git a/src/Hydro/SpecificThermalEnergyPolicy.cc b/src/Hydro/SpecificThermalEnergyPolicy.cc index c856e579f..51fd1691d 100644 --- a/src/Hydro/SpecificThermalEnergyPolicy.cc +++ b/src/Hydro/SpecificThermalEnergyPolicy.cc @@ -79,7 +79,7 @@ update(const KeyType& key, const auto mass = state.fields(HydroFieldNames::mass, Scalar()); const auto velocity = state.fields(HydroFieldNames::velocity, Vector::zero); const auto DvDt = derivs.fields(HydroFieldNames::hydroAcceleration, Vector::zero); - const auto& pairAccelerations = derivs.getAny(HydroFieldNames::pairAccelerations, vector()); + const auto& pairAccelerations = derivs.get(HydroFieldNames::pairAccelerations, vector()); const auto DepsDt0 = derivs.fields(IncrementState >::prefix() + HydroFieldNames::specificThermalEnergy, 0.0); const auto& connectivityMap = mDataBasePtr->connectivityMap(); const auto& pairs = connectivityMap.nodePairList(); diff --git a/src/RK/RKCorrections.cc b/src/RK/RKCorrections.cc index 5efdaff8d..a29314140 100644 --- a/src/RK/RKCorrections.cc +++ b/src/RK/RKCorrections.cc @@ -155,9 +155,9 @@ RKCorrections:: registerState(DataBase& dataBase, State& state) { // Stuff RKCorrections owns - state.enrollAny(RKFieldNames::rkOrders, mOrders); + state.enroll(RKFieldNames::rkOrders, mOrders); for (auto order: mOrders) { - state.enrollAny(RKFieldNames::reproducingKernel(order), mWR[order]); + state.enroll(RKFieldNames::reproducingKernel(order), mWR[order]); state.enroll(mCorrections[order]); } state.enroll(mVolume); diff --git a/src/RK/ReproducingKernel.cc b/src/RK/ReproducingKernel.cc index 09c1a7a3b..c2c3ad26c 100644 --- a/src/RK/ReproducingKernel.cc +++ b/src/RK/ReproducingKernel.cc @@ -46,11 +46,14 @@ operator=(const ReproducingKernel& rhs) { } //------------------------------------------------------------------------------ -// Destructor +// Equivalence //------------------------------------------------------------------------------ template +bool ReproducingKernel:: -~ReproducingKernel() { -} +operator==(const ReproducingKernel& rhs) const { + return (ReproducingKernelMethods::operator==(rhs) and + *mWptr == *(rhs.mWptr)); +} } diff --git a/src/RK/ReproducingKernel.hh b/src/RK/ReproducingKernel.hh index cfaf073ec..1296b32c6 100644 --- a/src/RK/ReproducingKernel.hh +++ b/src/RK/ReproducingKernel.hh @@ -24,7 +24,8 @@ public: ReproducingKernel(); ReproducingKernel(const ReproducingKernel& rhs); ReproducingKernel& operator=(const ReproducingKernel& rhs); - ~ReproducingKernel(); + virtual ~ReproducingKernel() {} + bool operator==(const ReproducingKernel& rhs) const; // Base kernel calls Scalar evaluateBaseKernel(const Vector& x, diff --git a/src/RK/ReproducingKernelMethods.cc b/src/RK/ReproducingKernelMethods.cc index ef85e2339..99470955b 100644 --- a/src/RK/ReproducingKernelMethods.cc +++ b/src/RK/ReproducingKernelMethods.cc @@ -226,11 +226,15 @@ operator=(const ReproducingKernelMethods& rhs) { } //------------------------------------------------------------------------------ -// Destructor +// Equivalence //------------------------------------------------------------------------------ template +bool ReproducingKernelMethods:: -~ReproducingKernelMethods() { -} +operator==(const ReproducingKernelMethods& rhs) const { + return (mOrder == rhs.mOrder and + mGradCorrectionsSize == rhs.mGradCorrectionsSize and + mHessCorrectionsSize == rhs.mHessCorrectionsSize); +} } diff --git a/src/RK/ReproducingKernelMethods.hh b/src/RK/ReproducingKernelMethods.hh index d7ffd41e0..d4a9908d8 100644 --- a/src/RK/ReproducingKernelMethods.hh +++ b/src/RK/ReproducingKernelMethods.hh @@ -26,7 +26,8 @@ public: ReproducingKernelMethods(); ReproducingKernelMethods(const ReproducingKernelMethods& rhs); ReproducingKernelMethods& operator=(const ReproducingKernelMethods& rhs); - ~ReproducingKernelMethods(); + virtual ~ReproducingKernelMethods() {} + bool operator==(const ReproducingKernelMethods& rhs) const; // Build a transformation operator TransformationMatrix transformationMatrix(const Tensor& T, diff --git a/src/SPH/PSPHHydroBase.cc b/src/SPH/PSPHHydroBase.cc index c6b4901c1..d5794eba8 100644 --- a/src/SPH/PSPHHydroBase.cc +++ b/src/SPH/PSPHHydroBase.cc @@ -296,7 +296,7 @@ evaluateDerivatives(const typename Dimension::Scalar /*time*/, auto maxViscousPressure = derivatives.fields(HydroFieldNames::maxViscousPressure, 0.0); auto effViscousPressure = derivatives.fields(HydroFieldNames::effectiveViscousPressure, 0.0); auto viscousWork = derivatives.fields(HydroFieldNames::viscousWork, 0.0); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHWeightSum = derivatives.fields(HydroFieldNames::XSPHWeightSum, 0.0); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); CHECK(rhoSum.size() == numNodeLists); diff --git a/src/SPH/SPHHydroBase.cc b/src/SPH/SPHHydroBase.cc index 41a06699b..2c38438f0 100644 --- a/src/SPH/SPHHydroBase.cc +++ b/src/SPH/SPHHydroBase.cc @@ -354,7 +354,7 @@ registerDerivatives(DataBase& dataBase, derivs.enroll(mGradRho); derivs.enroll(mM); derivs.enroll(mLocalM); - derivs.enrollAny(HydroFieldNames::pairAccelerations, mPairAccelerations); + derivs.enroll(HydroFieldNames::pairAccelerations, mPairAccelerations); TIME_END("SPHregisterDerivs"); } @@ -645,7 +645,7 @@ evaluateDerivatives(const typename Dimension::Scalar time, auto maxViscousPressure = derivs.fields(HydroFieldNames::maxViscousPressure, 0.0); auto effViscousPressure = derivs.fields(HydroFieldNames::effectiveViscousPressure, 0.0); auto viscousWork = derivs.fields(HydroFieldNames::viscousWork, 0.0); - auto& pairAccelerations = derivs.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivs.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHWeightSum = derivs.fields(HydroFieldNames::XSPHWeightSum, 0.0); auto XSPHDeltaV = derivs.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); CHECK(rhoSum.size() == numNodeLists); diff --git a/src/SPH/SPHHydroBaseRZ.cc b/src/SPH/SPHHydroBaseRZ.cc index ecf3b36b7..6484480da 100644 --- a/src/SPH/SPHHydroBaseRZ.cc +++ b/src/SPH/SPHHydroBaseRZ.cc @@ -244,7 +244,7 @@ evaluateDerivatives(const Dim<2>::Scalar /*time*/, auto maxViscousPressure = derivatives.fields(HydroFieldNames::maxViscousPressure, 0.0); auto effViscousPressure = derivatives.fields(HydroFieldNames::effectiveViscousPressure, 0.0); auto viscousWork = derivatives.fields(HydroFieldNames::viscousWork, 0.0); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHWeightSum = derivatives.fields(HydroFieldNames::XSPHWeightSum, 0.0); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); CHECK(rhoSum.size() == numNodeLists); diff --git a/src/SPH/SolidSPHHydroBase.cc b/src/SPH/SolidSPHHydroBase.cc index 9d996eace..27afce77f 100644 --- a/src/SPH/SolidSPHHydroBase.cc +++ b/src/SPH/SolidSPHHydroBase.cc @@ -371,7 +371,7 @@ evaluateDerivatives(const typename Dimension::Scalar /*time*/, auto effViscousPressure = derivatives.fields(HydroFieldNames::effectiveViscousPressure, 0.0); auto rhoSumCorrection = derivatives.fields(HydroFieldNames::massDensityCorrection, 0.0); auto viscousWork = derivatives.fields(HydroFieldNames::viscousWork, 0.0); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHWeightSum = derivatives.fields(HydroFieldNames::XSPHWeightSum, 0.0); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); auto DSDt = derivatives.fields(IncrementState::prefix() + SolidFieldNames::deviatoricStress, SymTensor::zero); diff --git a/src/SPH/SolidSPHHydroBaseRZ.cc b/src/SPH/SolidSPHHydroBaseRZ.cc index 0a4eeb52e..7ea9c1d86 100644 --- a/src/SPH/SolidSPHHydroBaseRZ.cc +++ b/src/SPH/SolidSPHHydroBaseRZ.cc @@ -295,7 +295,7 @@ evaluateDerivatives(const Dim<2>::Scalar /*time*/, auto effViscousPressure = derivatives.fields(HydroFieldNames::effectiveViscousPressure, 0.0); auto rhoSumCorrection = derivatives.fields(HydroFieldNames::massDensityCorrection, 0.0); auto viscousWork = derivatives.fields(HydroFieldNames::viscousWork, 0.0); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHWeightSum = derivatives.fields(HydroFieldNames::XSPHWeightSum, 0.0); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); auto DSDt = derivatives.fields(IncrementState::prefix() + SolidFieldNames::deviatoricStress, SymTensor::zero); diff --git a/src/SPH/SolidSphericalSPHHydroBase.cc b/src/SPH/SolidSphericalSPHHydroBase.cc index 31e3e0813..b78cbc1b2 100644 --- a/src/SPH/SolidSphericalSPHHydroBase.cc +++ b/src/SPH/SolidSphericalSPHHydroBase.cc @@ -294,7 +294,7 @@ evaluateDerivatives(const Dim<1>::Scalar /*time*/, auto maxViscousPressure = derivatives.fields(HydroFieldNames::maxViscousPressure, 0.0); auto effViscousPressure = derivatives.fields(HydroFieldNames::effectiveViscousPressure, 0.0); auto rhoSumCorrection = derivatives.fields(HydroFieldNames::massDensityCorrection, 0.0); - auto& pairAccelerations = derivatives.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivatives.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHWeightSum = derivatives.fields(HydroFieldNames::XSPHWeightSum, 0.0); auto XSPHDeltaV = derivatives.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); auto DSDt = derivatives.fields(IncrementState::prefix() + SolidFieldNames::deviatoricStress, SymTensor::zero); diff --git a/src/SPH/SphericalSPHHydroBase.cc b/src/SPH/SphericalSPHHydroBase.cc index 843ed7ef6..d6591f2e6 100644 --- a/src/SPH/SphericalSPHHydroBase.cc +++ b/src/SPH/SphericalSPHHydroBase.cc @@ -248,7 +248,7 @@ evaluateDerivatives(const Dim<1>::Scalar time, auto localM = derivs.fields("local " + HydroFieldNames::M_SPHCorrection, Tensor::zero); auto maxViscousPressure = derivs.fields(HydroFieldNames::maxViscousPressure, 0.0); auto effViscousPressure = derivs.fields(HydroFieldNames::effectiveViscousPressure, 0.0); - auto& pairAccelerations = derivs.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivs.get(HydroFieldNames::pairAccelerations, vector()); auto XSPHWeightSum = derivs.fields(HydroFieldNames::XSPHWeightSum, 0.0); auto XSPHDeltaV = derivs.fields(HydroFieldNames::XSPHDeltaV, Vector::zero); CHECK(rhoSum.size() == numNodeLists); diff --git a/src/VoronoiCells/SubPointPressureHourglassControl.cc b/src/VoronoiCells/SubPointPressureHourglassControl.cc index 9e5174d47..578202237 100644 --- a/src/VoronoiCells/SubPointPressureHourglassControl.cc +++ b/src/VoronoiCells/SubPointPressureHourglassControl.cc @@ -341,7 +341,7 @@ evaluateDerivatives(const Scalar time, auto DvDt = derivs.fields(HydroFieldNames::hydroAcceleration, Vector::zero); auto DepsDt = derivs.fields(IncrementState::prefix() + HydroFieldNames::specificThermalEnergy, 0.0); auto DxDt = derivs.fields(IncrementState::prefix() + HydroFieldNames::position, Vector::zero); - auto& pairAccelerations = derivs.getAny(HydroFieldNames::pairAccelerations, vector()); + auto& pairAccelerations = derivs.get(HydroFieldNames::pairAccelerations, vector()); CHECK(DvDt.size() == numNodeLists); CHECK(DepsDt.size() == numNodeLists); From bf5c93e9f3575401733a7f0017ab62bac3716d36 Mon Sep 17 00:00:00 2001 From: Mike Owen Date: Wed, 6 Nov 2024 09:52:11 -0800 Subject: [PATCH 03/14] First pass at new State(Base) interface replacing boost::any with std::variant --- src/DataBase/State.cc | 206 ++++++++--------- src/DataBase/State.hh | 50 ++-- src/DataBase/StateBase.cc | 381 ++++++++++++++++--------------- src/DataBase/StateBase.hh | 148 ++++++------ src/DataBase/StateBaseInline.hh | 112 ++++----- src/DataBase/StateDerivatives.cc | 59 +++-- src/DataBase/StateDerivatives.hh | 27 ++- src/DataBase/StateInline.hh | 79 ++----- src/PYB11/DataBase/StateBase.py | 48 ++-- 9 files changed, 531 insertions(+), 579 deletions(-) diff --git a/src/DataBase/State.cc b/src/DataBase/State.cc index 0859d195d..186c28d07 100644 --- a/src/DataBase/State.cc +++ b/src/DataBase/State.cc @@ -94,9 +94,7 @@ State(DataBase& dataBase, mPolicyMap(), mTimeAdvanceOnly(false) { // Iterate over the physics packages, and have them register their state. - for (PackageIterator itr = physicsPackages.begin(); - itr != physicsPackages.end(); - ++itr) (*itr)->registerState(dataBase, *this); + for (auto pkg: physicsPackages) pkg->registerState(dataBase, *this); } //------------------------------------------------------------------------------ @@ -111,9 +109,7 @@ State(DataBase& dataBase, mPolicyMap(), mTimeAdvanceOnly(false) { // Iterate over the physics packages, and have them register their state. - for (PackageIterator itr = physicsPackageBegin; - itr != physicsPackageEnd; - ++itr) (*itr)->registerState(dataBase, *this); + for (auto pkg: range(physicsPackageBegin, physicsPackageEnd)) pkg->registerState(dataBase, *this); } //------------------------------------------------------------------------------ @@ -160,6 +156,105 @@ operator==(const StateBase& rhs) const { return StateBase::operator==(rhs); } +//------------------------------------------------------------------------------ +// The set of keys for all registered policies. +//------------------------------------------------------------------------------ +template +vector::KeyType> +State:: +policyKeys() const { + vector result; + for (const auto itr: mPolicyMap) result.push_back(itr.first); + ENSURE(result.size() == mPolicyMap.size()); + return result; +} + +//------------------------------------------------------------------------------ +// Return the policy for the given key. +//------------------------------------------------------------------------------ +template +typename State::PolicyPointer +State:: +policy(const typename State::KeyType& key) const { + KeyType fieldKey, nodeKey; + this->splitFieldKey(key, fieldKey, nodeKey); + const auto outerItr = mPolicyMap.find(fieldKey); + if (outerItr == mPolicyMap.end()) return PolicyPointer(); + // VERIFY2(outerItr != mPolicyMap.end(), + // "State ERROR: attempted to retrieve non-existent policy for key " << key); + const auto& key2policies = outerItr->second; + const auto innerItr = key2policies.find(key); + if (innerItr == key2policies.end()) return PolicyPointer(); + // VERIFY2(innerItr != policies.end(), + // "State ERROR: attempted to retrieve non-existent policy for key " << key); + return innerItr->second; +} + +//------------------------------------------------------------------------------ +// Return all the policies for the given field key. +//------------------------------------------------------------------------------ +template +std::map::KeyType, typename State::PolicyPointer> +State:: +policies(const typename State::KeyType& fieldKey) const { + const auto outerItr = mPolicyMap.find(fieldKey); + if (outerItr == mPolicyMap.end()) return std::map(); + // VERIFY2(outerItr != mPolicyMap.end(), + // "State ERROR: attempted to retrieve non-existent policy for key " << key); + return outerItr->second; +} + +//------------------------------------------------------------------------------ +// Remove the policy associated with the given key. +//------------------------------------------------------------------------------ +template +void +State:: +removePolicy(const typename State::KeyType& key) { + KeyType fieldKey, nodeKey; + this->splitFieldKey(key, fieldKey, nodeKey); + typename PolicyMapType::iterator outerItr = mPolicyMap.find(fieldKey); + VERIFY2(outerItr != mPolicyMap.end(), + "State ERROR: attempted to remove non-existent policy for field key " << fieldKey); + std::map& policies = outerItr->second; + typename std::map::iterator innerItr = policies.find(key); + if (innerItr == policies.end()) { + cerr << "State ERROR: attempted to remove non-existent policy for inner key " << key << endl + << "Known keys are: " << endl; + for (auto itr = policies.begin(); itr != policies.end(); ++itr) cerr << " --> " << itr->first << endl; + VERIFY(innerItr != policies.end()); + } + policies.erase(innerItr); + if (policies.size() == 0) mPolicyMap.erase(outerItr); +} + +//------------------------------------------------------------------------------ +// Remove the policy associated with a Field. +//------------------------------------------------------------------------------ +template +void +State:: +removePolicy(FieldBase& field) { + this->removePolicy(StateBase::key(field)); +} + +//------------------------------------------------------------------------------ +// Remove the policy associated with a FieldList. +//------------------------------------------------------------------------------ +template +void +State:: +removePolicy(FieldListBase& fieldList, + const bool clonePerField) { + if (clonePerField) { + for (auto fieldPtrItr = fieldList.begin_base(); + fieldPtrItr < fieldList.end_base(); + ++fieldPtrItr) this->removePolicy(**fieldPtrItr); + } else { + this->removePolicy(StateBase::key(fieldList)); + } +} + //------------------------------------------------------------------------------ // Update the state with the given derivatives object, according to the per // state field policies. @@ -273,104 +368,5 @@ update(StateDerivatives& derivs, } } -//------------------------------------------------------------------------------ -// The set of keys for all registered policies. -//------------------------------------------------------------------------------ -template -vector::KeyType> -State:: -policyKeys() const { - vector result; - for (const auto itr: mPolicyMap) result.push_back(itr.first); - ENSURE(result.size() == mPolicyMap.size()); - return result; -} - -//------------------------------------------------------------------------------ -// Return the policy for the given key. -//------------------------------------------------------------------------------ -template -typename State::PolicyPointer -State:: -policy(const typename State::KeyType& key) const { - KeyType fieldKey, nodeKey; - this->splitFieldKey(key, fieldKey, nodeKey); - const auto outerItr = mPolicyMap.find(fieldKey); - if (outerItr == mPolicyMap.end()) return PolicyPointer(); - // VERIFY2(outerItr != mPolicyMap.end(), - // "State ERROR: attempted to retrieve non-existent policy for key " << key); - const auto& key2policies = outerItr->second; - const auto innerItr = key2policies.find(key); - if (innerItr == key2policies.end()) return PolicyPointer(); - // VERIFY2(innerItr != policies.end(), - // "State ERROR: attempted to retrieve non-existent policy for key " << key); - return innerItr->second; -} - -//------------------------------------------------------------------------------ -// Return all the policies for the given field key. -//------------------------------------------------------------------------------ -template -std::map::KeyType, typename State::PolicyPointer> -State:: -policies(const typename State::KeyType& fieldKey) const { - const auto outerItr = mPolicyMap.find(fieldKey); - if (outerItr == mPolicyMap.end()) return std::map(); - // VERIFY2(outerItr != mPolicyMap.end(), - // "State ERROR: attempted to retrieve non-existent policy for key " << key); - return outerItr->second; -} - -//------------------------------------------------------------------------------ -// Remove the policy associated with the given key. -//------------------------------------------------------------------------------ -template -void -State:: -removePolicy(const typename State::KeyType& key) { - KeyType fieldKey, nodeKey; - this->splitFieldKey(key, fieldKey, nodeKey); - typename PolicyMapType::iterator outerItr = mPolicyMap.find(fieldKey); - VERIFY2(outerItr != mPolicyMap.end(), - "State ERROR: attempted to remove non-existent policy for field key " << fieldKey); - std::map& policies = outerItr->second; - typename std::map::iterator innerItr = policies.find(key); - if (innerItr == policies.end()) { - cerr << "State ERROR: attempted to remove non-existent policy for inner key " << key << endl - << "Known keys are: " << endl; - for (auto itr = policies.begin(); itr != policies.end(); ++itr) cerr << " --> " << itr->first << endl; - VERIFY(innerItr != policies.end()); - } - policies.erase(innerItr); - if (policies.size() == 0) mPolicyMap.erase(outerItr); -} - -//------------------------------------------------------------------------------ -// Remove the policy associated with a Field. -//------------------------------------------------------------------------------ -template -void -State:: -removePolicy(FieldBase& field) { - this->removePolicy(StateBase::key(field)); -} - -//------------------------------------------------------------------------------ -// Remove the policy associated with a FieldList. -//------------------------------------------------------------------------------ -template -void -State:: -removePolicy(FieldListBase& fieldList, - const bool clonePerField) { - if (clonePerField) { - for (auto fieldPtrItr = fieldList.begin_base(); - fieldPtrItr < fieldList.end_base(); - ++fieldPtrItr) this->removePolicy(**fieldPtrItr); - } else { - this->removePolicy(StateBase::key(fieldList)); - } -} - } diff --git a/src/DataBase/State.hh b/src/DataBase/State.hh index 59b7dbb49..b34f02087 100644 --- a/src/DataBase/State.hh +++ b/src/DataBase/State.hh @@ -40,6 +40,9 @@ public: using PackageIterator = typename PackageList::iterator; using PolicyPointer = typename std::shared_ptr>; + // Promote base overloaded enroll methods + using StateBase::enroll; + // Constructors, destructor. State(); State(DataBase& dataBase, PackageList& physicsPackages); @@ -52,15 +55,22 @@ public: // Assignment. State& operator=(const State& rhs); - // Override the base method. + // Override the base equivalence operator virtual bool operator==(const StateBase& rhs) const override; - // Update the registered state according to the policies. - void update(StateDerivatives& derivs, - const double multiplier, - const double t, - const double dt); + //........................................................................... + // Enroll state with update policies + void enroll(FieldBase& field, PolicyPointer policy); + // Enroll the given FieldList and associated update policy + // This method queries the "clonePerField" method of the policy, and + // if true enrolls each Field in the FieldList with a copy of the policy. + // Otherwise the FieldList is enrolled as a single entity, and the policy is + // assumed to handle a FieldList as a whole. + void enroll(FieldListBase& fieldList, PolicyPointer policy); + + //........................................................................... + // Policies // Enroll a policy by itself. void enroll(const KeyType& key, PolicyPointer policy); @@ -70,23 +80,6 @@ public: void removePolicy(FieldListBase& field, const bool clonePerField); - // Enroll the given Field and associated update policy - void enroll(FieldBase& field, PolicyPointer policy); - - // Enroll the given FieldList and associated update policy - // This method queries the "clonePerField" method of the policy, and - // if true enrolls each Field in the FieldList with a copy of the policy. - // Otherwise the FieldList is enrolled directly as normal, and the policy is - // assumed to handle a FieldList directly. - void enroll(FieldListBase& fieldList, PolicyPointer policy); - - // The base class method for just registering a field. - virtual void enroll(FieldBase& field) override; - virtual void enroll(std::shared_ptr>& fieldPtr) override; - - // The base class method for just registering a field list. - virtual void enroll(FieldListBase& fieldList) override; - // The full set of keys for all policies. std::vector policyKeys() const; @@ -100,10 +93,17 @@ public: template PolicyPointer policy(const Field& field) const; + //........................................................................... + // Update the registered state according to the policies. + void update(StateDerivatives& derivs, + const double multiplier, + const double t, + const double dt); + // Optionally trip a flag indicating policies should time advance only -- no replacing state! // This is useful when you're trying to cheat and reuse derivatives from a prior advance. - bool timeAdvanceOnly() const; - void timeAdvanceOnly(const bool x); + bool timeAdvanceOnly() const { return mTimeAdvanceOnly; } + void timeAdvanceOnly(const bool x) { mTimeAdvanceOnly = x; } private: //--------------------------- Private Interface ---------------------------// diff --git a/src/DataBase/StateBase.cc b/src/DataBase/StateBase.cc index 4ae5e7c99..e7c4a24d1 100644 --- a/src/DataBase/StateBase.cc +++ b/src/DataBase/StateBase.cc @@ -10,6 +10,8 @@ #include "Field/FieldList.hh" #include "Neighbor/ConnectivityMap.hh" #include "Mesh/Mesh.hh" +#include "RK/RKCorrectionParams.hh" +#include "RK/ReproducingKernel.hh" #include "Utilities/DBC.hh" #include @@ -21,10 +23,12 @@ using std::endl; using std::min; using std::max; using std::abs; +using std::sort; namespace Spheral { -// namespace { +namespace { + // //------------------------------------------------------------------------------ // // Helper for copying a type, used in copyState // //------------------------------------------------------------------------------ @@ -39,7 +43,23 @@ namespace Spheral { // } // } -// } +//------------------------------------------------------------------------------ +// Template to downselect comparison in our variant types +//------------------------------------------------------------------------------ +template bool safeCompare(T1& x, const T1& y) { return x == y; } +template bool safeCompare(T1& x, const T2& y) { VERIFY2(false, "Bad comparison!"); return false; } + +//------------------------------------------------------------------------------ +// Template to downselect assignment in our variant types +//------------------------------------------------------------------------------ +template void safeAssign(T1& x, const T1& y) { x = y; } +template void safeAssign(T1& x, const T2& y) { VERIFY2(false, "Bad assignment!"); } + +// Helper with overloading in std::visit +template struct overload : Ts... { using Ts::operator()...; }; +template overload(Ts...) -> overload; + +} //------------------------------------------------------------------------------ // Default constructor. @@ -47,8 +67,11 @@ namespace Spheral { template StateBase:: StateBase(): - mStorage(), - mCache(), + mFieldStorage(), + mFieldCache(), + mMiscStorage(), + mMiscCache(), + mNodeListPtrs(), mConnectivityMapPtr(), mMeshPtr(new MeshType()) { } @@ -59,8 +82,10 @@ StateBase(): template StateBase:: StateBase(const StateBase& rhs): - mStorage(rhs.mStorage), - mCache(), + mFieldStorage(rhs.mFieldStorage), + mFieldCache(), + mMiscStorage(rhs.mMiscStorage), + mMiscCache(), mNodeListPtrs(rhs.mNodeListPtrs), mConnectivityMapPtr(rhs.mConnectivityMapPtr), mMeshPtr(rhs.mMeshPtr) { @@ -82,8 +107,10 @@ StateBase& StateBase:: operator=(const StateBase& rhs) { if (this != &rhs) { - mStorage = rhs.mStorage; - mCache = CacheType(); + mFieldStorage = rhs.mFieldStorage; + mFieldCache = FieldCacheType(); + mMiscStorage = rhs.mMiscStorage; + mMiscCache = MiscCacheType(); mNodeListPtrs = rhs.mNodeListPtrs; mConnectivityMapPtr = rhs.mConnectivityMapPtr; mMeshPtr = rhs.mMeshPtr; @@ -98,71 +125,95 @@ template bool StateBase:: operator==(const StateBase& rhs) const { - if (mStorage.size() != rhs.mStorage.size()) { - cerr << "Storage sizes don't match." << endl; + + // Compare raw sizes + if (mFieldStorage.size() != rhs.mFieldStorage.size()) { + cerr << "Field storage sizes don't match." << endl; return false; } - vector lhsKeys = keys(); - vector rhsKeys = rhs.keys(); - if (lhsKeys.size() != rhsKeys.size()) { - cerr << "Keys sizes don't match." << endl; + if (mMiscStorage.size() != rhs.mMiscStorage.size()) { + cerr << "Miscellaneous storage sizes don't match." << endl; return false; } - sort(lhsKeys.begin(), lhsKeys.end()); - sort(rhsKeys.begin(), rhsKeys.end()); + + // Keys + auto lhsKeys = keys(); + auto rhsKeys = rhs.keys(); if (lhsKeys != rhsKeys) { cerr << "Keys don't match." << endl; return false; } - // Walk the keys, and rely on the virtual overloaded - // Field::operator==(FieldBase) to do the right thing! - // We are also relying here on the fact that std::map with a given - // set of keys will always result in the same order. - bool result = true; - typename StorageType::const_iterator lhsItr, rhsItr; - for (rhsItr = rhs.mStorage.begin(), lhsItr = mStorage.begin(); - rhsItr != rhs.mStorage.end(); - ++rhsItr, ++lhsItr) { - try { - auto lhsPtr = boost::any_cast*>(lhsItr->second); - auto rhsPtr = boost::any_cast*>(rhsItr->second); - if (*lhsPtr != *rhsPtr) { - cerr << "Fields for " << lhsItr->first << " don't match." << endl; - result = false; + // Compare fields + { + auto lhsitr = mFieldStorage.begin(); + auto rhsitr = rhs.mFieldStorage.begin(); + for (; lhsitr != mFieldStorage.end(); ++lhsitr, ++rhsitr) { + CHECK(rhsitr != rhs.mFieldStorage.end()); + CHECK(lhsitr->first == rhsitr->first); + if (*(lhsitr->second) != *(rhsitr->second)) { + cerr << "Fields don't match for key " << lhsitr->first << endl; + return false; } - } catch (const boost::bad_any_cast&) { - try { - auto lhsPtr = boost::any_cast*>(lhsItr->second); - auto rhsPtr = boost::any_cast*>(rhsItr->second); - if (*lhsPtr != *rhsPtr) { - cerr << "vector for " << lhsItr->first << " don't match." << endl; - result = false; - } - } catch (const boost::bad_any_cast&) { - try { - auto lhsPtr = boost::any_cast(lhsItr->second); - auto rhsPtr = boost::any_cast(rhsItr->second); - if (*lhsPtr != *rhsPtr) { - cerr << "Vector for " << lhsItr->first << " don't match." << endl; - result = false; - } - } catch (const boost::bad_any_cast&) { - try { - auto lhsPtr = boost::any_cast(lhsItr->second); - auto rhsPtr = boost::any_cast(rhsItr->second); - if (*lhsPtr != *rhsPtr) { - cerr << "Scalar for " << lhsItr->first << " don't match." << endl; - result = false; - } - } catch (const boost::bad_any_cast&) { - std::cerr << "StateBase::operator== WARNING: unable to compare values for " << lhsItr->first << "\n"; - } - } + } + } + + // Compare the miscellaneous objects + { + auto lhsitr = mMiscStorage.begin(); + auto rhsitr = rhs.mMiscStorage.begin(); + for (; lhsitr != mMiscStorage.end(); ++lhsitr, ++rhsitr) { + CHECK(rhsitr != rhs.mMiscStorage.end()); + CHECK(lhsitr->first == rhsitr->first); + auto result = std::visit([](auto& x, auto& y) -> bool { return safeCompare(x, y); }, *(lhsitr->second), *(rhsitr->second)); + if (not result) { + cerr << "State does not match for key " << lhsitr->first << endl; + return false; } } } - return result; + + return true; +} + +//------------------------------------------------------------------------------ +// Enroll a Field +//------------------------------------------------------------------------------ +template +void +StateBase:: +enroll(FieldBase& field) { + const auto key = this->key(field); + mFieldStorage[key] = &field; + mNodeListPtrs.insert(field.nodeListPtr()); + // std::cerr << "StateBase::enroll field: " << key << " at " << &field << std::endl; + ENSURE(std::find(mNodeListPtrs.begin(), mNodeListPtrs.end(), field.nodeListPtr()) != mNodeListPtrs.end()); +} + +//------------------------------------------------------------------------------ +// Enroll a Field (shared_ptr). +//------------------------------------------------------------------------------ +template +void +StateBase:: +enroll(std::shared_ptr>& fieldPtr) { + const auto key = this->key(*fieldPtr); + mFieldStorage[key] = fieldPtr.get(); + mNodeListPtrs.insert(fieldPtr->nodeListPtr()); + mFieldCache.push_back(fieldPtr); + ENSURE(std::find(mNodeListPtrs.begin(), mNodeListPtrs.end(), fieldPtr->nodeListPtr()) != mNodeListPtrs.end()); +} + +//------------------------------------------------------------------------------ +// Add the fields from a FieldList. +//------------------------------------------------------------------------------ +template +void +StateBase:: +enroll(FieldListBase& fieldList) { + for (auto* fptr: range(fieldList.begin_base(), fieldList.end_base())) { + this->enroll(*fptr); + } } //------------------------------------------------------------------------------ @@ -172,7 +223,8 @@ template bool StateBase:: registered(const StateBase::KeyType& key) const { - return (mStorage.find(key) != mStorage.end()); + return (mFieldStorage.find(key) != mFieldStorage.end() or + mMiscStorage.find(key) != mMiscStorage.end()); } //------------------------------------------------------------------------------ @@ -182,9 +234,8 @@ template bool StateBase:: registered(const FieldBase& field) const { - const KeyType key = this->key(field); - typename StorageType::const_iterator itr = mStorage.find(key); - return (itr != mStorage.end()); + const auto key = this->key(field); + return mFieldStorage.find(key) != mFieldStorage.end(); } //------------------------------------------------------------------------------ @@ -206,73 +257,47 @@ bool StateBase:: fieldNameRegistered(const FieldName& name) const { KeyType fieldName, nodeListName; - auto itr = mStorage.begin(); - while (itr != mStorage.end()) { - splitFieldKey(itr->first, fieldName, nodeListName); + for (auto [key, valptr]: mFieldStorage) { + splitFieldKey(key, fieldName, nodeListName); if (fieldName == name) return true; - ++itr; } return false; } //------------------------------------------------------------------------------ -// Enroll a field. -//------------------------------------------------------------------------------ -template -void -StateBase:: -enroll(FieldBase& field) { - const KeyType key = this->key(field); - boost::any fieldptr; - fieldptr = &field; - mStorage[key] = fieldptr; - mNodeListPtrs.insert(field.nodeListPtr()); - // std::cerr << "StateBase::enroll field: " << key << " at " << &field << std::endl; - ENSURE(&(this->getAny>(key)) == &field); - ENSURE(find(mNodeListPtrs.begin(), mNodeListPtrs.end(), field.nodeListPtr()) != mNodeListPtrs.end()); -} - -//------------------------------------------------------------------------------ -// Enroll a field (shared_ptr). +// Return the full set of known keys. //------------------------------------------------------------------------------ template -void +std::vector::KeyType> StateBase:: -enroll(std::shared_ptr>& fieldPtr) { - const KeyType key = this->key(*fieldPtr); - mStorage[key] = fieldPtr.get(); - mNodeListPtrs.insert(fieldPtr->nodeListPtr()); - mFieldCache.push_back(fieldPtr); - ENSURE(find(mNodeListPtrs.begin(), mNodeListPtrs.end(), fieldPtr->nodeListPtr()) != mNodeListPtrs.end()); +keys() const { + vector result; + for (auto itr = mFieldStorage.begin(); itr != mFieldStorage.end(); ++itr) result.push_back(itr->first); + for (auto itr = mMiscStorage.begin(); itr != mMiscStorage.end(); ++itr) result.push_back(itr->first); + return result; } //------------------------------------------------------------------------------ -// Add the fields from a FieldList. +// Return the full set of Field Keys (mangled with NodeList names) //------------------------------------------------------------------------------ template -void +std::vector::KeyType> StateBase:: -enroll(FieldListBase& fieldList) { - for (auto itr = fieldList.begin_base(); - itr != fieldList.end_base(); - ++itr) { - this->enroll(**itr); - } +fullFieldKeys() const { + vector result; + for (auto itr = mFieldStorage.begin(); itr != mFieldStorage.end(); ++itr) result.push_back(itr->first); + return result; } //------------------------------------------------------------------------------ -// Return the full set of known keys. +// Return the set of non-field keys. //------------------------------------------------------------------------------ template std::vector::KeyType> StateBase:: -keys() const { +miscKeys() const { vector result; - result.reserve(mStorage.size()); - for (auto itr = mStorage.begin(); - itr != mStorage.end(); - ++itr) result.push_back(itr->first); - ENSURE(result.size() == mStorage.size()); + for (auto itr = mMiscStorage.begin(); itr != mMiscStorage.end(); ++itr) result.push_back(itr->first); return result; } @@ -282,15 +307,12 @@ keys() const { template std::vector::FieldName> StateBase:: -fieldKeys() const { +fieldNames() const { KeyType fieldName, nodeListName; - vector::FieldName> result; - result.reserve(mStorage.size()); - for (typename StorageType::const_iterator itr = mStorage.begin(); - itr != mStorage.end(); - ++itr) { + vector result; + for (auto itr = mFieldStorage.begin(); itr != mFieldStorage.end(); ++itr) { splitFieldKey(itr->first, fieldName, nodeListName); - if (fieldName != "" and nodeListName != "") result.push_back(fieldName); + result.push_back(fieldName); } // Remove any duplicates. This will happen when we've stored the same field @@ -384,58 +406,41 @@ void StateBase:: assign(const StateBase& rhs) { - // Extract the keys for each state, and verify they line up. - REQUIRE(mStorage.size() == rhs.mStorage.size()); - vector lhsKeys = keys(); - vector rhsKeys = rhs.keys(); - REQUIRE(lhsKeys.size() == rhsKeys.size()); - sort(lhsKeys.begin(), lhsKeys.end()); - sort(rhsKeys.begin(), rhsKeys.end()); - REQUIRE(lhsKeys == rhsKeys); - - // Walk the keys, and rely on the underlying type to know how to copy itself. - for (typename StorageType::const_iterator itr = rhs.mStorage.begin(); - itr != rhs.mStorage.end(); - ++itr) { - auto& anylhs = mStorage[itr->first]; - const auto& anyrhs = itr->second; - try { - auto lhsptr = boost::any_cast*>(anylhs); - const auto rhsptr = boost::any_cast*>(anyrhs); - *lhsptr = *rhsptr; - } catch(const boost::bad_any_cast&) { - try { - auto lhsptr = boost::any_cast*>(anylhs); - const auto rhsptr = boost::any_cast*>(anyrhs); - *lhsptr = *rhsptr; - } catch(const boost::bad_any_cast&) { - try { - auto lhsptr = boost::any_cast(anylhs); - const auto rhsptr = boost::any_cast(anyrhs); - *lhsptr = *rhsptr; - } catch(const boost::bad_any_cast&) { - try { - auto lhsptr = boost::any_cast(anylhs); - const auto rhsptr = boost::any_cast(anyrhs); - *lhsptr = *rhsptr; - } catch(const boost::bad_any_cast&) { - // We'll assume other things don't need to be assigned... - // VERIFY2(false, "StateBase::assign ERROR: unknown type for key " << itr->first << "\n"); - } - } - } + // Fields + { + CHECK(mFieldStorage.size() == rhs.mFieldStorage.size()); + auto lhsitr = mFieldStorage.begin(); + auto rhsitr = rhs.mFieldStorage.begin(); + for (; lhsitr != mFieldStorage.end(); ++lhsitr, ++rhsitr) { + CHECK(rhsitr != rhs.mFieldStorage.end()); + CHECK(lhsitr->first == rhsitr->first); + *(lhsitr->second) = *(rhsitr->second); + } + } + + // Miscellaneous state + { + // Depend on assignment working for our AllowedTypes + CHECK(mMiscStorage.size() == rhs.mMiscStorage.size()); + auto lhsitr = mMiscStorage.begin(); + auto rhsitr = rhs.mMiscStorage.begin(); + for (; lhsitr != mMiscStorage.end(); ++lhsitr, ++rhsitr) { + CHECK(rhsitr != rhs.mMiscStorage.end()); + CHECK(lhsitr->first == rhsitr->first); + std::visit([](auto& lhsval, auto& rhsval) { safeAssign(lhsval, rhsval); }, *(lhsitr->second), *(rhsitr->second)); } } + // Copy the connectivity (by reference). This thing is too // big to carry around separate copies! - if (rhs.mConnectivityMapPtr != NULL) { + if (rhs.mConnectivityMapPtr != nullptr) { mConnectivityMapPtr = rhs.mConnectivityMapPtr; } else { mConnectivityMapPtr = ConnectivityMapPtr(); } // Copy the mesh. - if (rhs.mMeshPtr != NULL) { + if (rhs.mMeshPtr != nullptr) { mMeshPtr = MeshPtr(new MeshType()); *mMeshPtr = *(rhs.mMeshPtr); } else { @@ -452,41 +457,37 @@ StateBase:: copyState() { // Remove any pre-existing stuff. - mCache = CacheType(); mFieldCache = FieldCacheType(); + mMiscCache = MiscCacheType(); - // Walk the registered state and copy it to our local cache. - for (auto itr = mStorage.begin(); - itr != mStorage.end(); - ++itr) { - boost::any anythingPtr = itr->second; - - // Is this a Field? - try { - auto ptr = boost::any_cast*>(anythingPtr); - mFieldCache.push_back(ptr->clone()); - itr->second = mFieldCache.back().get(); - - } catch (const boost::bad_any_cast&) { - try { - auto ptr = boost::any_cast*>(anythingPtr); - auto clone = std::shared_ptr>(new vector(*ptr)); - mCache.push_back(clone); - itr->second = clone.get(); - - } catch (const boost::bad_any_cast&) { - try { - auto ptr = boost::any_cast(anythingPtr); - auto clone = std::shared_ptr(new Vector(*ptr)); - mCache.push_back(clone); - itr->second = clone.get(); - - } catch (const boost::bad_any_cast&) { - // We'll assume other things don't need to be copied... - // VERIFY2(false, "StateBase::copyState ERROR: unrecognized type for " << itr->first << "\n"); - } - } - } + // Fields + for (auto itr = mFieldStorage.begin(); itr != mFieldStorage.end(); ++itr) { + auto clone = itr->second->clone(); + mFieldCache.push_back(clone); + itr->second = clone.get(); + } + + // Misc + for (auto itr = mMiscStorage.begin(); itr != mMiscStorage.end(); ++itr) { + std::visit(overload{[](const Scalar& x) { return std::make_shared(x); }, + [](const Vector& x) { return std::make_shared(x); }, + [](const Tensor& x) { return std::make_shared(x); }, + [](const SymTensor& x) { return std::make_shared(x); }, + [](const vector& x) { return std::make_shared(x); }, + [](const vector& x) { return std::make_shared(x); }, + [](const vector& x) { return std::make_shared(x); }, + [](const vector& x) { return std::make_shared(x); }, + [](const set& x) { return std::make_shared(x); }, + [](const set& x) { return std::make_shared(x); }, + [](const ReproducingKernel& x) { return std::make_shared(x); } + }, *(itr->second)); + // [&](auto* xptr) { + // // auto clone = makeClone(*xptr); + // auto clone = std::shared_ptr(makeClone(*xptr)); // new AllowedType(*xptr)); + // // auto clone = std::make_shared(*xptr); + // // mMiscCache.push_back(clone); + // // itr->second = clone.get(); + // }, itr->second); } } diff --git a/src/DataBase/StateBase.hh b/src/DataBase/StateBase.hh index 355b6fbf5..23f26ba0b 100644 --- a/src/DataBase/StateBase.hh +++ b/src/DataBase/StateBase.hh @@ -18,8 +18,6 @@ #include "Field/FieldBase.hh" -#include "boost/any.hpp" - #include #include #include @@ -27,18 +25,18 @@ #include #include #include - -#include "Field/FieldBase.hh" +#include namespace Spheral { // Forward declaration. template class NodeList; -template class FieldListBase; -template class Field; -template class FieldList; +template class Field; +template class FieldList; template class ConnectivityMap; template class Mesh; +template class ReproducingKernel; +enum class RKOrder : int; template class StateBase { @@ -46,22 +44,35 @@ class StateBase { public: //--------------------------- Public Interface ---------------------------// // Useful typedefs - typedef typename Dimension::Scalar Scalar; - typedef typename Dimension::Vector Vector; - typedef typename Dimension::Vector3d Vector3d; - typedef typename Dimension::Tensor Tensor; - typedef typename Dimension::SymTensor SymTensor; - typedef typename Dimension::ThirdRankTensor ThirdRankTensor; - typedef typename Dimension::FourthRankTensor FourthRankTensor; - typedef typename Dimension::FifthRankTensor FifthRankTensor; - typedef typename Spheral::ConnectivityMap ConnectivityMapType; - typedef typename Spheral::Mesh MeshType; - - typedef std::shared_ptr ConnectivityMapPtr; - typedef std::shared_ptr MeshPtr; - - typedef std::string KeyType; - typedef typename FieldBase::FieldName FieldName; + using Scalar = typename Dimension::Scalar; + using Vector = typename Dimension::Vector; + using Vector3d = typename Dimension::Vector3d; + using Tensor = typename Dimension::Tensor; + using SymTensor = typename Dimension::SymTensor; + using ThirdRankTensor = typename Dimension::ThirdRankTensor; + using FourthRankTensor = typename Dimension::FourthRankTensor; + using FifthRankTensor = typename Dimension::FifthRankTensor; + using ConnectivityMapType = typename Spheral::ConnectivityMap; + using MeshType = typename Spheral::Mesh; + + using ConnectivityMapPtr = std::shared_ptr; + using MeshPtr = std::shared_ptr; + + using KeyType = std::string; + using FieldName = typename FieldBase::FieldName; + + // The allowed miscellaneous types beyond Fields and FieldLists State can handle + using AllowedType = std::variant, + std::vector, + std::vector, + std::vector, + std::set, + std::set, + ReproducingKernel>; // Constructors, destructor. StateBase(); @@ -75,64 +86,52 @@ public: virtual bool operator==(const StateBase& rhs) const; //............................................................................ - // Test if the specified Field or key is currently registered. - bool registered(const KeyType& key) const; - bool registered(const FieldBase& field) const; - bool registered(const FieldListBase& fieldList) const; - bool fieldNameRegistered(const FieldName& fieldName) const; + // Enroll state + virtual void enroll(FieldBase& field); + virtual void enroll(std::shared_ptr>& fieldPtr); + virtual void enroll(FieldListBase& fieldList); + template void enroll(const KeyType& key, T& thing); // T has to be one of AllowedTypes //............................................................................ - // Enroll a Field. - virtual void enroll(FieldBase& field); - virtual void enroll(std::shared_ptr>& fieldPtr); + // Access Fields + template Field& field(const KeyType& key) const; + template Field& field(const KeyType& key, + const Value& dummy) const; - // Return the field for the given key. - template - Field& field(const KeyType& key, - const Value& dummy) const; - - // Return all the fields of the given Value. - template - std::vector*> allFields(const Value& dummy) const; - - // This version is for when providing a dummy Value type is not possible/practical. - // Using this form however meand using the cumbersome syntax: state.template field(key) - template - Field& field(const KeyType& key) const; + // Get all registered fields of the given data type + template std::vector*> allFields(const Value& dummy) const; //............................................................................ - // Enroll a FieldList. - virtual void enroll(FieldListBase& fieldList); - - // Return FieldLists constructed from all registered Fields with the given name. - template - FieldList fields(const std::string& name, - const Value& dummy) const; - - // This version is for when providing a dummy Value type is not possible/practical. - // Using this form however meand using the cumbersome syntax: state.template fields(key) - template - FieldList fields(const std::string& name) const; + // Access FieldLists + template FieldList fields(const std::string& name) const; + template FieldList fields(const std::string& name, + const Value& dummy) const; //............................................................................ - // Enroll an arbitrary type - template - void enrollAny(const KeyType& key, Value& thing); - - // Return an arbitrary type (held by any) - template - Value& getAny(const KeyType& key) const; + // Access an arbitrary type + template Value& get(const KeyType& key) const; + template Value& get(const KeyType& key, const Value& dummy) const; - template - Value& getAny(const KeyType& key, const Value& dummy) const; + //............................................................................ + // Test if the specified Field or key is currently registered. + bool registered(const KeyType& key) const; + bool registered(const FieldBase& field) const; + bool registered(const FieldListBase& fieldList) const; + bool fieldNameRegistered(const FieldName& fieldName) const; //............................................................................ - // Return the complete set of keys registered. + // Return the complete set of keys registered std::vector keys() const; + // The field keys including mangling with NodeList names + std::vector fullFieldKeys() const; + + // The non-field (miscellaneous) keys + std::vector miscKeys() const; + // Return the set of known field names (unencoded from our internal mangling // convention with the NodeList name). - std::vector fieldKeys() const; + std::vector fieldNames() const; //............................................................................ // A state object can carry around a reference to a ConnectivityMap. @@ -172,14 +171,17 @@ public: protected: //--------------------------- Protected Interface ---------------------------// - typedef std::map StorageType; - typedef std::list>> FieldCacheType; - typedef std::list CacheType; + using FieldStorageType = std::map*>; + using FieldCacheType = std::list>>; + + using MiscStorageType = std::map; + using MiscCacheType = std::list>; // Protected data. - StorageType mStorage; - CacheType mCache; - FieldCacheType mFieldCache; + FieldStorageType mFieldStorage; + FieldCacheType mFieldCache; + MiscStorageType mMiscStorage; + MiscCacheType mMiscCache; std::set*> mNodeListPtrs; ConnectivityMapPtr mConnectivityMapPtr; MeshPtr mMeshPtr; diff --git a/src/DataBase/StateBaseInline.hh b/src/DataBase/StateBaseInline.hh index 6a8bb41cf..113bcfcb5 100644 --- a/src/DataBase/StateBaseInline.hh +++ b/src/DataBase/StateBaseInline.hh @@ -1,27 +1,47 @@ #include "boost/algorithm/string.hpp" #include "DataBase/UpdatePolicyBase.hh" +#include "RK/RKCorrectionParams.hh" +#include "RK/ReproducingKernel.hh" #include "Mesh/Mesh.hh" +#include "Utilities/range.hh" #include "Utilities/DBC.hh" namespace Spheral { +//------------------------------------------------------------------------------ +// Enroll an arbitrary type +// Must be one of the supported types in StateBase::AllowedType +//------------------------------------------------------------------------------ +template +template +inline +void +StateBase:: +enroll(const KeyType& key, T& thing) { + mMiscStorage[key] = &thing; +} + //------------------------------------------------------------------------------ // Return the Field for the given key. //------------------------------------------------------------------------------ template template +inline Field& StateBase:: -field(const typename StateBase::KeyType& key) const { - try { - return dynamic_cast&>(this->getAny>(key)); - } catch (...) { - VERIFY2(false,"StateBase ERROR: unable to extract field for key " << key << "\n"); - } +field(const KeyType& key) const { + auto itr = mFieldStorage.find(key); + VERIFY2(itr != mFieldStorage.end(), "StateBase ERROR: failed lookup for Field " << key); + auto* fbasePtr = itr->second; + auto* resultPtr = dynamic_cast*>(fbasePtr); + VERIFY2(resultPtr != nullptr, + "StateBase::field ERROR: field type incorrect for key " << key); + return *resultPtr; } template template +inline Field& StateBase:: field(const typename StateBase::KeyType& key, @@ -40,15 +60,9 @@ StateBase:: allFields(const Value&) const { std::vector*> result; KeyType fieldName, nodeListName; - for (auto itr = mStorage.begin(); - itr != mStorage.end(); - ++itr) { - try { - Field* ptr = dynamic_cast*>(boost::any_cast*>(itr->second)); - if (ptr != 0) result.push_back(ptr); - } catch (...) { - // The field must have been the wrong type. - } + for (auto [key, valptr]: mFieldStorage) { + auto* ptr = dynamic_cast*>(valptr); + if (ptr != nullptr) result.push_back(ptr); } return result; } @@ -64,13 +78,13 @@ StateBase:: fields(const std::string& name) const { FieldList result; KeyType fieldName, nodeListName; - for (auto itr = mStorage.begin(); - itr != mStorage.end(); - ++itr) { - splitFieldKey(itr->first, fieldName, nodeListName); + for (auto [key, valptr]: mFieldStorage) { + splitFieldKey(key, fieldName, nodeListName); if (fieldName == name) { CHECK(nodeListName != ""); - result.appendField(this->template field(itr->first)); + auto* fptr = dynamic_cast*>(valptr); + CHECK(valptr != nullptr); + result.appendField(*fptr); } } return result; @@ -86,40 +100,45 @@ fields(const std::string& name, const Value& dummy) const { } //------------------------------------------------------------------------------ -// Enroll an arbitrary type +// Extract an arbitrary type //------------------------------------------------------------------------------ template template -void +inline +Value& StateBase:: -enrollAny(const typename StateBase::KeyType& key, Value& thing) { - mStorage[key] = &thing; +get(const typename StateBase::KeyType& key) const { + auto itr = mMiscStorage.find(key); + VERIFY2(itr != mMiscStorage.end(), "StateBase ERROR: failed lookup for key " << key); + auto* resultPtr = std::get_if(itr->second); + VERIFY2(resultPtr != nullptr, "StateBase::get ERROR: unable to extract Value for " << key << "\n"); + return *resultPtr; } -//------------------------------------------------------------------------------ -// Extract an arbitrary type -//------------------------------------------------------------------------------ +// Same thing passing a dummy argument to help with template type template template +inline Value& StateBase:: -getAny(const typename StateBase::KeyType& key) const { - try { - Value& result = *boost::any_cast(mStorage.find(key)->second); - return result; - } catch (const boost::bad_any_cast&) { - VERIFY2(false, "StateBase::getAny ERROR: unable to extract Value for " << key << "\n"); - } +get(const typename StateBase::KeyType& key, + const Value&) const { + return this->get(key); } -// Same thing passing a dummy argument to help with template type +//------------------------------------------------------------------------------ +// Assign the Fields matching the given name of this State object to be equal to +// the values in another. +//------------------------------------------------------------------------------ template template -Value& +inline +void StateBase:: -getAny(const typename StateBase::KeyType& key, - const Value&) const { - return this->getAny(key); +assignFields(const StateBase& rhs, const std::string name) { + auto lhsfields = this->fields(name, Value()); + auto rhsfields = rhs.fields(name, Value()); + lhsfields.assignFields(rhsfields); } //------------------------------------------------------------------------------ @@ -145,21 +164,6 @@ key(const FieldListBase& fieldList) { return buildFieldKey((*fieldList.begin_base())->name(), UpdatePolicyBase::wildcard()); } -//------------------------------------------------------------------------------ -// Assign the Fields matching the given name of this State object to be equal to -// the values in another. -//------------------------------------------------------------------------------ -template -template -inline -void -StateBase:: -assignFields(const StateBase& rhs, const std::string name) { - auto lhsfields = this->fields(name, Value()); - auto rhsfields = rhs.fields(name, Value()); - lhsfields.assignFields(rhsfields); -} - //------------------------------------------------------------------------------ // Internal methods to encode the convention for combining Field and NodeList // names into a single unique key. diff --git a/src/DataBase/StateDerivatives.cc b/src/DataBase/StateDerivatives.cc index e7e54ee23..f2f914254 100644 --- a/src/DataBase/StateDerivatives.cc +++ b/src/DataBase/StateDerivatives.cc @@ -20,6 +20,14 @@ using std::abs; namespace Spheral { +namespace { + +// Helper with overloading in std::visit +template struct overload : Ts... { using Ts::operator()...; }; +template overload(Ts...) -> overload; + +} + //------------------------------------------------------------------------------ // Default constructor. //------------------------------------------------------------------------------ @@ -41,11 +49,7 @@ StateDerivatives(DataBase& dataBase, StateBase(), mCalculatedNodePairs(), mNumSignificantNeighbors() { - - // Iterate over the physics packages, and have them register their derivatives. - for (PackageIterator itr = physicsPackages.begin(); - itr != physicsPackages.end(); - ++itr) (*itr)->registerDerivatives(dataBase, *this); + for (auto pkg: physicsPackages) pkg->registerDerivatives(dataBase, *this); } //------------------------------------------------------------------------------ @@ -59,11 +63,7 @@ StateDerivatives(DataBase& dataBase, StateBase(), mCalculatedNodePairs(), mNumSignificantNeighbors() { - - // Iterate over the physics packages, and have them register their derivatives. - for (PackageIterator itr = physicsPackageBegin; - itr != physicsPackageEnd; - ++itr) (*itr)->registerDerivatives(dataBase, *this); + for (auto pkg: range(physicsPackageBegin, physicsPackageEnd)) pkg->registerDerivatives(dataBase, *this); } //------------------------------------------------------------------------------ @@ -159,29 +159,22 @@ StateDerivatives:: Zero() { // Walk the state fields and zero them. - for (typename StateBase::StorageType::iterator itr = this->mStorage.begin(); - itr != this->mStorage.end(); - ++itr) { - - try { - auto ptr = boost::any_cast*>(itr->second); - ptr->Zero(); - - } catch (const boost::bad_any_cast&) { - try { - auto ptr = boost::any_cast*>(itr->second); - ptr->clear(); - - } catch (const boost::bad_any_cast&) { - try { - auto ptr = boost::any_cast*>(itr->second); - ptr->clear(); - - } catch (const boost::bad_any_cast&) { - VERIFY2(false, "StateDerivatives::Zero ERROR: unknown type for key " << itr->first << "\n"); - } - } - } + for (auto [key, fptr]: mFieldStorage) fptr->Zero(); + + // Same thing for the miscellaeneous types + for (auto [key, mptr]: mMiscStorage) { + std::visit(overload{[](Scalar& x) { x = 0.0; }, + [](Vector& x) { x = Vector::zero; }, + [](Tensor& x) { x = Tensor::zero; }, + [](SymTensor& x) { x = SymTensor::zero; }, + [](vector& x) { x.clear(); }, + [](vector& x) { x.clear(); }, + [](vector& x) { x.clear(); }, + [](vector& x) { x.clear(); }, + [](set& x) { x.clear(); }, + [](set& x) { }, + [](ReproducingKernel& x) { } + }, *mptr); } // Reinitialize the node pair interaction information. diff --git a/src/DataBase/StateDerivatives.hh b/src/DataBase/StateDerivatives.hh index 64a4dc729..e92012adf 100644 --- a/src/DataBase/StateDerivatives.hh +++ b/src/DataBase/StateDerivatives.hh @@ -26,16 +26,16 @@ class StateDerivatives: public StateBase { public: //--------------------------- Public Interface ---------------------------// // Useful typedefs - typedef typename Dimension::Scalar Scalar; - typedef typename Dimension::Vector Vector; - typedef typename Dimension::Vector3d Vector3d; - typedef typename Dimension::Tensor Tensor; - typedef typename Dimension::SymTensor SymTensor; + using Scalar = typename Dimension::Scalar; + using Vector = typename Dimension::Vector; + using Vector3d = typename Dimension::Vector3d; + using Tensor = typename Dimension::Tensor; + using SymTensor = typename Dimension::SymTensor; - typedef std::vector*> PackageList; - typedef typename PackageList::iterator PackageIterator; + using PackageList = std::vector*>; + using PackageIterator = typename PackageList::iterator; - typedef typename StateBase::KeyType KeyType; + using KeyType = typename StateBase::KeyType; // Constructors, destructor. StateDerivatives(); @@ -73,17 +73,16 @@ private: //--------------------------- Private Interface ---------------------------// // Map for storing information about pairs of nodes that have already been // calculated. - typedef std::map, - std::vector > > CalculatedPairType; + using CalculatedPairType = std::map, + std::vector>>; CalculatedPairType mCalculatedNodePairs; // Map for maintaining the number of significant neighbors per node. - typedef std::map, int> SignificantNeighborMapType; - + using SignificantNeighborMapType = std::map, int>; SignificantNeighborMapType mNumSignificantNeighbors; - using typename StateBase::StorageType; - using StateBase::mStorage; + using StateBase::mFieldStorage; + using StateBase::mMiscStorage; }; } diff --git a/src/DataBase/StateInline.hh b/src/DataBase/StateInline.hh index e8b5e58c3..9d39e3919 100644 --- a/src/DataBase/StateInline.hh +++ b/src/DataBase/StateInline.hh @@ -1,19 +1,5 @@ namespace Spheral { -//------------------------------------------------------------------------------ -// Enroll the given policy. -//------------------------------------------------------------------------------ -template -inline -void -State:: -enroll(const typename State::KeyType& key, - typename State::PolicyPointer polptr) { - KeyType fieldKey, nodeKey; - this->splitFieldKey(key, fieldKey, nodeKey); - mPolicyMap[fieldKey][key] = polptr; -} - //------------------------------------------------------------------------------ // Enroll the given field and policy. //------------------------------------------------------------------------------ @@ -22,9 +8,9 @@ inline void State:: enroll(FieldBase& field, - typename State::PolicyPointer polptr) { + typename State::PolicyPointer policy) { this->enroll(field); - this->enroll(this->key(field), polptr); + this->enroll(this->key(field), policy); } //------------------------------------------------------------------------------ @@ -35,51 +21,32 @@ inline void State:: enroll(FieldListBase& fieldList, - typename State::PolicyPointer polptr) { - if (polptr->clonePerField()) { + typename State::PolicyPointer policy) { + if (policy->clonePerField()) { // std::cerr << "Registering FieldList " << this->key(fieldList) << " with cloning policy" << std::endl; - for (auto bitr = fieldList.begin_base(); bitr < fieldList.end_base(); ++bitr) { - this->enroll(**bitr, polptr); + for (auto fptr: range(fieldList.begin_base(), fieldList.end_base())) { + this->enroll(*fptr, policy); } } else { // std::cerr << "Registering FieldList " << this->key(fieldList) << " with SINGLE policy" << std::endl; // this->enroll(this->key(fieldList), fieldList); this->enroll(fieldList); // enrolls each field without a policy - this->enroll(this->key(fieldList), polptr); + this->enroll(this->key(fieldList), policy); } } //------------------------------------------------------------------------------ -// Enroll the given field. -//------------------------------------------------------------------------------ -template -inline -void -State:: -enroll(FieldBase& field) { - StateBase::enroll(field); -} - -//------------------------------------------------------------------------------ -// Enroll the given field shared_pointer. -//------------------------------------------------------------------------------ -template -inline -void -State:: -enroll(std::shared_ptr>& fieldPtr) { - StateBase::enroll(fieldPtr); -} - -//------------------------------------------------------------------------------ -// Enroll the given field list. +// Enroll the given policy. //------------------------------------------------------------------------------ template inline void State:: -enroll(FieldListBase& fieldList) { - StateBase::enroll(fieldList); +enroll(const typename State::KeyType& key, + typename State::PolicyPointer policy) { + KeyType fieldKey, nodeKey; + this->splitFieldKey(key, fieldKey, nodeKey); + mPolicyMap[fieldKey][key] = policy; } //------------------------------------------------------------------------------ @@ -95,24 +62,4 @@ policy(const Field& field) const { return this->policy(key); } -//------------------------------------------------------------------------------ -// Optionally trip a flag indicating policies should time advance only -- no replacing state! -// This is useful when you're trying to cheat and reuse derivatives from a prior advance. -//------------------------------------------------------------------------------ -template -inline -bool -State:: -timeAdvanceOnly() const { - return mTimeAdvanceOnly; -} - -template -inline -void -State:: -timeAdvanceOnly(const bool x) { - mTimeAdvanceOnly = x; -} - } diff --git a/src/PYB11/DataBase/StateBase.py b/src/PYB11/DataBase/StateBase.py index bf774d9a8..03de427d1 100644 --- a/src/PYB11/DataBase/StateBase.py +++ b/src/PYB11/DataBase/StateBase.py @@ -7,17 +7,17 @@ class StateBase: PYB11typedefs = """ - typedef typename %(Dimension)s::Scalar Scalar; - typedef typename %(Dimension)s::Vector Vector; - typedef typename %(Dimension)s::Tensor Tensor; - typedef typename %(Dimension)s::SymTensor SymTensor; - typedef typename %(Dimension)s::ThirdRankTensor ThirdRankTensor; - typedef typename %(Dimension)s::FourthRankTensor FourthRankTensor; - typedef typename %(Dimension)s::FifthRankTensor FifthRankTensor; - typedef typename %(Dimension)s::FacetedVolume FacetedVolume; - typedef typename StateBase<%(Dimension)s>::KeyType KeyType; - typedef typename StateBase<%(Dimension)s>::FieldName FieldName; - typedef typename StateBase<%(Dimension)s>::MeshPtr MeshPtr; + using Scalar = typename %(Dimension)s::Scalar; + using Vector = typename %(Dimension)s::Vector; + using Tensor = typename %(Dimension)s::Tensor; + using SymTensor = typename %(Dimension)s::SymTensor; + using ThirdRankTensor = typename %(Dimension)s::ThirdRankTensor; + using FourthRankTensor = typename %(Dimension)s::FourthRankTensor; + using FifthRankTensor = typename %(Dimension)s::FifthRankTensor; + using FacetedVolume = typename %(Dimension)s::FacetedVolume; + using KeyType = typename StateBase<%(Dimension)s>::KeyType; + using FieldName = typename StateBase<%(Dimension)s>::FieldName; + using MeshPtr = typename StateBase<%(Dimension)s>::MeshPtr; """ #........................................................................... @@ -100,9 +100,19 @@ def keys(self): return "std::vector" @PYB11const - def fieldKeys(self): - "The set of Field names for the state in the StateBase" - return "std::vector" + def fullFieldKeys(self): + "The set of Field names (with NodeList mangling) for the state in the StateBase" + return "std::vector" + + @PYB11const + def fieldNames(self): + "The set of unique Field names for the state in the StateBase (no NodeList mangling)" + return "std::vector" + + @PYB11const + def miscKeys(self): + "The set of names for non-Fields in the StateBase" + return "std::vector" def enrollConnectivityMap(self, connectivityMapPtr = "std::shared_ptr>"): @@ -228,9 +238,9 @@ def allFields(self, allRKCoefficientsFields = PYB11TemplateMethod(allFields, "RKCoefficients<%(Dimension)s>") #........................................................................... - # enrollAny/getAny + # enroll/get @PYB11template("Value") - def enrollAny(self, + def enroll(self, key = "const KeyType&", thing = "%(Value)s&"): "Enroll a type of %(Value)s." @@ -239,13 +249,13 @@ def enrollAny(self, @PYB11template("Value") @PYB11const @PYB11returnpolicy("reference_internal") - def getAny(self, + def get(self, key = "const KeyType&"): "Return a stored type of %(Value)s" return "%(Value)s&" - enrollVectorVector = PYB11TemplateMethod(enrollAny, "std::vector", pyname="enrollAny") - getVectorVector = PYB11TemplateMethod(getAny, "std::vector", pyname="getAny") + enrollVectorVector = PYB11TemplateMethod(enroll, "std::vector", pyname="enroll") + getVectorVector = PYB11TemplateMethod(get, "std::vector", pyname="get") #........................................................................... # assignFields From 053b403022c3898663cf0d5dcce5b260fcb6b729 Mon Sep 17 00:00:00 2001 From: Mike Owen Date: Fri, 8 Nov 2024 11:06:05 -0800 Subject: [PATCH 04/14] Converting to std::any for our storage in StateBase again, but using a visitor pattern to manipulate and check all types are handled (rather than our old tree of try/catch craziness). --- src/DEM/IncrementPairFieldList.cc | 2 +- src/DataBase/StateBase.cc | 319 +++++++++++++++------------ src/DataBase/StateBase.hh | 40 +--- src/DataBase/StateBaseInline.hh | 48 ++-- src/DataBase/StateDerivatives.cc | 60 +++-- src/DataBase/StateDerivatives.hh | 3 +- src/Hydro/SphericalPositionPolicy.cc | 2 +- 7 files changed, 261 insertions(+), 213 deletions(-) diff --git a/src/DEM/IncrementPairFieldList.cc b/src/DEM/IncrementPairFieldList.cc index e6169f859..371116a60 100644 --- a/src/DEM/IncrementPairFieldList.cc +++ b/src/DEM/IncrementPairFieldList.cc @@ -50,7 +50,7 @@ update(const KeyType& key, // Find all the available matching derivative FieldList keys. const auto incrementKey = prefix() + fieldKey; // cerr << "IncrementPairFieldList: [" << fieldKey << "] [" << incrementKey << "] : " << endl; - const auto allkeys = derivs.fieldKeys(); + const auto allkeys = derivs.fullFieldKeys(); vector incrementKeys; for (const auto& key: allkeys) { // if (std::regex_search(key, std::regex("^" + incrementKey))) { diff --git a/src/DataBase/StateBase.cc b/src/DataBase/StateBase.cc index e7c4a24d1..e3fd67335 100644 --- a/src/DataBase/StateBase.cc +++ b/src/DataBase/StateBase.cc @@ -16,6 +16,7 @@ #include #include + using std::vector; using std::cout; using std::cerr; @@ -24,11 +25,67 @@ using std::min; using std::max; using std::abs; using std::sort; +using std::shared_ptr; +using std::make_shared; +using std::any; +using std::any_cast; namespace Spheral { namespace { +//------------------------------------------------------------------------------ +// Collect visitor methods to apply to std::any object holders +//------------------------------------------------------------------------------ +// 2 args +template +class AnyVisitor2 { +public: + using VisitorFunc = std::function; + + RETURNT visit(ARG1 value1, ARG2 value2) const { + auto it = mVisitors.find(std::type_index(value1.type())); + if (it != mVisitors.end()) { + return it->second(value1, value2); + } + VERIFY2(false, "AnyVisitor ERROR in StateBase: unable to process unknown data"); + } + + template + void addVisitor(VisitorFunc visitor) { + mVisitors[std::type_index(typeid(T))] = visitor; + } + + +private: + std::unordered_map mVisitors; +}; + +//.............................................................................. +// 4 args +template +class AnyVisitor4 { +public: + using VisitorFunc = std::function; + + RETURNT visit(ARG1 value1, ARG2 value2, ARG3 value3, ARG4 value4) const { + auto it = mVisitors.find(std::type_index(value1.type())); + if (it != mVisitors.end()) { + return it->second(value1, value2, value3, value4); + } + VERIFY2(false, "AnyVisitor ERROR in StateBase: unable to process unknown data"); + } + + template + void addVisitor(VisitorFunc visitor) { + mVisitors[std::type_index(typeid(T))] = visitor; + } + + +private: + std::unordered_map mVisitors; +}; + // //------------------------------------------------------------------------------ // // Helper for copying a type, used in copyState // //------------------------------------------------------------------------------ @@ -43,6 +100,20 @@ namespace { // } // } +//------------------------------------------------------------------------------ +// Template for generic cloning during copyState +//------------------------------------------------------------------------------ +template +void +genericClone(std::any& x, + const std::string& key, + typename std::map& storage, + typename std::list& cache) { + auto clone = std::make_shared(*std::any_cast(x)); + cache.push_back(clone); + storage[key] = clone.get(); +} + //------------------------------------------------------------------------------ // Template to downselect comparison in our variant types //------------------------------------------------------------------------------ @@ -55,7 +126,16 @@ template bool safeCompare(T1& x, const T2& y) { VERIFY template void safeAssign(T1& x, const T1& y) { x = y; } template void safeAssign(T1& x, const T2& y) { VERIFY2(false, "Bad assignment!"); } +template T1& safePointer(T1* xptr, const T1* yptr) { return yptr; } +template T1& safePointer(T1* xptr, const T2* yptr) { VERIFY2(false, "Bad assignment!"); return xptr; } + +//------------------------------------------------------------------------------ +//------------------------------------------------------------------------------ +template std::shared_ptr safeClone(const T1& x, const ResultT& dummy) { return std::make_shared(x); } + +//------------------------------------------------------------------------------ // Helper with overloading in std::visit +//------------------------------------------------------------------------------ template struct overload : Ts... { using Ts::operator()...; }; template overload(Ts...) -> overload; @@ -67,57 +147,13 @@ template overload(Ts...) -> overload; template StateBase:: StateBase(): - mFieldStorage(), - mFieldCache(), - mMiscStorage(), - mMiscCache(), + mStorage(), + mCache(), mNodeListPtrs(), mConnectivityMapPtr(), mMeshPtr(new MeshType()) { } -//------------------------------------------------------------------------------ -// Copy constructor. -//------------------------------------------------------------------------------ -template -StateBase:: -StateBase(const StateBase& rhs): - mFieldStorage(rhs.mFieldStorage), - mFieldCache(), - mMiscStorage(rhs.mMiscStorage), - mMiscCache(), - mNodeListPtrs(rhs.mNodeListPtrs), - mConnectivityMapPtr(rhs.mConnectivityMapPtr), - mMeshPtr(rhs.mMeshPtr) { -} - -//------------------------------------------------------------------------------ -// Destructor. -//------------------------------------------------------------------------------ -template -StateBase:: -~StateBase() { -} - -//------------------------------------------------------------------------------ -// Assignment. -//------------------------------------------------------------------------------ -template -StateBase& -StateBase:: -operator=(const StateBase& rhs) { - if (this != &rhs) { - mFieldStorage = rhs.mFieldStorage; - mFieldCache = FieldCacheType(); - mMiscStorage = rhs.mMiscStorage; - mMiscCache = MiscCacheType(); - mNodeListPtrs = rhs.mNodeListPtrs; - mConnectivityMapPtr = rhs.mConnectivityMapPtr; - mMeshPtr = rhs.mMeshPtr; - } - return *this; -} - //------------------------------------------------------------------------------ // Test if the internal state is equal. //------------------------------------------------------------------------------ @@ -127,12 +163,8 @@ StateBase:: operator==(const StateBase& rhs) const { // Compare raw sizes - if (mFieldStorage.size() != rhs.mFieldStorage.size()) { - cerr << "Field storage sizes don't match." << endl; - return false; - } - if (mMiscStorage.size() != rhs.mMiscStorage.size()) { - cerr << "Miscellaneous storage sizes don't match." << endl; + if (mStorage.size() != rhs.mStorage.size()) { + cerr << "Storage sizes don't match." << endl; return false; } @@ -144,32 +176,30 @@ operator==(const StateBase& rhs) const { return false; } - // Compare fields - { - auto lhsitr = mFieldStorage.begin(); - auto rhsitr = rhs.mFieldStorage.begin(); - for (; lhsitr != mFieldStorage.end(); ++lhsitr, ++rhsitr) { - CHECK(rhsitr != rhs.mFieldStorage.end()); - CHECK(lhsitr->first == rhsitr->first); - if (*(lhsitr->second) != *(rhsitr->second)) { - cerr << "Fields don't match for key " << lhsitr->first << endl; - return false; - } - } - } - - // Compare the miscellaneous objects - { - auto lhsitr = mMiscStorage.begin(); - auto rhsitr = rhs.mMiscStorage.begin(); - for (; lhsitr != mMiscStorage.end(); ++lhsitr, ++rhsitr) { - CHECK(rhsitr != rhs.mMiscStorage.end()); - CHECK(lhsitr->first == rhsitr->first); - auto result = std::visit([](auto& x, auto& y) -> bool { return safeCompare(x, y); }, *(lhsitr->second), *(rhsitr->second)); - if (not result) { - cerr << "State does not match for key " << lhsitr->first << endl; - return false; - } + // Build up a visitor to compare each type of state data we support holding + AnyVisitor2 EQUAL; + EQUAL.addVisitor*> ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast*>(x) == *std::any_cast*>(y); }); + EQUAL.addVisitor ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast(x) == *std::any_cast(y); }); + EQUAL.addVisitor ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast(x) == *std::any_cast(y); }); + EQUAL.addVisitor ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast(x) == *std::any_cast(y); }); + EQUAL.addVisitor ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast(x) == *std::any_cast(y); }); + EQUAL.addVisitor*> ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast*>(x) == *std::any_cast*>(y); }); + EQUAL.addVisitor*> ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast*>(x) == *std::any_cast*>(y); }); + EQUAL.addVisitor*> ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast*>(x) == *std::any_cast*>(y); }); + EQUAL.addVisitor*> ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast*>(x) == *std::any_cast*>(y); }); + EQUAL.addVisitor*> ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast*>(x) == *std::any_cast*>(y); }); + EQUAL.addVisitor*> ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast*>(x) == *std::any_cast*>(y); }); + EQUAL.addVisitor*>([](const std::any& x, const std::any& y) -> bool { return *std::any_cast*>(x) == *std::any_cast*>(y); }); + + // Apply the equality visitor to all the stored State data + auto lhsitr = mStorage.begin(); + auto rhsitr = rhs.mStorage.begin(); + for (; lhsitr != mStorage.end(); ++lhsitr, ++rhsitr) { + CHECK(rhsitr != rhs.mStorage.end()); + CHECK(lhsitr->first == rhsitr->first); + if (not EQUAL.visit(lhsitr->second, rhsitr->second)) { + cerr << "States don't match for key " << lhsitr->first << endl; + return false; } } @@ -184,7 +214,7 @@ void StateBase:: enroll(FieldBase& field) { const auto key = this->key(field); - mFieldStorage[key] = &field; + mStorage[key] = &field; mNodeListPtrs.insert(field.nodeListPtr()); // std::cerr << "StateBase::enroll field: " << key << " at " << &field << std::endl; ENSURE(std::find(mNodeListPtrs.begin(), mNodeListPtrs.end(), field.nodeListPtr()) != mNodeListPtrs.end()); @@ -198,9 +228,9 @@ void StateBase:: enroll(std::shared_ptr>& fieldPtr) { const auto key = this->key(*fieldPtr); - mFieldStorage[key] = fieldPtr.get(); + mStorage[key] = fieldPtr.get(); mNodeListPtrs.insert(fieldPtr->nodeListPtr()); - mFieldCache.push_back(fieldPtr); + mCache.push_back(fieldPtr); ENSURE(std::find(mNodeListPtrs.begin(), mNodeListPtrs.end(), fieldPtr->nodeListPtr()) != mNodeListPtrs.end()); } @@ -223,8 +253,7 @@ template bool StateBase:: registered(const StateBase::KeyType& key) const { - return (mFieldStorage.find(key) != mFieldStorage.end() or - mMiscStorage.find(key) != mMiscStorage.end()); + return mStorage.find(key) != mStorage.end(); } //------------------------------------------------------------------------------ @@ -235,7 +264,7 @@ bool StateBase:: registered(const FieldBase& field) const { const auto key = this->key(field); - return mFieldStorage.find(key) != mFieldStorage.end(); + return this->registered(key); } //------------------------------------------------------------------------------ @@ -257,7 +286,7 @@ bool StateBase:: fieldNameRegistered(const FieldName& name) const { KeyType fieldName, nodeListName; - for (auto [key, valptr]: mFieldStorage) { + for (auto [key, valptr]: mStorage) { splitFieldKey(key, fieldName, nodeListName); if (fieldName == name) return true; } @@ -272,8 +301,7 @@ std::vector::KeyType> StateBase:: keys() const { vector result; - for (auto itr = mFieldStorage.begin(); itr != mFieldStorage.end(); ++itr) result.push_back(itr->first); - for (auto itr = mMiscStorage.begin(); itr != mMiscStorage.end(); ++itr) result.push_back(itr->first); + for (auto itr = mStorage.begin(); itr != mStorage.end(); ++itr) result.push_back(itr->first); return result; } @@ -285,7 +313,9 @@ std::vector::KeyType> StateBase:: fullFieldKeys() const { vector result; - for (auto itr = mFieldStorage.begin(); itr != mFieldStorage.end(); ++itr) result.push_back(itr->first); + for (auto [key, aptr]: mStorage) { + if (std::any_cast*>(aptr) != nullptr) result.push_back(key); + } return result; } @@ -297,7 +327,9 @@ std::vector::KeyType> StateBase:: miscKeys() const { vector result; - for (auto itr = mMiscStorage.begin(); itr != mMiscStorage.end(); ++itr) result.push_back(itr->first); + for (auto [key, aptr]: mStorage) { + if (std::any_cast*>(aptr) == nullptr) result.push_back(key); + } return result; } @@ -308,11 +340,10 @@ template std::vector::FieldName> StateBase:: fieldNames() const { - KeyType fieldName, nodeListName; vector result; - for (auto itr = mFieldStorage.begin(); itr != mFieldStorage.end(); ++itr) { - splitFieldKey(itr->first, fieldName, nodeListName); - result.push_back(fieldName); + for (auto [key, aptr]: mStorage) { + auto* fptr = std::any_cast*>(aptr); + if (fptr != nullptr) result.push_back(fptr->name()); } // Remove any duplicates. This will happen when we've stored the same field @@ -406,28 +437,31 @@ void StateBase:: assign(const StateBase& rhs) { - // Fields - { - CHECK(mFieldStorage.size() == rhs.mFieldStorage.size()); - auto lhsitr = mFieldStorage.begin(); - auto rhsitr = rhs.mFieldStorage.begin(); - for (; lhsitr != mFieldStorage.end(); ++lhsitr, ++rhsitr) { - CHECK(rhsitr != rhs.mFieldStorage.end()); - CHECK(lhsitr->first == rhsitr->first); - *(lhsitr->second) = *(rhsitr->second); - } - } - - // Miscellaneous state - { - // Depend on assignment working for our AllowedTypes - CHECK(mMiscStorage.size() == rhs.mMiscStorage.size()); - auto lhsitr = mMiscStorage.begin(); - auto rhsitr = rhs.mMiscStorage.begin(); - for (; lhsitr != mMiscStorage.end(); ++lhsitr, ++rhsitr) { - CHECK(rhsitr != rhs.mMiscStorage.end()); - CHECK(lhsitr->first == rhsitr->first); - std::visit([](auto& lhsval, auto& rhsval) { safeAssign(lhsval, rhsval); }, *(lhsitr->second), *(rhsitr->second)); + // Build a visitor that knows how to assign each of our datatypes + AnyVisitor2 ASSIGN; + ASSIGN.addVisitor*> ([](std::any& x, const std::any& y) { *std::any_cast*>(x) = *std::any_cast*>(y); }); + ASSIGN.addVisitor ([](std::any& x, const std::any& y) { *std::any_cast(x) = *std::any_cast(y); }); + ASSIGN.addVisitor ([](std::any& x, const std::any& y) { *std::any_cast(x) = *std::any_cast(y); }); + ASSIGN.addVisitor ([](std::any& x, const std::any& y) { *std::any_cast(x) = *std::any_cast(y); }); + ASSIGN.addVisitor ([](std::any& x, const std::any& y) { *std::any_cast(x) = *std::any_cast(y); }); + ASSIGN.addVisitor*> ([](std::any& x, const std::any& y) { *std::any_cast*>(x) = *std::any_cast*>(y); }); + ASSIGN.addVisitor*> ([](std::any& x, const std::any& y) { *std::any_cast*>(x) = *std::any_cast*>(y); }); + ASSIGN.addVisitor*> ([](std::any& x, const std::any& y) { *std::any_cast*>(x) = *std::any_cast*>(y); }); + ASSIGN.addVisitor*> ([](std::any& x, const std::any& y) { *std::any_cast*>(x) = *std::any_cast*>(y); }); + ASSIGN.addVisitor*> ([](std::any& x, const std::any& y) { *std::any_cast*>(x) = *std::any_cast*>(y); }); + ASSIGN.addVisitor*> ([](std::any& x, const std::any& y) { *std::any_cast*>(x) = *std::any_cast*>(y); }); + ASSIGN.addVisitor*>([](std::any& x, const std::any& y) { *std::any_cast*>(x) = *std::any_cast*>(y); }); + + // Apply the assignment visitor to all the stored State data + auto lhsitr = mStorage.begin(); + auto rhsitr = rhs.mStorage.begin(); + for (; lhsitr != mStorage.end(); ++lhsitr, ++rhsitr) { + CHECK(rhsitr != rhs.mStorage.end()); + CHECK(lhsitr->first == rhsitr->first); + try { + ASSIGN.visit(lhsitr->second, rhsitr->second); + } catch(...) { + CHECK(false); } } @@ -457,37 +491,30 @@ StateBase:: copyState() { // Remove any pre-existing stuff. - mFieldCache = FieldCacheType(); - mMiscCache = MiscCacheType(); - - // Fields - for (auto itr = mFieldStorage.begin(); itr != mFieldStorage.end(); ++itr) { - auto clone = itr->second->clone(); - mFieldCache.push_back(clone); - itr->second = clone.get(); - } - - // Misc - for (auto itr = mMiscStorage.begin(); itr != mMiscStorage.end(); ++itr) { - std::visit(overload{[](const Scalar& x) { return std::make_shared(x); }, - [](const Vector& x) { return std::make_shared(x); }, - [](const Tensor& x) { return std::make_shared(x); }, - [](const SymTensor& x) { return std::make_shared(x); }, - [](const vector& x) { return std::make_shared(x); }, - [](const vector& x) { return std::make_shared(x); }, - [](const vector& x) { return std::make_shared(x); }, - [](const vector& x) { return std::make_shared(x); }, - [](const set& x) { return std::make_shared(x); }, - [](const set& x) { return std::make_shared(x); }, - [](const ReproducingKernel& x) { return std::make_shared(x); } - }, *(itr->second)); - // [&](auto* xptr) { - // // auto clone = makeClone(*xptr); - // auto clone = std::shared_ptr(makeClone(*xptr)); // new AllowedType(*xptr)); - // // auto clone = std::make_shared(*xptr); - // // mMiscCache.push_back(clone); - // // itr->second = clone.get(); - // }, itr->second); + mCache = CacheType(); + + // Build a visitor to clone each type of state data + AnyVisitor4 CLONE; + CLONE.addVisitor*> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { + auto clone = std::any_cast*>(x)->clone(); + cache.push_back(clone); + storage[key] = clone.get(); + }); + CLONE.addVisitor ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone(x, key, storage, cache); }); + CLONE.addVisitor ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone(x, key, storage, cache); }); + CLONE.addVisitor ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone(x, key, storage, cache); }); + CLONE.addVisitor ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone(x, key, storage, cache); }); + CLONE.addVisitor*> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + CLONE.addVisitor*> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + CLONE.addVisitor*> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + CLONE.addVisitor*> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + CLONE.addVisitor*> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + CLONE.addVisitor*> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + CLONE.addVisitor*> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + + // Clone all our stored data to cache + for (auto& [key, anyvalptr]: mStorage) { + CLONE.visit(anyvalptr, key, mStorage, mCache); } } diff --git a/src/DataBase/StateBase.hh b/src/DataBase/StateBase.hh index 23f26ba0b..21be4fc5c 100644 --- a/src/DataBase/StateBase.hh +++ b/src/DataBase/StateBase.hh @@ -18,6 +18,8 @@ #include "Field/FieldBase.hh" +#include +#include #include #include #include @@ -25,7 +27,6 @@ #include #include #include -#include namespace Spheral { @@ -61,26 +62,11 @@ public: using KeyType = std::string; using FieldName = typename FieldBase::FieldName; - // The allowed miscellaneous types beyond Fields and FieldLists State can handle - using AllowedType = std::variant, - std::vector, - std::vector, - std::vector, - std::set, - std::set, - ReproducingKernel>; - // Constructors, destructor. StateBase(); - StateBase(const StateBase& rhs); - virtual ~StateBase(); - - // Assignment. - StateBase& operator=(const StateBase& rhs); + StateBase(const StateBase& rhs) = default; + StateBase& operator=(const StateBase& rhs) = default; + virtual ~StateBase() {} // Test if two StateBases have equivalent fields. virtual bool operator==(const StateBase& rhs) const; @@ -90,7 +76,7 @@ public: virtual void enroll(FieldBase& field); virtual void enroll(std::shared_ptr>& fieldPtr); virtual void enroll(FieldListBase& fieldList); - template void enroll(const KeyType& key, T& thing); // T has to be one of AllowedTypes + template void enroll(const KeyType& key, T& thing); //............................................................................ // Access Fields @@ -99,6 +85,7 @@ public: const Value& dummy) const; // Get all registered fields of the given data type + template std::vector*> allFields() const; template std::vector*> allFields(const Value& dummy) const; //............................................................................ @@ -171,17 +158,12 @@ public: protected: //--------------------------- Protected Interface ---------------------------// - using FieldStorageType = std::map*>; - using FieldCacheType = std::list>>; - - using MiscStorageType = std::map; - using MiscCacheType = std::list>; + using StorageType = std::map; + using CacheType = std::list; // Protected data. - FieldStorageType mFieldStorage; - FieldCacheType mFieldCache; - MiscStorageType mMiscStorage; - MiscCacheType mMiscCache; + StorageType mStorage; + CacheType mCache; std::set*> mNodeListPtrs; ConnectivityMapPtr mConnectivityMapPtr; MeshPtr mMeshPtr; diff --git a/src/DataBase/StateBaseInline.hh b/src/DataBase/StateBaseInline.hh index 113bcfcb5..a982c6719 100644 --- a/src/DataBase/StateBaseInline.hh +++ b/src/DataBase/StateBaseInline.hh @@ -18,7 +18,7 @@ inline void StateBase:: enroll(const KeyType& key, T& thing) { - mMiscStorage[key] = &thing; + mStorage[key] = &thing; } //------------------------------------------------------------------------------ @@ -30,13 +30,11 @@ inline Field& StateBase:: field(const KeyType& key) const { - auto itr = mFieldStorage.find(key); - VERIFY2(itr != mFieldStorage.end(), "StateBase ERROR: failed lookup for Field " << key); - auto* fbasePtr = itr->second; - auto* resultPtr = dynamic_cast*>(fbasePtr); - VERIFY2(resultPtr != nullptr, + FieldBase& fbase = this->template get>(key); + auto* fptr = dynamic_cast*>(&fbase); + VERIFY2(fptr != nullptr, "StateBase::field ERROR: field type incorrect for key " << key); - return *resultPtr; + return *fptr; } template @@ -57,16 +55,28 @@ template inline std::vector*> StateBase:: -allFields(const Value&) const { +allFields() const { std::vector*> result; KeyType fieldName, nodeListName; - for (auto [key, valptr]: mFieldStorage) { - auto* ptr = dynamic_cast*>(valptr); - if (ptr != nullptr) result.push_back(ptr); + for (auto [key, aptr]: mStorage) { + auto* fbptr = std::any_cast*>(aptr); + if (fbptr != nullptr) { + auto* fptr = dynamic_cast*>(fbptr); + if (fptr != nullptr) result.push_back(fptr); + } } return result; } +template +template +inline +std::vector*> +StateBase:: +allFields(const Value&) const { + return this->template allFields(); +} + //------------------------------------------------------------------------------ // Return a FieldList containing all registered fields of the given name. //------------------------------------------------------------------------------ @@ -78,13 +88,15 @@ StateBase:: fields(const std::string& name) const { FieldList result; KeyType fieldName, nodeListName; - for (auto [key, valptr]: mFieldStorage) { + for (auto [key, aptr]: mStorage) { splitFieldKey(key, fieldName, nodeListName); if (fieldName == name) { CHECK(nodeListName != ""); - auto* fptr = dynamic_cast*>(valptr); - CHECK(valptr != nullptr); - result.appendField(*fptr); + auto* fbptr = std::any_cast*>(aptr); + if (fbptr != nullptr) { + auto fptr = dynamic_cast*>(fbptr); + if (fptr != nullptr) result.appendField(*fptr); + } } } return result; @@ -108,9 +120,9 @@ inline Value& StateBase:: get(const typename StateBase::KeyType& key) const { - auto itr = mMiscStorage.find(key); - VERIFY2(itr != mMiscStorage.end(), "StateBase ERROR: failed lookup for key " << key); - auto* resultPtr = std::get_if(itr->second); + auto itr = mStorage.find(key); + VERIFY2(itr != mStorage.end(), "StateBase ERROR: failed lookup for key " << key); + auto* resultPtr = std::any_cast(itr->second); VERIFY2(resultPtr != nullptr, "StateBase::get ERROR: unable to extract Value for " << key << "\n"); return *resultPtr; } diff --git a/src/DataBase/StateDerivatives.cc b/src/DataBase/StateDerivatives.cc index f2f914254..157ad4664 100644 --- a/src/DataBase/StateDerivatives.cc +++ b/src/DataBase/StateDerivatives.cc @@ -22,6 +22,33 @@ namespace Spheral { namespace { +//------------------------------------------------------------------------------ +// Collect visitor methods to apply to std::any object holders +//------------------------------------------------------------------------------ +// 2 args +template +class AnyVisitor2 { +public: + using VisitorFunc = std::function; + + RETURNT visit(ARG1 value1, ARG2 value2) const { + auto it = mVisitors.find(std::type_index(value1.type())); + if (it != mVisitors.end()) { + return it->second(value1, value2); + } + VERIFY2(false, "AnyVisitor ERROR in StateBase: unable to process unknown data"); + } + + template + void addVisitor(VisitorFunc visitor) { + mVisitors[std::type_index(typeid(T))] = visitor; + } + + +private: + std::unordered_map mVisitors; +}; + // Helper with overloading in std::visit template struct overload : Ts... { using Ts::operator()...; }; template overload(Ts...) -> overload; @@ -158,23 +185,24 @@ void StateDerivatives:: Zero() { + // Build a visitor to zero each data type + AnyVisitor2 ZERO; + ZERO.addVisitor*> ([](const std::any& x, const std::any& y) { std::any_cast*>(x)->Zero(); }); + ZERO.addVisitor ([](const std::any& x, const std::any& y) { *std::any_cast(x) = 0.0; }); + ZERO.addVisitor ([](const std::any& x, const std::any& y) { *std::any_cast(x) = Vector::zero; }); + ZERO.addVisitor ([](const std::any& x, const std::any& y) { *std::any_cast(x) = Tensor::zero; }); + ZERO.addVisitor ([](const std::any& x, const std::any& y) { *std::any_cast(x) = SymTensor::zero; }); + ZERO.addVisitor*> ([](const std::any& x, const std::any& y) { std::any_cast*>(x)->clear(); }); + ZERO.addVisitor*> ([](const std::any& x, const std::any& y) { std::any_cast*>(x)->clear(); }); + ZERO.addVisitor*> ([](const std::any& x, const std::any& y) { std::any_cast*>(x)->clear(); }); + ZERO.addVisitor*> ([](const std::any& x, const std::any& y) { std::any_cast*>(x)->clear(); }); + ZERO.addVisitor*> ([](const std::any& x, const std::any& y) { std::any_cast*>(x)->clear(); }); + ZERO.addVisitor*> ([](const std::any& x, const std::any& y) { std::any_cast*>(x)->clear(); }); + ZERO.addVisitor*> ([](const std::any& x, const std::any& y) { }); + // Walk the state fields and zero them. - for (auto [key, fptr]: mFieldStorage) fptr->Zero(); - - // Same thing for the miscellaeneous types - for (auto [key, mptr]: mMiscStorage) { - std::visit(overload{[](Scalar& x) { x = 0.0; }, - [](Vector& x) { x = Vector::zero; }, - [](Tensor& x) { x = Tensor::zero; }, - [](SymTensor& x) { x = SymTensor::zero; }, - [](vector& x) { x.clear(); }, - [](vector& x) { x.clear(); }, - [](vector& x) { x.clear(); }, - [](vector& x) { x.clear(); }, - [](set& x) { x.clear(); }, - [](set& x) { }, - [](ReproducingKernel& x) { } - }, *mptr); + for (auto [key, anyvalptr]: mStorage) { + ZERO.visit(anyvalptr, anyvalptr); } // Reinitialize the node pair interaction information. diff --git a/src/DataBase/StateDerivatives.hh b/src/DataBase/StateDerivatives.hh index e92012adf..5bf40ccc1 100644 --- a/src/DataBase/StateDerivatives.hh +++ b/src/DataBase/StateDerivatives.hh @@ -81,8 +81,7 @@ private: using SignificantNeighborMapType = std::map, int>; SignificantNeighborMapType mNumSignificantNeighbors; - using StateBase::mFieldStorage; - using StateBase::mMiscStorage; + using StateBase::mStorage; }; } diff --git a/src/Hydro/SphericalPositionPolicy.cc b/src/Hydro/SphericalPositionPolicy.cc index 392d3b6b4..d6cca6d5a 100644 --- a/src/Hydro/SphericalPositionPolicy.cc +++ b/src/Hydro/SphericalPositionPolicy.cc @@ -64,7 +64,7 @@ update(const KeyType& key, // Find all the available matching derivative Field keys. const auto incrementKey = prefix() + fieldKey; - const auto allkeys = derivs.fieldKeys(); + const auto allkeys = derivs.fullFieldKeys(); vector incrementKeys; for (const auto& key: allkeys) { if (key.compare(0, incrementKey.size(), incrementKey) == 0) { From c62e2e87aa30bf968963bc6562af91d8f5947b64 Mon Sep 17 00:00:00 2001 From: Mike Owen Date: Fri, 8 Nov 2024 16:47:54 -0800 Subject: [PATCH 05/14] Still working on testing --- src/DataBase/State.hh | 11 ++++++--- src/DataBase/StateBase.cc | 2 +- src/DataBase/StateBaseInline.hh | 1 + src/DataBase/StateInline.hh | 42 +++++++++++++++++++++++++++++++++ 4 files changed, 52 insertions(+), 4 deletions(-) diff --git a/src/DataBase/State.hh b/src/DataBase/State.hh index b34f02087..6f5c5f493 100644 --- a/src/DataBase/State.hh +++ b/src/DataBase/State.hh @@ -40,9 +40,6 @@ public: using PackageIterator = typename PackageList::iterator; using PolicyPointer = typename std::shared_ptr>; - // Promote base overloaded enroll methods - using StateBase::enroll; - // Constructors, destructor. State(); State(DataBase& dataBase, PackageList& physicsPackages); @@ -105,6 +102,14 @@ public: bool timeAdvanceOnly() const { return mTimeAdvanceOnly; } void timeAdvanceOnly(const bool x) { mTimeAdvanceOnly = x; } + //........................................................................... + // Expose the StateBase enroll methods + using StateBase::enroll; + // virtual void enroll(FieldBase& field) override { StateBase::enroll(field); } + // virtual void enroll(std::shared_ptr>& fieldPtr) override { StateBase::enroll(fieldPtr); } + // virtual void enroll(FieldListBase& fieldList) override { StateBase::enroll(fieldList); } + template void enroll(const KeyType& key, T& thing); + private: //--------------------------- Private Interface ---------------------------// using PolicyMapType = std::map>; diff --git a/src/DataBase/StateBase.cc b/src/DataBase/StateBase.cc index e3fd67335..661430a7b 100644 --- a/src/DataBase/StateBase.cc +++ b/src/DataBase/StateBase.cc @@ -73,7 +73,7 @@ class AnyVisitor4 { if (it != mVisitors.end()) { return it->second(value1, value2, value3, value4); } - VERIFY2(false, "AnyVisitor ERROR in StateBase: unable to process unknown data"); + VERIFY2(false, "AnyVisitor ERROR in StateBase: unable to process unknown data of typeid " << std::quoted(value1.type().name())); } template diff --git a/src/DataBase/StateBaseInline.hh b/src/DataBase/StateBaseInline.hh index a982c6719..afc144205 100644 --- a/src/DataBase/StateBaseInline.hh +++ b/src/DataBase/StateBaseInline.hh @@ -18,6 +18,7 @@ inline void StateBase:: enroll(const KeyType& key, T& thing) { + std::cerr << "StateBase::enroll " << key << std::endl; mStorage[key] = &thing; } diff --git a/src/DataBase/StateInline.hh b/src/DataBase/StateInline.hh index 9d39e3919..8ea18e4c1 100644 --- a/src/DataBase/StateInline.hh +++ b/src/DataBase/StateInline.hh @@ -1,5 +1,35 @@ namespace Spheral { +//------------------------------------------------------------------------------ +// Functors in a detail namespace to help with partial specialization +//------------------------------------------------------------------------------ +namespace Detail { + +template +struct EnrollAny { + void operator()(State& state, + const typename State::KeyType& key, + T& thing) { + dynamic_cast*>(&state)->enroll(key, thing); + } +}; + +template +struct EnrollAny> { + void operator()(State& state, + const typename State::KeyType& key, + std::shared_ptr& thing) { + auto UPP = std::dynamic_pointer_cast>(thing); + if (UPP) { + state.enroll(key, UPP); + } else { + dynamic_cast*>(&state)->enroll(key, thing); + } + } +}; + +} + //------------------------------------------------------------------------------ // Enroll the given field and policy. //------------------------------------------------------------------------------ @@ -62,4 +92,16 @@ policy(const Field& field) const { return this->policy(key); } +//------------------------------------------------------------------------------ +// Enroll an arbitrary type +//------------------------------------------------------------------------------ +template +template +inline +void +State:: +enroll(const KeyType& key, T& thing) { + Detail::EnrollAny()(*this, key, thing); +} + } From fd7ca80f84b754e8c2030fc621a1c94ed20cf3b3 Mon Sep 17 00:00:00 2001 From: Mike Owen Date: Mon, 11 Nov 2024 11:33:39 -0800 Subject: [PATCH 06/14] Looks like std::any_cast with pointer types can still throw -- I thought it was not supposed to. We still have to put those in try/catch blocks. --- src/DataBase/StateBase.cc | 17 +++++++++++++---- src/DataBase/StateBaseInline.hh | 33 +++++++++++++++++++++------------ 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/src/DataBase/StateBase.cc b/src/DataBase/StateBase.cc index 661430a7b..f49c650e4 100644 --- a/src/DataBase/StateBase.cc +++ b/src/DataBase/StateBase.cc @@ -314,7 +314,10 @@ StateBase:: fullFieldKeys() const { vector result; for (auto [key, aptr]: mStorage) { - if (std::any_cast*>(aptr) != nullptr) result.push_back(key); + try { + if (std::any_cast*>(aptr) != nullptr) result.push_back(key); + } catch (const std::bad_any_cast& e) { + } } return result; } @@ -328,7 +331,10 @@ StateBase:: miscKeys() const { vector result; for (auto [key, aptr]: mStorage) { - if (std::any_cast*>(aptr) == nullptr) result.push_back(key); + try { + if (std::any_cast*>(aptr) == nullptr) result.push_back(key); + } catch(const std::bad_any_cast& e) { + } } return result; } @@ -342,8 +348,11 @@ StateBase:: fieldNames() const { vector result; for (auto [key, aptr]: mStorage) { - auto* fptr = std::any_cast*>(aptr); - if (fptr != nullptr) result.push_back(fptr->name()); + try { + auto* fptr = std::any_cast*>(aptr); + if (fptr != nullptr) result.push_back(fptr->name()); + } catch (const std::bad_any_cast& e) { + } } // Remove any duplicates. This will happen when we've stored the same field diff --git a/src/DataBase/StateBaseInline.hh b/src/DataBase/StateBaseInline.hh index afc144205..e17750466 100644 --- a/src/DataBase/StateBaseInline.hh +++ b/src/DataBase/StateBaseInline.hh @@ -18,7 +18,7 @@ inline void StateBase:: enroll(const KeyType& key, T& thing) { - std::cerr << "StateBase::enroll " << key << std::endl; + // std::cerr << "StateBase::enroll " << key << std::endl; mStorage[key] = &thing; } @@ -60,10 +60,13 @@ allFields() const { std::vector*> result; KeyType fieldName, nodeListName; for (auto [key, aptr]: mStorage) { - auto* fbptr = std::any_cast*>(aptr); - if (fbptr != nullptr) { - auto* fptr = dynamic_cast*>(fbptr); - if (fptr != nullptr) result.push_back(fptr); + try { + auto* fbptr = std::any_cast*>(aptr); + if (fbptr != nullptr) { + auto* fptr = dynamic_cast*>(fbptr); + if (fptr != nullptr) result.push_back(fptr); + } + } catch(const std::bad_any_cast& e) { } } return result; @@ -93,10 +96,13 @@ fields(const std::string& name) const { splitFieldKey(key, fieldName, nodeListName); if (fieldName == name) { CHECK(nodeListName != ""); - auto* fbptr = std::any_cast*>(aptr); - if (fbptr != nullptr) { - auto fptr = dynamic_cast*>(fbptr); - if (fptr != nullptr) result.appendField(*fptr); + try { + auto* fbptr = std::any_cast*>(aptr); + if (fbptr != nullptr) { + auto fptr = dynamic_cast*>(fbptr); + if (fptr != nullptr) result.appendField(*fptr); + } + } catch(const std::bad_any_cast& e) { } } } @@ -123,9 +129,12 @@ StateBase:: get(const typename StateBase::KeyType& key) const { auto itr = mStorage.find(key); VERIFY2(itr != mStorage.end(), "StateBase ERROR: failed lookup for key " << key); - auto* resultPtr = std::any_cast(itr->second); - VERIFY2(resultPtr != nullptr, "StateBase::get ERROR: unable to extract Value for " << key << "\n"); - return *resultPtr; + try { + auto* resultPtr = std::any_cast(itr->second); + return *resultPtr; + } catch(const std::bad_any_cast& e) { + VERIFY2(false, "StateBase::get ERROR: unable to extract Value for " << key << "\n"); + } } // Same thing passing a dummy argument to help with template type From 1992cbefeee9f18d58b68611445153bc0ea7d1cf Mon Sep 17 00:00:00 2001 From: Mike Owen Date: Mon, 11 Nov 2024 11:34:36 -0800 Subject: [PATCH 07/14] Making SphericalPositionUpdate a per Field policy --- src/Hydro/SphericalPositionPolicy.cc | 30 +++++++++++----------------- src/Hydro/SphericalPositionPolicy.hh | 3 +++ 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/src/Hydro/SphericalPositionPolicy.cc b/src/Hydro/SphericalPositionPolicy.cc index d6cca6d5a..42193a4ae 100644 --- a/src/Hydro/SphericalPositionPolicy.cc +++ b/src/Hydro/SphericalPositionPolicy.cc @@ -56,33 +56,27 @@ update(const KeyType& key, // Get the field name portion of the key. KeyType fieldKey, nodeListKey; StateBase::splitFieldKey(key, fieldKey, nodeListKey); - REQUIRE(nodeListKey == UpdatePolicyBase::wildcard()); // Get the state we're updating. - auto f = state.fields(fieldKey, Vector::zero); - const auto numNodeLists = f.size(); + auto f = state.field(key, Vector::zero); // Find all the available matching derivative Field keys. const auto incrementKey = prefix() + fieldKey; - const auto allkeys = derivs.fullFieldKeys(); - vector incrementKeys; + const auto allkeys = derivs.keys(); + KeyType dfKey, dfNodeListKey; for (const auto& key: allkeys) { - if (key.compare(0, incrementKey.size(), incrementKey) == 0) { - incrementKeys.push_back(key); - } - } - CHECK(not incrementKeys.empty()); + StateBase::splitFieldKey(key, dfKey, dfNodeListKey); + if (dfNodeListKey == nodeListKey and + dfKey.compare(0, incrementKey.size(), incrementKey) == 0) { - // Update by each of our derivative fields. - for (const auto& key: incrementKeys) { - const auto df = derivs.fields(key, Vector::zero); - CHECK(df.size() == f.size()); - for (auto k = 0u; k != numNodeLists; ++k) { - const auto n = f[k]->numInternalElements(); - for (auto i = 0u; i != n; ++i) { + // This delta field matches the base of increment key, so apply it. + const auto& df = derivs.field(key, Vector::zero); + const auto n = f.numInternalElements(); +#pragma omp parallel for + for (auto i = 0u; i < n; ++i) { // This is where we diverge from the standard IncrementState. Ensure we cannot cross to // negative radius. - f(k,i) = std::max(0.5*f(k,i), f(k,i) + multiplier*(df(k, i))); + f(i) = std::max(0.5*f(i), f(i) + multiplier*(df(i))); } } } diff --git a/src/Hydro/SphericalPositionPolicy.hh b/src/Hydro/SphericalPositionPolicy.hh index 8b121b057..b8b8aeb5c 100644 --- a/src/Hydro/SphericalPositionPolicy.hh +++ b/src/Hydro/SphericalPositionPolicy.hh @@ -35,6 +35,9 @@ public: const double t, const double dt); + // Should this policy be cloned per Field when registering for a FieldList? + virtual bool clonePerField() const { return true; } + // Equivalence. virtual bool operator==(const UpdatePolicyBase& rhs) const; From a7929b90091022f0b297a47c4707eb46c98a2f72 Mon Sep 17 00:00:00 2001 From: Mike Owen Date: Mon, 11 Nov 2024 15:39:03 -0800 Subject: [PATCH 08/14] Made AnyVisitor a generic template (in terms of the visitor method arguments) and converted all our uses in the State objects to use this class, now stored in Utilities. Also updated the RELEASE_NOTES. --- RELEASE_NOTES.md | 3 + src/DataBase/StateBase.cc | 98 ++------------------------------ src/DataBase/StateBase.hh | 1 - src/DataBase/StateDerivatives.cc | 38 +------------ src/Utilities/AnyVisitor.hh | 40 +++++++++++++ src/Utilities/CMakeLists.txt | 1 + 6 files changed, 50 insertions(+), 131 deletions(-) create mode 100644 src/Utilities/AnyVisitor.hh diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 793d944e1..302d85d08 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -28,6 +28,9 @@ Notable changes include: * Physics packages can indicate if they require Voronoi cell information be available. If so, a new package which computes and updates the Voronoi information is automatically added to the package list by the SpheralController (similar to how the Reproducing Kernel corrections are handled). + * Cleaned up use of std::any in State objects using a visitor pattern to be rigorous ensuring all state entries are handled properly + during assignement, equality, and cloning operations. This is intended to help ensure our Physics advance during time integration + is correct. * Build changes / improvements: * Distributed source directory must always be built now. diff --git a/src/DataBase/StateBase.cc b/src/DataBase/StateBase.cc index f49c650e4..e71e2f0ed 100644 --- a/src/DataBase/StateBase.cc +++ b/src/DataBase/StateBase.cc @@ -12,6 +12,7 @@ #include "Mesh/Mesh.hh" #include "RK/RKCorrectionParams.hh" #include "RK/ReproducingKernel.hh" +#include "Utilities/AnyVisitor.hh" #include "Utilities/DBC.hh" #include @@ -34,72 +35,6 @@ namespace Spheral { namespace { -//------------------------------------------------------------------------------ -// Collect visitor methods to apply to std::any object holders -//------------------------------------------------------------------------------ -// 2 args -template -class AnyVisitor2 { -public: - using VisitorFunc = std::function; - - RETURNT visit(ARG1 value1, ARG2 value2) const { - auto it = mVisitors.find(std::type_index(value1.type())); - if (it != mVisitors.end()) { - return it->second(value1, value2); - } - VERIFY2(false, "AnyVisitor ERROR in StateBase: unable to process unknown data"); - } - - template - void addVisitor(VisitorFunc visitor) { - mVisitors[std::type_index(typeid(T))] = visitor; - } - - -private: - std::unordered_map mVisitors; -}; - -//.............................................................................. -// 4 args -template -class AnyVisitor4 { -public: - using VisitorFunc = std::function; - - RETURNT visit(ARG1 value1, ARG2 value2, ARG3 value3, ARG4 value4) const { - auto it = mVisitors.find(std::type_index(value1.type())); - if (it != mVisitors.end()) { - return it->second(value1, value2, value3, value4); - } - VERIFY2(false, "AnyVisitor ERROR in StateBase: unable to process unknown data of typeid " << std::quoted(value1.type().name())); - } - - template - void addVisitor(VisitorFunc visitor) { - mVisitors[std::type_index(typeid(T))] = visitor; - } - - -private: - std::unordered_map mVisitors; -}; - -// //------------------------------------------------------------------------------ -// // Helper for copying a type, used in copyState -// //------------------------------------------------------------------------------ -// template -// T* -// extractType(boost::any& anyT) { -// try { -// T* result = boost::any_cast(anyT); -// return result; -// } catch (boost::any_cast_error) { -// return NULL; -// } -// } - //------------------------------------------------------------------------------ // Template for generic cloning during copyState //------------------------------------------------------------------------------ @@ -114,31 +49,6 @@ genericClone(std::any& x, storage[key] = clone.get(); } -//------------------------------------------------------------------------------ -// Template to downselect comparison in our variant types -//------------------------------------------------------------------------------ -template bool safeCompare(T1& x, const T1& y) { return x == y; } -template bool safeCompare(T1& x, const T2& y) { VERIFY2(false, "Bad comparison!"); return false; } - -//------------------------------------------------------------------------------ -// Template to downselect assignment in our variant types -//------------------------------------------------------------------------------ -template void safeAssign(T1& x, const T1& y) { x = y; } -template void safeAssign(T1& x, const T2& y) { VERIFY2(false, "Bad assignment!"); } - -template T1& safePointer(T1* xptr, const T1* yptr) { return yptr; } -template T1& safePointer(T1* xptr, const T2* yptr) { VERIFY2(false, "Bad assignment!"); return xptr; } - -//------------------------------------------------------------------------------ -//------------------------------------------------------------------------------ -template std::shared_ptr safeClone(const T1& x, const ResultT& dummy) { return std::make_shared(x); } - -//------------------------------------------------------------------------------ -// Helper with overloading in std::visit -//------------------------------------------------------------------------------ -template struct overload : Ts... { using Ts::operator()...; }; -template overload(Ts...) -> overload; - } //------------------------------------------------------------------------------ @@ -177,7 +87,7 @@ operator==(const StateBase& rhs) const { } // Build up a visitor to compare each type of state data we support holding - AnyVisitor2 EQUAL; + AnyVisitor EQUAL; EQUAL.addVisitor*> ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast*>(x) == *std::any_cast*>(y); }); EQUAL.addVisitor ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast(x) == *std::any_cast(y); }); EQUAL.addVisitor ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast(x) == *std::any_cast(y); }); @@ -447,7 +357,7 @@ StateBase:: assign(const StateBase& rhs) { // Build a visitor that knows how to assign each of our datatypes - AnyVisitor2 ASSIGN; + AnyVisitor ASSIGN; ASSIGN.addVisitor*> ([](std::any& x, const std::any& y) { *std::any_cast*>(x) = *std::any_cast*>(y); }); ASSIGN.addVisitor ([](std::any& x, const std::any& y) { *std::any_cast(x) = *std::any_cast(y); }); ASSIGN.addVisitor ([](std::any& x, const std::any& y) { *std::any_cast(x) = *std::any_cast(y); }); @@ -503,7 +413,7 @@ copyState() { mCache = CacheType(); // Build a visitor to clone each type of state data - AnyVisitor4 CLONE; + AnyVisitor CLONE; CLONE.addVisitor*> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { auto clone = std::any_cast*>(x)->clone(); cache.push_back(clone); diff --git a/src/DataBase/StateBase.hh b/src/DataBase/StateBase.hh index 21be4fc5c..5dc899b13 100644 --- a/src/DataBase/StateBase.hh +++ b/src/DataBase/StateBase.hh @@ -19,7 +19,6 @@ #include "Field/FieldBase.hh" #include -#include #include #include #include diff --git a/src/DataBase/StateDerivatives.cc b/src/DataBase/StateDerivatives.cc index 157ad4664..5e527ab2b 100644 --- a/src/DataBase/StateDerivatives.cc +++ b/src/DataBase/StateDerivatives.cc @@ -9,6 +9,7 @@ #include "DataBase.hh" #include "Physics/Physics.hh" #include "Field/Field.hh" +#include "Utilities/AnyVisitor.hh" using std::vector; using std::cout; @@ -20,41 +21,6 @@ using std::abs; namespace Spheral { -namespace { - -//------------------------------------------------------------------------------ -// Collect visitor methods to apply to std::any object holders -//------------------------------------------------------------------------------ -// 2 args -template -class AnyVisitor2 { -public: - using VisitorFunc = std::function; - - RETURNT visit(ARG1 value1, ARG2 value2) const { - auto it = mVisitors.find(std::type_index(value1.type())); - if (it != mVisitors.end()) { - return it->second(value1, value2); - } - VERIFY2(false, "AnyVisitor ERROR in StateBase: unable to process unknown data"); - } - - template - void addVisitor(VisitorFunc visitor) { - mVisitors[std::type_index(typeid(T))] = visitor; - } - - -private: - std::unordered_map mVisitors; -}; - -// Helper with overloading in std::visit -template struct overload : Ts... { using Ts::operator()...; }; -template overload(Ts...) -> overload; - -} - //------------------------------------------------------------------------------ // Default constructor. //------------------------------------------------------------------------------ @@ -186,7 +152,7 @@ StateDerivatives:: Zero() { // Build a visitor to zero each data type - AnyVisitor2 ZERO; + AnyVisitor ZERO; ZERO.addVisitor*> ([](const std::any& x, const std::any& y) { std::any_cast*>(x)->Zero(); }); ZERO.addVisitor ([](const std::any& x, const std::any& y) { *std::any_cast(x) = 0.0; }); ZERO.addVisitor ([](const std::any& x, const std::any& y) { *std::any_cast(x) = Vector::zero; }); diff --git a/src/Utilities/AnyVisitor.hh b/src/Utilities/AnyVisitor.hh new file mode 100644 index 000000000..ef037ad1c --- /dev/null +++ b/src/Utilities/AnyVisitor.hh @@ -0,0 +1,40 @@ +//---------------------------------Spheral++----------------------------------// +// Collect visitor methods to apply to std::any object holders +// +// This allows us to use the visitor pattern with containers of std::any +// obfuscated objects similarly to the std::variant pattern. +//----------------------------------------------------------------------------// +#ifndef __Spheral_AnyVisitor__ +#define __Spheral_AnyVisitor__ + +#include +#include + +namespace Spheral { + +template +class AnyVisitor { +public: + using VisitorFunc = std::function; + + template + RETURNT visit(T value, EXTRAARGS&&... extraargs) const { + auto it = mVisitors.find(std::type_index(value.type())); + if (it != mVisitors.end()) { + return it->second(value, extraargs...); + } + VERIFY2(false, "AnyVisitor ERROR: unable to process unknown data type " << std::quoted(value.type().name())); + } + + template + void addVisitor(VisitorFunc visitor) { + mVisitors[std::type_index(typeid(T))] = visitor; + } + +private: + std::unordered_map mVisitors; +}; + +} + +#endif diff --git a/src/Utilities/CMakeLists.txt b/src/Utilities/CMakeLists.txt index 75439b880..471901587 100644 --- a/src/Utilities/CMakeLists.txt +++ b/src/Utilities/CMakeLists.txt @@ -131,6 +131,7 @@ set(Utilities_headers timingUtilities.hh uniform_random.hh uniform_random_Inline.hh + AnyVisitor.hh ) spheral_install_python_files(fitspline.py) From c1669628cc1e9b0de36dbb738ca2874cb936e70a Mon Sep 17 00:00:00 2001 From: Mike Owen Date: Tue, 12 Nov 2024 15:35:19 -0800 Subject: [PATCH 09/14] Unnecessary include --- src/Utilities/AnyVisitor.hh | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Utilities/AnyVisitor.hh b/src/Utilities/AnyVisitor.hh index ef037ad1c..66e63c89b 100644 --- a/src/Utilities/AnyVisitor.hh +++ b/src/Utilities/AnyVisitor.hh @@ -7,7 +7,6 @@ #ifndef __Spheral_AnyVisitor__ #define __Spheral_AnyVisitor__ -#include #include namespace Spheral { From a1e766b80b9de926cd449d384e0c9a76c5c71b48 Mon Sep 17 00:00:00 2001 From: Mike Owen Date: Wed, 13 Nov 2024 11:12:10 -0800 Subject: [PATCH 10/14] Upping BlueOS compiler version to gcc 10.2.1 --- scripts/devtools/spec-list.json | 6 +++--- .../spack/configs/blueos_3_ppc64le_ib/compilers.yaml | 10 +++++----- .../spack/configs/blueos_3_ppc64le_ib/packages.yaml | 2 ++ 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/scripts/devtools/spec-list.json b/scripts/devtools/spec-list.json index 59f681eb8..395d5d0cd 100644 --- a/scripts/devtools/spec-list.json +++ b/scripts/devtools/spec-list.json @@ -7,9 +7,9 @@ ] , "blueos_3_ppc64le_ib_p9": [ - "gcc@8.3.1", - "gcc@8.3.1+cuda~mpi cuda_arch=70", - "gcc@8.3.1+cuda cuda_arch=70" + "gcc@10.2.1", + "gcc@10.2.1+cuda~mpi cuda_arch=70", + "gcc@10.2.1+cuda cuda_arch=70" ] } } diff --git a/scripts/spack/configs/blueos_3_ppc64le_ib/compilers.yaml b/scripts/spack/configs/blueos_3_ppc64le_ib/compilers.yaml index b876f15bc..84c1b8b70 100644 --- a/scripts/spack/configs/blueos_3_ppc64le_ib/compilers.yaml +++ b/scripts/spack/configs/blueos_3_ppc64le_ib/compilers.yaml @@ -13,12 +13,12 @@ compilers: environment: {} extra_rpaths: [] - compiler: - spec: gcc@8.3.1 + spec: gcc@10.2.1 paths: - cc: /usr/tce/packages/gcc/gcc-8.3.1/bin/gcc - cxx: /usr/tce/packages/gcc/gcc-8.3.1/bin/g++ - f77: /usr/tce/packages/gcc/gcc-8.3.1/bin/gfortran - fc: /usr/tce/packages/gcc/gcc-8.3.1/bin/gfortran + cc: /usr/tce/packages/gcc/gcc-10.2.1/bin/gcc + cxx: /usr/tce/packages/gcc/gcc-10.2.1/bin/g++ + f77: /usr/tce/packages/gcc/gcc-10.2.1/bin/gfortran + fc: /usr/tce/packages/gcc/gcc-10.2.1/bin/gfortran flags: {} operating_system: rhel7 target: ppc64le diff --git a/scripts/spack/configs/blueos_3_ppc64le_ib/packages.yaml b/scripts/spack/configs/blueos_3_ppc64le_ib/packages.yaml index 7a8d3d6bf..54f6fd19c 100644 --- a/scripts/spack/configs/blueos_3_ppc64le_ib/packages.yaml +++ b/scripts/spack/configs/blueos_3_ppc64le_ib/packages.yaml @@ -39,6 +39,8 @@ packages: - 10.1.243 buildable: false externals: + - spec: cuda@11.4.1+allow-unsupported-compilers + prefix: /usr/tce/packages/cuda/cuda-11.4.1 - spec: cuda@11.1.0~allow-unsupported-compilers prefix: /usr/tce/packages/cuda/cuda-11.1.0 - spec: cuda@11.0.2~allow-unsupported-compilers From ef542bdc83b488b4963ba57103f5e5ed7db0bde2 Mon Sep 17 00:00:00 2001 From: Mike Owen Date: Wed, 13 Nov 2024 13:34:14 -0800 Subject: [PATCH 11/14] Updating hardwired compiler version --- .gitlab/os.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab/os.yml b/.gitlab/os.yml index 9bc4b8146..e7a6a8fda 100644 --- a/.gitlab/os.yml +++ b/.gitlab/os.yml @@ -18,7 +18,7 @@ .on_blueos_3_ppc64: variables: ARCH: 'blueos_3_ppc64le_ib_p9' - GCC_VERSION: '8.3.1' + GCC_VERSION: '10.2.1' CLANG_VERSION: '9.0.0' SPHERAL_BUILDS_DIR: /p/gpfs1/sphapp/spheral-ci-builds extends: [.sys_config] From 306b6975dfe7517b1e7d21e06706a52cecb68871 Mon Sep 17 00:00:00 2001 From: Mike Owen Date: Wed, 13 Nov 2024 15:16:45 -0800 Subject: [PATCH 12/14] Swithed from storing pointers to std::reference_wrapper in our State storage types -- seems to work well --- src/DataBase/StateBase.cc | 110 ++++++++++++++++--------------- src/DataBase/StateBaseInline.hh | 30 ++++----- src/DataBase/StateDerivatives.cc | 34 +++++----- 3 files changed, 86 insertions(+), 88 deletions(-) diff --git a/src/DataBase/StateBase.cc b/src/DataBase/StateBase.cc index e71e2f0ed..699d458c9 100644 --- a/src/DataBase/StateBase.cc +++ b/src/DataBase/StateBase.cc @@ -44,9 +44,9 @@ genericClone(std::any& x, const std::string& key, typename std::map& storage, typename std::list& cache) { - auto clone = std::make_shared(*std::any_cast(x)); + auto clone = std::make_shared(std::any_cast>(x).get()); cache.push_back(clone); - storage[key] = clone.get(); + storage[key] = std::ref(*clone); } } @@ -88,18 +88,18 @@ operator==(const StateBase& rhs) const { // Build up a visitor to compare each type of state data we support holding AnyVisitor EQUAL; - EQUAL.addVisitor*> ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast*>(x) == *std::any_cast*>(y); }); - EQUAL.addVisitor ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast(x) == *std::any_cast(y); }); - EQUAL.addVisitor ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast(x) == *std::any_cast(y); }); - EQUAL.addVisitor ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast(x) == *std::any_cast(y); }); - EQUAL.addVisitor ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast(x) == *std::any_cast(y); }); - EQUAL.addVisitor*> ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast*>(x) == *std::any_cast*>(y); }); - EQUAL.addVisitor*> ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast*>(x) == *std::any_cast*>(y); }); - EQUAL.addVisitor*> ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast*>(x) == *std::any_cast*>(y); }); - EQUAL.addVisitor*> ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast*>(x) == *std::any_cast*>(y); }); - EQUAL.addVisitor*> ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast*>(x) == *std::any_cast*>(y); }); - EQUAL.addVisitor*> ([](const std::any& x, const std::any& y) -> bool { return *std::any_cast*>(x) == *std::any_cast*>(y); }); - EQUAL.addVisitor*>([](const std::any& x, const std::any& y) -> bool { return *std::any_cast*>(x) == *std::any_cast*>(y); }); + EQUAL.addVisitor>> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>>(x).get() == std::any_cast>>(y).get(); }); + EQUAL.addVisitor> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>(x).get() == std::any_cast>(y).get(); }); + EQUAL.addVisitor> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>(x).get() == std::any_cast>(y).get(); }); + EQUAL.addVisitor> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>(x).get() == std::any_cast>(y).get(); }); + EQUAL.addVisitor> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>(x).get() == std::any_cast>(y).get(); }); + EQUAL.addVisitor>> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>>(x).get() == std::any_cast>>(y).get(); }); + EQUAL.addVisitor>> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>>(x).get() == std::any_cast>>(y).get(); }); + EQUAL.addVisitor>> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>>(x).get() == std::any_cast>>(y).get(); }); + EQUAL.addVisitor>> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>>(x).get() == std::any_cast>>(y).get(); }); + EQUAL.addVisitor>> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>>(x).get() == std::any_cast>>(y).get(); }); + EQUAL.addVisitor>> ([](const std::any& x, const std::any& y) -> bool { return std::any_cast>>(x).get() == std::any_cast>>(y).get(); }); + EQUAL.addVisitor>>([](const std::any& x, const std::any& y) -> bool { return std::any_cast>>(x).get() == std::any_cast>>(y).get(); }); // Apply the equality visitor to all the stored State data auto lhsitr = mStorage.begin(); @@ -124,7 +124,7 @@ void StateBase:: enroll(FieldBase& field) { const auto key = this->key(field); - mStorage[key] = &field; + mStorage[key] = std::ref(field); mNodeListPtrs.insert(field.nodeListPtr()); // std::cerr << "StateBase::enroll field: " << key << " at " << &field << std::endl; ENSURE(std::find(mNodeListPtrs.begin(), mNodeListPtrs.end(), field.nodeListPtr()) != mNodeListPtrs.end()); @@ -138,7 +138,7 @@ void StateBase:: enroll(std::shared_ptr>& fieldPtr) { const auto key = this->key(*fieldPtr); - mStorage[key] = fieldPtr.get(); + mStorage[key] = std::ref(*fieldPtr); mNodeListPtrs.insert(fieldPtr->nodeListPtr()); mCache.push_back(fieldPtr); ENSURE(std::find(mNodeListPtrs.begin(), mNodeListPtrs.end(), fieldPtr->nodeListPtr()) != mNodeListPtrs.end()); @@ -223,9 +223,11 @@ std::vector::KeyType> StateBase:: fullFieldKeys() const { vector result; - for (auto [key, aptr]: mStorage) { + for (auto [key, aref]: mStorage) { try { - if (std::any_cast*>(aptr) != nullptr) result.push_back(key); + auto xref = std::any_cast>>(aref); + result.push_back(key); + CONTRACT_VAR(xref); } catch (const std::bad_any_cast& e) { } } @@ -240,9 +242,11 @@ std::vector::KeyType> StateBase:: miscKeys() const { vector result; - for (auto [key, aptr]: mStorage) { + for (auto [key, aref]: mStorage) { try { - if (std::any_cast*>(aptr) == nullptr) result.push_back(key); + auto xref = std::any_cast>>(aref); + result.push_back(key); + CONTRACT_VAR(xref); } catch(const std::bad_any_cast& e) { } } @@ -257,10 +261,10 @@ std::vector::FieldName> StateBase:: fieldNames() const { vector result; - for (auto [key, aptr]: mStorage) { + for (auto [key, aref]: mStorage) { try { - auto* fptr = std::any_cast*>(aptr); - if (fptr != nullptr) result.push_back(fptr->name()); + auto fref = std::any_cast>>(aref); + result.push_back(fref.get().name()); } catch (const std::bad_any_cast& e) { } } @@ -358,18 +362,18 @@ assign(const StateBase& rhs) { // Build a visitor that knows how to assign each of our datatypes AnyVisitor ASSIGN; - ASSIGN.addVisitor*> ([](std::any& x, const std::any& y) { *std::any_cast*>(x) = *std::any_cast*>(y); }); - ASSIGN.addVisitor ([](std::any& x, const std::any& y) { *std::any_cast(x) = *std::any_cast(y); }); - ASSIGN.addVisitor ([](std::any& x, const std::any& y) { *std::any_cast(x) = *std::any_cast(y); }); - ASSIGN.addVisitor ([](std::any& x, const std::any& y) { *std::any_cast(x) = *std::any_cast(y); }); - ASSIGN.addVisitor ([](std::any& x, const std::any& y) { *std::any_cast(x) = *std::any_cast(y); }); - ASSIGN.addVisitor*> ([](std::any& x, const std::any& y) { *std::any_cast*>(x) = *std::any_cast*>(y); }); - ASSIGN.addVisitor*> ([](std::any& x, const std::any& y) { *std::any_cast*>(x) = *std::any_cast*>(y); }); - ASSIGN.addVisitor*> ([](std::any& x, const std::any& y) { *std::any_cast*>(x) = *std::any_cast*>(y); }); - ASSIGN.addVisitor*> ([](std::any& x, const std::any& y) { *std::any_cast*>(x) = *std::any_cast*>(y); }); - ASSIGN.addVisitor*> ([](std::any& x, const std::any& y) { *std::any_cast*>(x) = *std::any_cast*>(y); }); - ASSIGN.addVisitor*> ([](std::any& x, const std::any& y) { *std::any_cast*>(x) = *std::any_cast*>(y); }); - ASSIGN.addVisitor*>([](std::any& x, const std::any& y) { *std::any_cast*>(x) = *std::any_cast*>(y); }); + ASSIGN.addVisitor>> ([](std::any& x, const std::any& y) { std::any_cast>>(x).get() = std::any_cast>>(y).get(); }); + ASSIGN.addVisitor> ([](std::any& x, const std::any& y) { std::any_cast>(x).get() = std::any_cast>(y).get(); }); + ASSIGN.addVisitor> ([](std::any& x, const std::any& y) { std::any_cast>(x).get() = std::any_cast>(y).get(); }); + ASSIGN.addVisitor> ([](std::any& x, const std::any& y) { std::any_cast>(x).get() = std::any_cast>(y).get(); }); + ASSIGN.addVisitor> ([](std::any& x, const std::any& y) { std::any_cast>(x).get() = std::any_cast>(y).get(); }); + ASSIGN.addVisitor>> ([](std::any& x, const std::any& y) { std::any_cast>>(x).get() = std::any_cast>>(y).get(); }); + ASSIGN.addVisitor>> ([](std::any& x, const std::any& y) { std::any_cast>>(x).get() = std::any_cast>>(y).get(); }); + ASSIGN.addVisitor>> ([](std::any& x, const std::any& y) { std::any_cast>>(x).get() = std::any_cast>>(y).get(); }); + ASSIGN.addVisitor>> ([](std::any& x, const std::any& y) { std::any_cast>>(x).get() = std::any_cast>>(y).get(); }); + ASSIGN.addVisitor>> ([](std::any& x, const std::any& y) { std::any_cast>>(x).get() = std::any_cast>>(y).get(); }); + ASSIGN.addVisitor>> ([](std::any& x, const std::any& y) { std::any_cast>>(x).get() = std::any_cast>>(y).get(); }); + ASSIGN.addVisitor>>([](std::any& x, const std::any& y) { std::any_cast>>(x).get() = std::any_cast>>(y).get(); }); // Apply the assignment visitor to all the stored State data auto lhsitr = mStorage.begin(); @@ -414,26 +418,26 @@ copyState() { // Build a visitor to clone each type of state data AnyVisitor CLONE; - CLONE.addVisitor*> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { - auto clone = std::any_cast*>(x)->clone(); - cache.push_back(clone); - storage[key] = clone.get(); - }); - CLONE.addVisitor ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone(x, key, storage, cache); }); - CLONE.addVisitor ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone(x, key, storage, cache); }); - CLONE.addVisitor ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone(x, key, storage, cache); }); - CLONE.addVisitor ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone(x, key, storage, cache); }); - CLONE.addVisitor*> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); - CLONE.addVisitor*> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); - CLONE.addVisitor*> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); - CLONE.addVisitor*> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); - CLONE.addVisitor*> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); - CLONE.addVisitor*> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); - CLONE.addVisitor*> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + CLONE.addVisitor>> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { + auto clone = std::any_cast>>(x).get().clone(); + cache.push_back(clone); + storage[key] = std::ref(*clone); + }); + CLONE.addVisitor> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone(x, key, storage, cache); }); + CLONE.addVisitor> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone(x, key, storage, cache); }); + CLONE.addVisitor> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone(x, key, storage, cache); }); + CLONE.addVisitor> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone(x, key, storage, cache); }); + CLONE.addVisitor>> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + CLONE.addVisitor>> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + CLONE.addVisitor>> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + CLONE.addVisitor>> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + CLONE.addVisitor>> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + CLONE.addVisitor>> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); + CLONE.addVisitor>> ([](std::any& x, const KeyType& key, StorageType& storage, CacheType& cache) { genericClone>(x, key, storage, cache); }); // Clone all our stored data to cache - for (auto& [key, anyvalptr]: mStorage) { - CLONE.visit(anyvalptr, key, mStorage, mCache); + for (auto& [key, anyval]: mStorage) { + CLONE.visit(anyval, key, mStorage, mCache); } } diff --git a/src/DataBase/StateBaseInline.hh b/src/DataBase/StateBaseInline.hh index e17750466..0434809bf 100644 --- a/src/DataBase/StateBaseInline.hh +++ b/src/DataBase/StateBaseInline.hh @@ -10,7 +10,6 @@ namespace Spheral { //------------------------------------------------------------------------------ // Enroll an arbitrary type -// Must be one of the supported types in StateBase::AllowedType //------------------------------------------------------------------------------ template template @@ -19,7 +18,7 @@ void StateBase:: enroll(const KeyType& key, T& thing) { // std::cerr << "StateBase::enroll " << key << std::endl; - mStorage[key] = &thing; + mStorage[key] = std::ref(thing); } //------------------------------------------------------------------------------ @@ -31,8 +30,8 @@ inline Field& StateBase:: field(const KeyType& key) const { - FieldBase& fbase = this->template get>(key); - auto* fptr = dynamic_cast*>(&fbase); + FieldBase& fb = this->template get>(key); + auto* fptr = dynamic_cast*>(&fb); VERIFY2(fptr != nullptr, "StateBase::field ERROR: field type incorrect for key " << key); return *fptr; @@ -59,13 +58,11 @@ StateBase:: allFields() const { std::vector*> result; KeyType fieldName, nodeListName; - for (auto [key, aptr]: mStorage) { + for (auto [key, aref]: mStorage) { try { - auto* fbptr = std::any_cast*>(aptr); - if (fbptr != nullptr) { - auto* fptr = dynamic_cast*>(fbptr); - if (fptr != nullptr) result.push_back(fptr); - } + auto fb = std::any_cast>>(aref); + auto* fptr = dynamic_cast*>(&fb.get()); + if (fptr != nullptr) result.push_back(fptr); } catch(const std::bad_any_cast& e) { } } @@ -92,16 +89,14 @@ StateBase:: fields(const std::string& name) const { FieldList result; KeyType fieldName, nodeListName; - for (auto [key, aptr]: mStorage) { + for (auto [key, aref]: mStorage) { splitFieldKey(key, fieldName, nodeListName); if (fieldName == name) { CHECK(nodeListName != ""); try { - auto* fbptr = std::any_cast*>(aptr); - if (fbptr != nullptr) { - auto fptr = dynamic_cast*>(fbptr); - if (fptr != nullptr) result.appendField(*fptr); - } + auto fb = std::any_cast>>(aref); + auto* fptr = dynamic_cast*>(&fb.get()); + if (fptr != nullptr) result.appendField(*fptr); } catch(const std::bad_any_cast& e) { } } @@ -130,8 +125,7 @@ get(const typename StateBase::KeyType& key) const { auto itr = mStorage.find(key); VERIFY2(itr != mStorage.end(), "StateBase ERROR: failed lookup for key " << key); try { - auto* resultPtr = std::any_cast(itr->second); - return *resultPtr; + return std::any_cast>(itr->second); } catch(const std::bad_any_cast& e) { VERIFY2(false, "StateBase::get ERROR: unable to extract Value for " << key << "\n"); } diff --git a/src/DataBase/StateDerivatives.cc b/src/DataBase/StateDerivatives.cc index 5e527ab2b..ef898013c 100644 --- a/src/DataBase/StateDerivatives.cc +++ b/src/DataBase/StateDerivatives.cc @@ -152,23 +152,23 @@ StateDerivatives:: Zero() { // Build a visitor to zero each data type - AnyVisitor ZERO; - ZERO.addVisitor*> ([](const std::any& x, const std::any& y) { std::any_cast*>(x)->Zero(); }); - ZERO.addVisitor ([](const std::any& x, const std::any& y) { *std::any_cast(x) = 0.0; }); - ZERO.addVisitor ([](const std::any& x, const std::any& y) { *std::any_cast(x) = Vector::zero; }); - ZERO.addVisitor ([](const std::any& x, const std::any& y) { *std::any_cast(x) = Tensor::zero; }); - ZERO.addVisitor ([](const std::any& x, const std::any& y) { *std::any_cast(x) = SymTensor::zero; }); - ZERO.addVisitor*> ([](const std::any& x, const std::any& y) { std::any_cast*>(x)->clear(); }); - ZERO.addVisitor*> ([](const std::any& x, const std::any& y) { std::any_cast*>(x)->clear(); }); - ZERO.addVisitor*> ([](const std::any& x, const std::any& y) { std::any_cast*>(x)->clear(); }); - ZERO.addVisitor*> ([](const std::any& x, const std::any& y) { std::any_cast*>(x)->clear(); }); - ZERO.addVisitor*> ([](const std::any& x, const std::any& y) { std::any_cast*>(x)->clear(); }); - ZERO.addVisitor*> ([](const std::any& x, const std::any& y) { std::any_cast*>(x)->clear(); }); - ZERO.addVisitor*> ([](const std::any& x, const std::any& y) { }); - - // Walk the state fields and zero them. - for (auto [key, anyvalptr]: mStorage) { - ZERO.visit(anyvalptr, anyvalptr); + AnyVisitor ZERO; + ZERO.addVisitor>> ([](const std::any& x) { std::any_cast>>(x).get().Zero(); }); + ZERO.addVisitor> ([](const std::any& x) { std::any_cast>(x).get() = 0.0; }); + ZERO.addVisitor> ([](const std::any& x) { std::any_cast>(x).get() = Vector::zero; }); + ZERO.addVisitor> ([](const std::any& x) { std::any_cast>(x).get() = Tensor::zero; }); + ZERO.addVisitor> ([](const std::any& x) { std::any_cast>(x).get() = SymTensor::zero; }); + ZERO.addVisitor>> ([](const std::any& x) { std::any_cast>>(x).get().clear(); }); + ZERO.addVisitor>> ([](const std::any& x) { std::any_cast>>(x).get().clear(); }); + ZERO.addVisitor>> ([](const std::any& x) { std::any_cast>>(x).get().clear(); }); + ZERO.addVisitor>> ([](const std::any& x) { std::any_cast>>(x).get().clear(); }); + ZERO.addVisitor>> ([](const std::any& x) { std::any_cast>>(x).get().clear(); }); + ZERO.addVisitor>> ([](const std::any& x) { std::any_cast>>(x).get().clear(); }); + ZERO.addVisitor>> ([](const std::any& x) { } ); + + // Walk the state values and zero them + for (auto itr: mStorage) { + ZERO.visit(itr.second); } // Reinitialize the node pair interaction information. From d443ecc215126a1353ab5bdb9676dc2ee8600a7c Mon Sep 17 00:00:00 2001 From: Mike Owen Date: Thu, 14 Nov 2024 10:12:26 -0800 Subject: [PATCH 13/14] Using any::type to eliminate the try/catch nonsense in State classes --- src/DataBase/StateBase.cc | 23 +++++++---------------- src/DataBase/StateBaseInline.hh | 11 ++++------- 2 files changed, 11 insertions(+), 23 deletions(-) diff --git a/src/DataBase/StateBase.cc b/src/DataBase/StateBase.cc index 699d458c9..698a91d58 100644 --- a/src/DataBase/StateBase.cc +++ b/src/DataBase/StateBase.cc @@ -224,11 +224,8 @@ StateBase:: fullFieldKeys() const { vector result; for (auto [key, aref]: mStorage) { - try { - auto xref = std::any_cast>>(aref); + if (aref.type() == typeid(std::reference_wrapper>)) { result.push_back(key); - CONTRACT_VAR(xref); - } catch (const std::bad_any_cast& e) { } } return result; @@ -243,11 +240,8 @@ StateBase:: miscKeys() const { vector result; for (auto [key, aref]: mStorage) { - try { - auto xref = std::any_cast>>(aref); + if (aref.type() != typeid(std::reference_wrapper>)) { result.push_back(key); - CONTRACT_VAR(xref); - } catch(const std::bad_any_cast& e) { } } return result; @@ -261,11 +255,12 @@ std::vector::FieldName> StateBase:: fieldNames() const { vector result; + KeyType fieldName, nodeListName; for (auto [key, aref]: mStorage) { - try { + if (aref.type() == typeid(std::reference_wrapper>)) { auto fref = std::any_cast>>(aref); - result.push_back(fref.get().name()); - } catch (const std::bad_any_cast& e) { + splitFieldKey(fref.get().name(), fieldName, nodeListName); + result.push_back(fieldName); } } @@ -381,11 +376,7 @@ assign(const StateBase& rhs) { for (; lhsitr != mStorage.end(); ++lhsitr, ++rhsitr) { CHECK(rhsitr != rhs.mStorage.end()); CHECK(lhsitr->first == rhsitr->first); - try { - ASSIGN.visit(lhsitr->second, rhsitr->second); - } catch(...) { - CHECK(false); - } + ASSIGN.visit(lhsitr->second, rhsitr->second); } // Copy the connectivity (by reference). This thing is too diff --git a/src/DataBase/StateBaseInline.hh b/src/DataBase/StateBaseInline.hh index 0434809bf..888e1cd82 100644 --- a/src/DataBase/StateBaseInline.hh +++ b/src/DataBase/StateBaseInline.hh @@ -59,11 +59,10 @@ allFields() const { std::vector*> result; KeyType fieldName, nodeListName; for (auto [key, aref]: mStorage) { - try { + if (aref.type() == typeid(std::reference_wrapper>)) { auto fb = std::any_cast>>(aref); auto* fptr = dynamic_cast*>(&fb.get()); if (fptr != nullptr) result.push_back(fptr); - } catch(const std::bad_any_cast& e) { } } return result; @@ -93,11 +92,10 @@ fields(const std::string& name) const { splitFieldKey(key, fieldName, nodeListName); if (fieldName == name) { CHECK(nodeListName != ""); - try { + if (aref.type() == typeid(std::reference_wrapper>)) { auto fb = std::any_cast>>(aref); auto* fptr = dynamic_cast*>(&fb.get()); if (fptr != nullptr) result.appendField(*fptr); - } catch(const std::bad_any_cast& e) { } } } @@ -124,11 +122,10 @@ StateBase:: get(const typename StateBase::KeyType& key) const { auto itr = mStorage.find(key); VERIFY2(itr != mStorage.end(), "StateBase ERROR: failed lookup for key " << key); - try { + if (itr->second.type() == typeid(std::reference_wrapper)) { return std::any_cast>(itr->second); - } catch(const std::bad_any_cast& e) { - VERIFY2(false, "StateBase::get ERROR: unable to extract Value for " << key << "\n"); } + VERIFY2(false, "StateBase::get ERROR: unable to extract Value for " << key << "\n"); } // Same thing passing a dummy argument to help with template type From f9a3b59a149616e90241c15fa08e5ab138861952 Mon Sep 17 00:00:00 2001 From: Mike Owen Date: Fri, 15 Nov 2024 13:26:09 -0800 Subject: [PATCH 14/14] Missing header for install targets --- src/Utilities/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Utilities/CMakeLists.txt b/src/Utilities/CMakeLists.txt index 471901587..3c73b46f3 100644 --- a/src/Utilities/CMakeLists.txt +++ b/src/Utilities/CMakeLists.txt @@ -131,6 +131,7 @@ set(Utilities_headers timingUtilities.hh uniform_random.hh uniform_random_Inline.hh + range.hh AnyVisitor.hh )