Skip to content

Commit

Permalink
Include model version in metadata request
Browse files Browse the repository at this point in the history
  • Loading branch information
gmertes committed Apr 9, 2024
1 parent 0824583 commit bd0ecec
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
10 changes: 7 additions & 3 deletions ai_models/remote/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,15 @@ def run(self, cfg: dict):

LOG.debug("Result written to %s", self.output_file)

def metadata(self, model, param) -> dict:
def metadata(self, model, model_version, param) -> dict:
if isinstance(param, str):
return self._request(requests.get, f"metadata/{model}/{param}")
return self._request(
requests.get, f"metadata/{model}/{model_version}/{param}"
)
elif isinstance(param, (list, dict)):
return self._request(requests.post, f"metadata/{model}", json=param)
return self._request(
requests.post, f"metadata/{model}/{model_version}", json=param
)
else:
raise ValueError("param must be a string, list, or dict with 'param' key.")

Expand Down
5 changes: 4 additions & 1 deletion ai_models/remote/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self, **kwargs):
self.cfg["assets_extra_dir"] = None

self.model = self.cfg["model"]
self.model_version = self.cfg.get("model_version", "latest")
self._param = {}
self.api = RemoteAPI()

Expand Down Expand Up @@ -61,6 +62,7 @@ def patch_retrieve_request(self, request):
def load_parameters(self):
params = self.api.metadata(
self.model,
self.model_version,
[
"expver",
"version",
Expand All @@ -80,7 +82,8 @@ def get_parameter(self, name):
if (param := self._param.get(name)) is not None:
return param

self._param.update(self.api.metadata(self.model, name))
_param = self.api.metadata(self.model, self.model_version, name)
self._param.update(_param)

return self._param.get(name)

Expand Down

0 comments on commit bd0ecec

Please sign in to comment.