Skip to content

Commit

Permalink
updated drought - mapping
Browse files Browse the repository at this point in the history
updated drought, there were some remaining changes from December 2023, also includes the map generation functions.
  • Loading branch information
faridradmehr committed May 3, 2024
1 parent a8bc5c2 commit 5ca1e8e
Show file tree
Hide file tree
Showing 2 changed files with 2,875 additions and 1,838 deletions.
85 changes: 52 additions & 33 deletions src/hazard/models/drought_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import os
import sys
from datetime import datetime
from pathlib import PurePosixPath
from typing import Iterable, List, Optional, Protocol, Sequence
Expand All @@ -15,24 +16,24 @@
import zarr # type: ignore
import zarr.hierarchy
from pydantic import BaseModel
from pydantic.dataclasses import dataclass
from pydantic.type_adapter import TypeAdapter
from zarr.errors import GroupNotFoundError # type: ignore

from hazard.indicator_model import IndicatorModel # type: ignore
from hazard.inventory import Colormap, HazardResource, MapInfo, Scenario
from hazard.models.multi_year_average import MultiYearAverageIndicatorBase # type: ignore
from hazard.protocols import ReadWriteDataArray
from hazard.sources.nex_gddp_cmip6 import NexGddpCmip6
from hazard.sources.osc_zarr import OscZarr
from hazard.utilities.tiles import create_tiles_for_resource

logger = logging.getLogger(__name__)


@dataclass
class BatchItem:
resource: HazardResource
gcm: str
scenario: str
central_year: int
central_years: list


class ZarrWorkingStore(Protocol):
Expand All @@ -42,7 +43,7 @@ def get_store(self, path: str): ...
class S3ZarrWorkingStore(ZarrWorkingStore):
def __init__(self):
s3 = s3fs.S3FileSystem(
key=os.environ.get("OSC_S3_ACCESS_KEY_DEV", None), secret=os.environ.get("OSC_S3_SECRET_KEY_DEV", None)
anon=False, key=os.environ["OSC_S3_ACCESS_KEY_DEV"], secret=os.environ["OSC_S3_SECRET_KEY_DEV"]
)
base_path = os.environ["OSC_S3_BUCKET_DEV"] + "/drought/osc/v01"
self._base_path = base_path
Expand Down Expand Up @@ -106,12 +107,12 @@ def _write(self, indices: ChunkIndicesComplete):
f.write(json.dumps(indices.model_dump()))


class DroughtIndicator(IndicatorModel[BatchItem]):
class DroughtIndicator:
def __init__(
self,
working_zarr_store: ZarrWorkingStore,
window_years: int = MultiYearAverageIndicatorBase._default_window_years,
gcms: Iterable[str] = ["MIROC6"], # MultiYearAverageIndicatorBase._default_gcms,
gcms: Iterable[str] = MultiYearAverageIndicatorBase._default_gcms,
scenarios: Iterable[str] = MultiYearAverageIndicatorBase._default_scenarios,
central_years: Sequence[int] = [2005, 2030, 2040, 2050, 2080],
):
Expand Down Expand Up @@ -164,9 +165,9 @@ def download_dataset(variable, year, gcm, scenario, datasource=datasource):
{"time": 365, "lat": lat_chunk_size, "lon": lon_chunk_size}
)
if year == years[0]:
ds.to_zarr(store=self.working_zarr_store, group=group, mode="w")
ds.to_zarr(store=self.working_zarr_store.get_store(group), mode="w")
else:
ds.to_zarr(store=self.working_zarr_store, group=group, append_dim="time")
ds.to_zarr(store=self.working_zarr_store.get_store(group), append_dim="time")
logger.info(f"completed processing: variable={quantity}, year={year}.")

