Skip to content

Commit

Permalink
run qa
Browse files Browse the repository at this point in the history
  • Loading branch information
jjlk committed Oct 2, 2024
1 parent 853bb79 commit 7a13506
Showing 1 changed file with 8 additions and 16 deletions.
24 changes: 8 additions & 16 deletions src/ai_models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ def __exit__(self, *args):


class ArchiveCollector:
UNIQUE = {"date", "hdate", "time",
"referenceDate", "type", "stream", "expver"}
UNIQUE = {"date", "hdate", "time", "referenceDate", "type", "stream", "expver"}

def __init__(self) -> None:
self.expect = 0
Expand All @@ -56,8 +55,7 @@ def add(self, field):
self.request[k].add(str(v))
if k in self.UNIQUE:
if len(self.request[k]) > 1:
raise ValueError(
f"Field {field} has different values for {k}: {self.request[k]}")
raise ValueError(f"Field {field} has different values for {k}: {self.request[k]}")


class Model:
Expand Down Expand Up @@ -162,8 +160,7 @@ def json_default(obj):
raise TypeError

print(
json.dumps(json_requests, separators=(
",", ":"), default=json_default, sort_keys=True),
json.dumps(json_requests, separators=(",", ":"), default=json_default, sort_keys=True),
file=f,
)

Expand All @@ -173,8 +170,7 @@ def download_assets(self, **kwargs):
if not os.path.exists(asset):
os.makedirs(os.path.dirname(asset), exist_ok=True)
LOG.info("Downloading %s", asset)
download(self.download_url.format(
file=file), asset + ".download")
download(self.download_url.format(file=file), asset + ".download")
os.rename(asset + ".download", asset)

@property
Expand Down Expand Up @@ -447,8 +443,7 @@ def _requests(self):
def filter_constant(request):
# We check for 'sfc' because param 'z' can be ambiguous
if request.get("levtype") == "sfc":
param = set(self.constant_fields) & set(
request.get("param", []))
param = set(self.constant_fields) & set(request.get("param", []))
if param:
request["param"] = list(param)
return True
Expand All @@ -459,8 +454,7 @@ def filter_prognostic(request):
# TODO: We assume here that prognostic fields are
# the ones that are not constant. This may not always be true
if request.get("levtype") == "sfc":
param = set(request.get("param", [])) - \
set(self.constant_fields)
param = set(request.get("param", [])) - set(self.constant_fields)
if param:
request["param"] = list(param)
return True
Expand Down Expand Up @@ -502,8 +496,7 @@ def peek_into_checkpoint(self, path):

def parse_model_args(self, args):
if args:
raise NotImplementedError(
f"This model does not accept arguments {args}")
raise NotImplementedError(f"This model does not accept arguments {args}")

def provenance(self):
from .provenance import gather_provenance_info
Expand Down Expand Up @@ -597,8 +590,7 @@ def write_input_fields(
"""

template = base64.b64decode(template)
accumulations_template = ekd.from_source(
"memory", template)[0]
accumulations_template = ekd.from_source("memory", template)[0]

for param in accumulations:
self.write(
Expand Down

0 comments on commit 7a13506

Please sign in to comment.