Skip to content

Commit

Permalink
change constants url
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Sep 15, 2024
1 parent c9c7545 commit e27e6e7
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions src/ai_models/inputs/opendata.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"slor",
)

CONSTANTS_URL = "https://get.ecmwf.int/repository/test-data/ai-models/opendata/constants.grib2"
CONSTANTS_URL = "https://get.ecmwf.int/repository/test-data/ai-models/opendata/constants-{resol}.grib2"


class OpenDataInput(RequestBasedInput):
Expand Down Expand Up @@ -62,12 +62,12 @@ def _adjust(self, kwargs):

if interp:
logging.debug("Interpolating from %s to %s", source, grid)
return Interpolate(grid, source)
return (Interpolate(grid, source), source)
else:
return lambda x: x
return (lambda x: x, source)

def pl_load_source(self, **kwargs):
pproc = self._adjust(kwargs)
pproc, _ = self._adjust(kwargs)
kwargs["levtype"] = "pl"
request = kwargs.copy()

Expand All @@ -87,7 +87,7 @@ def pl_load_source(self, **kwargs):
return self.check_pl(pproc(ekd.from_source("ecmwf-open-data", **kwargs)), request)

def sfc_load_source(self, **kwargs):
pproc = self._adjust(kwargs)
pproc, resol = self._adjust(kwargs)
kwargs["levtype"] = "sfc"
request = kwargs.copy()

Expand Down Expand Up @@ -117,16 +117,17 @@ def sfc_load_source(self, **kwargs):
constants = []

cachedir = os.path.expanduser("~/.cache/ai-models")
basename = os.path.basename(CONSTANTS_URL)
constant_url = CONSTANTS_URL.format(resol=resol)
basename = os.path.basename(constant_url)

if not os.path.exists(cachedir):
os.makedirs(cachedir)

path = os.path.join(cachedir, basename)

if not os.path.exists(path):
logging.info("Downloading %s to %s", CONSTANTS_URL, path)
download(CONSTANTS_URL, path + ".tmp")
logging.info("Downloading %s to %s", constant_url, path)
download(constant_url, path + ".tmp")
os.rename(path + ".tmp", path)

ds = ekd.from_source("file", path)
Expand Down Expand Up @@ -169,7 +170,7 @@ def sfc_load_source(self, **kwargs):
return self.check_sfc(fields, request)

def ml_load_source(self, **kwargs):
pproc = self._adjust(kwargs)
pproc, _ = self._adjust(kwargs)
kwargs["levtype"] = "ml"
request = kwargs.copy()

Expand Down

0 comments on commit e27e6e7

Please sign in to comment.