Skip to content

Commit

Permalink
feat: Update datamodel cache from all mutating rpc response (#3041)
Browse files Browse the repository at this point in the history
* feat: return state changes from mutating rpcs

* test: fix

* feat: add logging

* test: fix

* feat: Handle dict-type parameters in cache

* fix: test

* fix: test

* fix: comment

* feat: revert test changes

* feat: rewrite rules query using generated classes

* feat: add test

* feat: Refresh tasks after rpc calls

* fix: test

* test: skip
  • Loading branch information
mkundu1 authored Sep 18, 2024
1 parent a3fbbb9 commit 68e5321
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 32 deletions.
5 changes: 4 additions & 1 deletion src/ansys/fluent/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,12 @@ def version_info() -> str:
# Whether to use datamodel attribute caching
DATAMODEL_USE_ATTR_CACHE = True

# Whether stream and cache commands state
# Whether to stream and cache commands state
DATAMODEL_USE_NOCOMMANDS_DIFF_STATE = True

# Whether to return the state changes on mutating datamodel rpcs
DATAMODEL_RETURN_STATE_CHANGES = True

# Whether to use remote gRPC file transfer service
USE_FILE_TRANSFER_SERVICE = False

Expand Down
71 changes: 63 additions & 8 deletions src/ansys/fluent/core/data_model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Dict, List

from ansys.api.fluent.v0.variant_pb2 import Variant
from ansys.fluent.core.utils.fluent_version import FluentVersion

StateType = (
bool
Expand Down Expand Up @@ -101,6 +102,31 @@ def update(self, d: dict[str, Any], d1: dict[str, Any]):
d[k] = v1


def _is_dict_parameter_type(version: FluentVersion, rules: str, rules_path: str):
"""Check if a parameter is a dict type."""
from ansys.fluent.core import CODEGEN_OUTDIR
from ansys.fluent.core.services.datamodel_se import (
PyDictionary,
PyNamedObjectContainer,
PyParameter,
)
from ansys.fluent.core.utils import load_module

module = load_module(
rules, CODEGEN_OUTDIR / f"datamodel_{version.number}" / f"{rules}.py"
)
cls = module.Root
comps = rules_path.split("/")
for i, comp in enumerate(comps):
if hasattr(cls, comp):
cls = getattr(cls, comp)
if issubclass(cls, PyParameter) and i < len(comps) - 1:
return False
if issubclass(cls, PyNamedObjectContainer):
cls = getattr(cls, f"_{comp}")
return issubclass(cls, PyDictionary)


class DataModelCache:
"""Class to manage datamodel cache."""

Expand Down Expand Up @@ -177,6 +203,8 @@ def _update_cache_from_variant_state(
key: str,
state: Variant,
updaterFn,
rules_str: str,
version,
):
if state.HasField("bool_state"):
updaterFn(source, key, state.bool_state)
Expand All @@ -198,7 +226,13 @@ def _update_cache_from_variant_state(
updaterFn(source, key, [])
for item in state.variant_vector_state.item:
self._update_cache_from_variant_state(
rules, source, key, item, lambda d, k, v: d[k].append(v)
rules,
source,
key,
item,
lambda d, k, v: d[k].append(v),
rules_str + "/" + key.split(":", maxsplit=1)[0],
version,
)
elif state.HasField("variant_map_state"):
internal_names_as_keys = (
Expand Down Expand Up @@ -226,15 +260,28 @@ def _update_cache_from_variant_state(
else:
if key not in source:
source[key] = {}
source = source[key]
for k, v in state.variant_map_state.item.items():
self._update_cache_from_variant_state(
rules, source, k, v, dict.__setitem__
)
if version and _is_dict_parameter_type(version, rules, rules_str):
source[key] = {}
if state.variant_map_state.item:
source = source[key]
for k, v in state.variant_map_state.item.items():
self._update_cache_from_variant_state(
rules,
source,
k,
v,
dict.__setitem__,
rules_str + "/" + k.split(":", maxsplit=1)[0],
version,
)
else:
source[key] = {}
else:
updaterFn(source, key, None)

def update_cache(self, rules: str, state: Variant, deleted_paths: List[str]):
def update_cache(
self, rules: str, state: Variant, deleted_paths: List[str], version=None
):
"""Update datamodel cache from streamed state.
Parameters
Expand All @@ -245,6 +292,8 @@ def update_cache(self, rules: str, state: Variant, deleted_paths: List[str]):
streamed state
deleted_paths : List[str]
list of deleted paths
version : FluentVersion, optional
Fluent version
"""
cache = self.rules_str_to_cache[rules]
with self._with_lock(rules):
Expand Down Expand Up @@ -280,7 +329,13 @@ def update_cache(self, rules: str, state: Variant, deleted_paths: List[str]):
break
for k, v in state.variant_map_state.item.items():
self._update_cache_from_variant_state(
rules, cache, k, v, dict.__setitem__
rules,
cache,
k,
v,
dict.__setitem__,
k.split(":", maxsplit=1)[0],
version,
)

@staticmethod
Expand Down
73 changes: 66 additions & 7 deletions src/ansys/fluent/core/services/datamodel_se.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from ansys.fluent.core.services.streaming import StreamingService
from ansys.fluent.core.solver.error_message import allowed_name_error_message
from ansys.fluent.core.utils.fluent_version import FluentVersion

Path = list[tuple[str, str]]
_TValue = None | bool | int | float | str | Sequence["_TValue"] | dict[str, "_TValue"]
Expand Down Expand Up @@ -452,6 +453,7 @@ def __init__(
self,
channel: grpc.Channel,
metadata: list[tuple[str, str]],
version: FluentVersion,
fluent_error_state,
file_transfer_service: Any | None = None,
) -> None:
Expand All @@ -465,6 +467,7 @@ def __init__(
self.subscriptions = SubscriptionList()
self.file_transfer_service = file_transfer_service
self.cache = DataModelCache() if pyfluent.DATAMODEL_USE_STATE_CACHE else None
self.version = version

def get_attribute_value(self, rules: str, path: str, attribute: str) -> _TValue:
"""Get attribute value."""
Expand Down Expand Up @@ -495,7 +498,14 @@ def rename(self, rules: str, path: str, new_name: str) -> None:
request.path = path
request.new_name = new_name
request.wait = True
self._impl.rename(request)
response = self._impl.rename(request)
if self.cache is not None:
self.cache.update_cache(
rules,
response.state,
response.deletedpaths,
version=self.version,
)

def delete_child_objects(
self, rules: str, path: str, obj_type: str, child_names: list[str]
Expand All @@ -507,7 +517,14 @@ def delete_child_objects(
for name in child_names:
request.child_names.names.append(name)
request.wait = True
self._impl.delete_child_objects(request)
response = self._impl.delete_child_objects(request)
if self.cache is not None:
self.cache.update_cache(
rules,
response.state,
response.deletedpaths,
version=self.version,
)

def delete_all_child_objects(self, rules: str, path: str, obj_type: str) -> None:
"""Delete all child objects."""
Expand All @@ -516,22 +533,43 @@ def delete_all_child_objects(self, rules: str, path: str, obj_type: str) -> None
request.path = path + "/" + obj_type
request.delete_all = True
request.wait = True
self._impl.delete_child_objects(request)
response = self._impl.delete_child_objects(request)
if self.cache is not None:
self.cache.update_cache(
rules,
response.state,
response.deletedpaths,
version=self.version,
)

def set_state(self, rules: str, path: str, state: _TValue) -> None:
"""Set state."""
request = DataModelProtoModule.SetStateRequest(
rules=rules, path=path, wait=True
)
_convert_value_to_variant(state, request.state)
self._impl.set_state(request)
response = self._impl.set_state(request)
if self.cache is not None:
self.cache.update_cache(
rules,
response.state,
response.deletedpaths,
version=self.version,
)

def fix_state(self, rules, path) -> None:
"""Fix state."""
request = DataModelProtoModule.FixStateRequest()
request.rules = rules
request.path = convert_path_to_se_path(path)
self._impl.fix_state(request)
response = self._impl.fix_state(request)
if self.cache is not None:
self.cache.update_cache(
rules,
response.state,
response.deletedpaths,
version=self.version,
)

def update_dict(
self, rules: str, path: str, dict_state: dict[str, _TValue]
Expand All @@ -541,14 +579,28 @@ def update_dict(
rules=rules, path=path, wait=True
)
_convert_value_to_variant(dict_state, request.dicttomerge)
self._impl.update_dict(request)
response = self._impl.update_dict(request)
if self.cache is not None:
self.cache.update_cache(
rules,
response.state,
response.deletedpaths,
version=self.version,
)

def delete_object(self, rules: str, path: str) -> None:
"""Delete an object."""
request = DataModelProtoModule.DeleteObjectRequest(
rules=rules, path=path, wait=True
)
self._impl.delete_object(request)
response = self._impl.delete_object(request)
if self.cache is not None:
self.cache.update_cache(
rules,
response.state,
response.deletedpaths,
version=self.version,
)

def execute_command(
self, rules: str, path: str, command: str, args: dict[str, _TValue]
Expand All @@ -559,6 +611,13 @@ def execute_command(
)
_convert_value_to_variant(args, request.args)
response = self._impl.execute_command(request)
if self.cache is not None:
self.cache.update_cache(
rules,
response.state,
response.deletedpaths,
version=self.version,
)
return _convert_variant_to_value(response.result)

def execute_query(
Expand Down
1 change: 1 addition & 0 deletions src/ansys/fluent/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def _build_from_fluent_connection(
self._datamodel_service_se = service_creator("datamodel").create(
fluent_connection._channel,
fluent_connection._metadata,
self.get_fluent_version(),
self._error_state,
self._file_transfer_service,
)
Expand Down
4 changes: 3 additions & 1 deletion src/ansys/fluent/core/session_pure_meshing.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def __init__(
stream = DatamodelStream(datamodel_service_se)
stream.register_callback(
functools.partial(
datamodel_service_se.cache.update_cache, rules=rules
datamodel_service_se.cache.update_cache,
rules=rules,
version=datamodel_service_se.version,
)
)
self.datamodel_streams[rules] = stream
Expand Down
11 changes: 11 additions & 0 deletions src/ansys/fluent/core/streaming_services/datamodel_streaming.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
"""Provides a module for datamodel streaming."""

import logging

from google.protobuf.json_format import MessageToDict

from ansys.api.fluent.v0 import datamodel_se_pb2
import ansys.fluent.core as pyfluent
from ansys.fluent.core.streaming_services.streaming import StreamingService

network_logger: logging.Logger = logging.getLogger("pyfluent.networking")


class DatamodelStream(StreamingService):
"""Encapsulates a datamodel streaming service."""
Expand All @@ -28,6 +35,7 @@ def _process_streaming(
"""Processes datamodel events."""
data_model_request = datamodel_se_pb2.DataModelRequest(*args, **kwargs)
data_model_request.rules = rules
data_model_request.returnstatechanges = pyfluent.DATAMODEL_RETURN_STATE_CHANGES
if no_commands_diff_state:
data_model_request.diffstate = datamodel_se_pb2.DIFFSTATE_NOCOMMANDS
responses = self._streaming_service.begin_streaming(
Expand All @@ -39,6 +47,9 @@ def _process_streaming(
while True:
try:
response: datamodel_se_pb2.DataModelResponse = next(responses)
network_logger.debug(
f"GRPC_TRACE: RPC = /grpcRemoting.DataModel/BeginStreaming, response = {MessageToDict(response)}"
)
with self._lock:
self._streaming = True
for _, cb_list in self._service_callbacks.items():
Expand Down
20 changes: 12 additions & 8 deletions src/ansys/fluent/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ansys.fluent.core.services.datamodel_se import (
PyCallableStateObject,
PyCommand,
PyMenu,
PyMenuGeneric,
PySingletonCommandArgumentsSubItem,
)
Expand Down Expand Up @@ -149,11 +150,13 @@ def _convert_task_list_to_display_names(workflow_root, task_list):
return [workflow_state[f"TaskObject:{x}"]["_name_"] for x in task_list]
else:
_display_names = []
_org_path = workflow_root.path
for _task_name in task_list:
workflow_root.path = [("TaskObject", _task_name), ("_name_", "")]
_display_names.append(workflow_root())
workflow_root.path = _org_path
name_obj = PyMenu(
service=workflow_root.service,
rules=workflow_root.rules,
path=[("TaskObject", _task_name), ("_name_", "")],
)
_display_names.append(name_obj())
return _display_names


Expand Down Expand Up @@ -530,9 +533,8 @@ def _insert_next_task(self, task_name: str):
raise ValueError(
f"'{task_name}' cannot be inserted next to '{self.python_name()}'."
)
return self._task.InsertNextTask(
CommandName=self._python_task_names_map[task_name]
)
self._task.InsertNextTask(CommandName=self._python_task_names_map[task_name])
_call_refresh_task_accessors(self._command_source)

@property
def insertable_tasks(self):
Expand Down Expand Up @@ -570,7 +572,9 @@ def __repr__(self):
def __call__(self, **kwds) -> Any:
if kwds:
self._task.Arguments.set_state(**kwds)
return self._task.Execute()
result = self._task.Execute()
_call_refresh_task_accessors(self._command_source)
return result

def _tasks_with_matching_attributes(self, attr: str, other_attr: str) -> list:
this_command = self._command()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_data_model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_data_model_cache():
({"r1": {}}, "r1", {"A": [3.0, 6.0]}, [], {"r1": {"A": [3.0, 6.0]}}),
({"r1": {}}, "r1", {"A": ["ab", "cd"]}, [], {"r1": {"A": ["ab", "cd"]}}),
({"r1": {"A": {}}}, "r1", {"A": {"B": 5}}, [], {"r1": {"A": {"B": 5}}}),
({"r1": {"A": 5}}, "r1", {"A": {}}, [], {"r1": {"A": 5}}),
({"r1": {"A": 5}}, "r1", {"A": {}}, [], {"r1": {"A": {}}}),
({"r1": {"A": 5}}, "r1", {"A": None}, [], {"r1": {"A": None}}),
(
{"r1": {"A": {}}},
Expand Down
Loading

0 comments on commit 68e5321

Please sign in to comment.