def read_quantity_from_s3_store(self, gcm, scenario, quantity, lat_min, lat_max, lon_min, lon_max) -> xr.Dataset:
Expand Down Expand Up @@ -224,7 +225,7 @@ def calculate_spei(self, gcm, scenario, progress_store: Optional[ProgressStore]
logger.info(f"chunk {chunk_name} complete.")
if progress_store:
progress_store.add_completed([chunk_names.index(chunk_name)])
except Exception:
except Exception as exc:
logger.info(f"chunk {futures[future]} failed.")

def _calculate_spei_chunk(self, chunk_name, data_chunks, ds_chunked: xr.Dataset, gcm, scenario):
Expand Down Expand Up @@ -253,28 +254,34 @@ def _calculate_spei_chunk(self, chunk_name, data_chunks, ds_chunked: xr.Dataset,
)
# compute=False to avoid calculating array
ds_spei.to_zarr(store=store, mode="w", compute=False)
logger.info("Created new zarr array.")
# see https://docs.xarray.dev/en/stable/user-guide/io.html?appending-to-existing-zarr-stores=#appending-to-existing-zarr-stores # noqa: E501
logger.info(f"created new zarr array.")
# see https://docs.xarray.dev/en/stable/user-guide/io.html?appending-to-existing-zarr-stores=#appending-to-existing-zarr-stores

lat_indexes = np.where(np.logical_and(lats_all >= lat_min, lats_all <= lat_max))[0]
lon_indexes = np.where(np.logical_and(lons_all >= lon_min, lons_all <= lon_max))[0]
time_indexes = np.arange(0, len(ds_spei_slice["time"].values))
ds_spei_slice.to_zarr(
store=self.working_zarr_store,
store=store,
mode="r+",
region={
"lat": slice(lat_indexes[0], lat_indexes[-1] + 1),
"lon": slice(lon_indexes[0], lon_indexes[-1] + 1),
"time": slice(time_indexes[0], time_indexes[-1] + 1),
},
)
logger.info(f"written chunk {chunk_name} to zarr array.")
return chunk_name

def _calculate_spei_for_slice(self, lat_min, lat_max, lon_min, lon_max, *, gcm, scenario, num_workers=2):
ds_tas = self.read_quantity_from_s3_store(gcm, scenario, "tas", lat_min, lat_max, lon_min, lon_max).chunk(
{"time": 100000}
ds_tas = (
self.read_quantity_from_s3_store(gcm, scenario, "tas", lat_min, lat_max, lon_min, lon_max)
.chunk({"time": 100000})
.compute()
)
ds_pr = self.read_quantity_from_s3_store(gcm, scenario, "pr", lat_min, lat_max, lon_min, lon_max).chunk(
{"time": 100000}
ds_pr = (
self.read_quantity_from_s3_store(gcm, scenario, "pr", lat_min, lat_max, lon_min, lon_max)
.chunk({"time": 100000})
.compute()
)
ds_tas = ds_tas.drop_duplicates(dim=..., keep="last").sortby("time")
ds_pr = ds_pr.drop_duplicates(dim=..., keep="last").sortby("time")
Expand All @@ -300,8 +307,8 @@ def _calculate_spei_for_slice(self, lat_min, lat_max, lon_min, lon_max, *, gcm,
return ds_spei

def calculate_annual_average_spei(self, gcm: str, scenario: str, central_year: int, target: OscZarr):
"""Calculate average number of months where 12-month SPEI index is below thresholds
[0, -1, -1.5, -2, -2.5, -3.6] for 20 years period.
"""Calculate average number of months where 12-month SPEI index is below thresholds [0, -1, -1.5, -2, -2.5, -3.6]
for 20 years period.
Args:
gcm (str): Global Circulation Model ID.
Expand All @@ -311,9 +318,8 @@ def calculate_annual_average_spei(self, gcm: str, scenario: str, central_year: i
"""

def get_spei_full_results(gcm, scenario):
ds_spei = xr.open_zarr(
store=self.working_zarr_store, group=os.path.join("SPEI", "Aggregated", gcm + "_" + scenario)
)
path = os.path.join("spei", gcm + "_" + scenario)
ds_spei = xr.open_zarr(self.working_zarr_store.get_store(path))
return ds_spei

period = [
Expand All @@ -330,38 +336,51 @@ def get_spei_full_results(gcm, scenario):
spei_temp = spei_temp["spei"]
for i in range(len(self.spei_threshold)):
spei_ext = xr.where((spei_temp <= self.spei_threshold[i]), 1, 0)
spei_ext_sum = spei_ext.mean("time")
spei_ext_sum = spei_ext.mean("time") * 12
spei_annual[i, :, :] = spei_ext_sum
spei_annual_all = xr.DataArray(
spei_annual,
coords={
"spei_idx": self.spei_threshold,
"spei_index": self.spei_threshold,
"lat": lats_all,
"lon": lons_all,
},
dims=["spei_idx", "lat", "lon"],
dims=["spei_index", "lat", "lon"],
)
path = self.resource.path.format(gcm=gcm, scenario=scenario, year=central_year)
print(path)
target.write(path, spei_annual_all)
return spei_annual_all

def run_single(self, item: BatchItem, source, target: ReadWriteDataArray, client):
assert isinstance(target, OscZarr)
calculate_spei = True
calculate_average_spei = True
def run_single(
self,
item: BatchItem,
target: OscZarr,
calculate_spei=True,
calculate_average_spei=True,
progress_store: Optional[ProgressStore] = None,
):
if calculate_spei:
self.calculate_spei(item.gcm, item.scenario)
self.calculate_spei(item.gcm, item.scenario, progress_store)
if calculate_average_spei:
self.calculate_annual_average_spei(item.gcm, item.scenario, item.central_year, target)
for central_year in item.central_years:
self.calculate_annual_average_spei(item.gcm, item.scenario, central_year, target)

def batch_items(self) -> Iterable[BatchItem]:
"""Get a list of all batch items."""
return []

def create_maps(self, source: OscZarr, target: OscZarr):
"""
Create map images.
"""
...
create_tiles_for_resource(source, target, self.resource)

def inventory(self) -> Iterable[HazardResource]:
"""Get the (unexpanded) HazardModel(s) that comprise the inventory."""
return [self._resource()]

def _resource(self) -> HazardResource:
# with open(os.path.join(os.path.dirname(__file__), "days_tas_above.md"), "r") as f:
# description = f.read()
Expand All @@ -388,7 +407,7 @@ def _resource(self) -> HazardResource:
),
bounds=[(-180.0, 85.0), (180.0, 85.0), (180.0, -60.0), (-180.0, -60.0)],
index_values=self.spei_threshold,
path="drought/osc/v1/months_spei12m_below_index_{gcm}_{scenario}_{year}_map",
path="maps/drought/osc/v1/months_spei12m_below_index_{gcm}_{scenario}_{year}_map",
source="map_array_pyramid",
),
units="months/year",
Expand Down
Loading

0 comments on commit 5ca1e8e

Please sign in to comment.