From f3574714ac41c219e33f757a7395fa9a31b19c59 Mon Sep 17 00:00:00 2001 From: tomsail Date: Thu, 20 Jun 2024 07:43:52 +0200 Subject: [PATCH] added model metrics --- Model_metrics.html | 11775 ++++++++++++++++++++++++++++++++++++++++++ Model_metrics.ipynb | 1385 +++++ index.html | 1 + 3 files changed, 13161 insertions(+) create mode 100644 Model_metrics.html create mode 100644 Model_metrics.ipynb diff --git a/Model_metrics.html b/Model_metrics.html new file mode 100644 index 0000000..519dbaa --- /dev/null +++ b/Model_metrics.html @@ -0,0 +1,11775 @@ + + + + + +Model_metrics + + + + + + + + + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
+ + diff --git a/Model_metrics.ipynb b/Model_metrics.ipynb new file mode 100644 index 0000000..f76ffc4 --- /dev/null +++ b/Model_metrics.ipynb @@ -0,0 +1,1385 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## first part: processing " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import numpy as np\n", + "import pandas as pd\n", + "import xarray as xr\n", + "import holoviews as hv\n", + "import panel as pn\n", + "import param\n", + "import geopandas as gp\n", + "import shapely\n", + "import json\n", + "from holoviews import opts \n", + "import geoviews as gv\n", + "import hvplot.pandas\n", + "from bokeh.models import HoverTool\n", + "from holoviews import streams\n", + "from holoviews.plotting.links import RangeToolLink\n", + "from holoviews.operation import histogram\n", + "from holoviews.operation.datashader import rasterize, spread\n", + "import bokeh.palettes as bp\n", + "\n", + "from scipy.stats import linregress\n", + "from scipy.spatial import cKDTree\n", + "\n", + "from typing import Tuple, Dict\n", + "\n", + "hv.extension('bokeh')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## model info\n", + "we need to specify the different versions of the global surge model:\n", + " * `v0`\n", + " * `v0.2` \n", + " * `v1.2`\n", + " * `v2.0`\n", + " * `v2.2`\n", + "\n", + "more info in [seareport_meshes](https://github.com/seareport/seareport_meshes) repository" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "versions = {\n", + " 'v0': 'results/2D/v0.nc', \n", + " 'v0.2': 'results/2D/v0.2.nc', \n", + " # 'v1.2': 'results/2D/v0.2.nc', # for now we use v0.2 results\n", + " # 'v2.0': 'results/2D/v0.2.nc', # for now we use v0.2 results\n", + " # 'v2.1': 'results/2D/v0.2.nc', # for now we use v0.2 results\n", + " # 'v2.2': 'results/2D/vnp.nan.2.nc', # for now we use v0.2 results\n", + "}\n", + "\n", + "# time for the comparison\n", + "tmin = '2023-01-01'\n", + "tmax = '2023-12-31'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## observations info\n", + "we need to confront the model with the observations. \n", + "\n", + "Here we want to compare modeled surge, so we had to clean up and detide the data. \n", + "\n", + " * data has been extracted from the IOC api, using [`searvey`](https://github.com/oceanmodeling/searvey)\n", + " * the data clean-up has been done with [`ioc_cleanup`](https://github.com/seareport/ioc_cleanup), we selected in total ~180 candidates for the period 2022/2023, depending on their data quality. \n", + " * The detide has been done with [`Utide`](https://github.com/wesleybowman/UTide) through [`analysea`](https://github.com/seareport/analysea/blob/tide-chunk/analysea/tide.py)'s `detide` function.\n", + "\n", + "the example can be shown for one station: \n", + " * `acnj` (Atlantic City, New Jersey)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "raw = pd.read_parquet('obs/raw/acnj.parquet')\n", + "clean = pd.read_parquet('obs/clean/acnj.parquet')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from analysea.tide import detide\n", + "surge = detide(clean[clean.columns[0]], lat=39.355) #lat for atlantic city" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ts_view = dict(width = 1200, height = 800)\n", + "scatter_view = dict(width = 800, height = 800)\n", + "mod_opts_raster = dict(cmap=[\"blue\"])\n", + "mod_opts = dict(color=\"blue\")\n", + "obs_opts_raster = dict(cmap=[\"red\"])\n", + "obs_opts = dict(color=\"red\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "example for a month: " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tmin_ = pd.Timestamp(2022,7,1)\n", + "tmax_ = pd.Timestamp(2022,8,1)\n", + "# \n", + "raw_ = raw[tmin_:tmax_].hvplot(label='raw')\n", + "clean_ = clean[tmin_:tmax_].hvplot(label='clean')\n", + "surge_ = surge[tmin_:tmax_].hvplot(label='surge')\n", + "# \n", + "(raw_ * clean_ * surge_).opts(show_legend=True, **ts_view)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "seaset = pd.read_csv('https://raw.githubusercontent.com/tomsail/seaset/main/Notebooks/catalog_full.csv', index_col=0)\n", + "# seaset needs to be corrected because it does not contains all stations names" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import glob\n", + "list_clean = glob.glob('*.parquet', root_dir = \"obs/clean/\")\n", + "ioc_cleanup_list = [item.split('.')[0] for item in list_clean]\n", + "surge_stations = seaset[seaset.ioc_code.isin(ioc_cleanup_list)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for ii,ioc_code in enumerate(surge_stations.ioc_code):\n", + " lat = surge_stations.iloc[ii].latitude\n", + " if not os.path.exists(f\"obs/surge/{ioc_code}.parquet\"):\n", + " df = pd.read_parquet(f\"obs/clean/{ioc_code}.parquet\")\n", + " surge = detide(df[df.columns[0]], lat=lat)\n", + " surge.to_frame().to_parquet(f\"obs/surge/{ioc_code}.parquet\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## functions to extract TS from model\n", + "we need to extract the surge data from the model. \n", + "\n", + "Ideally we should have a dataset with indexed at the observations coordinate, this is still a WIP (on-going work in [seaset](https://github.com/oceanmodeling/seaset) for unique station indexing and TELEMAC 1D results file output)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def closest_n_points(nodes, N, meshXY, dist_max=np.inf):\n", + " mytree = cKDTree(meshXY)\n", + " d_, indice = mytree.query(nodes, range(1, N + 1))\n", + " indice[d_ > dist_max] = -1\n", + " mask = indice != -1\n", + " return indice[mask].T, d_[mask].T\n", + "\n", + "\n", + "def extract_t_elev_2D(\n", + " ds: xr.Dataset, \n", + " x: float, \n", + " y: float, \n", + " xstr: str = 'longitude', \n", + " ystr: str = 'latitude'\n", + " )-> Tuple[pd.Series, float, float, float]:\n", + " lons, lats = ds[xstr].values, ds[ystr].values\n", + " indx, dist_ = closest_n_points(np.array([x, y]).T, 1, np.array([lons,lats]).T)\n", + " ds_ = ds.isel(node=indx[0])\n", + " elev_ = ds_.elev.values\n", + " t_ = [pd.Timestamp(ti) for ti in ds_.time.values]\n", + " return pd.Series(elev_, index=t_), np.round(dist_, 2), float(ds_[xstr]), float(ds_[ystr])\n", + "\n", + "\n", + "def get_obs(folder : str, ioc_code: str, ext: str = \".parquet\")->pd.Series: \n", + " obs = pd.read_parquet(f\"{folder}/{ioc_code}{ext}\")\n", + " # hack\n", + " obs = obs[obs.columns[0]]\n", + " return obs\n", + "\n", + "def get_model(model_file: str, ioc_code: str, catalog: pd.DataFrame)->Tuple[pd.Series, float, float]:\n", + " # Extract model data and calculate correlation\n", + " ds = xr.open_dataset(model_file)\n", + " s = catalog[catalog.ioc_code == ioc_code]\n", + " mod, d_, mlon, mlat = extract_t_elev_2D(ds, s.longitude.values[0], s.latitude.values[0], 'lon', 'lat')\n", + " return mod, mlon, mlat" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Stats functions\n", + "We are supposed we compare to time series: \n", + " * `ts1`: modelled surge time series\n", + " * `ts2`: observed surge time series\n", + "\n", + "We need metrics to assess the quality of the model. \n", + "We define the most important ones, as stated on this [Stats wiki](https://cirpwiki.info/wiki/Statistics): \n", + "### A. Dimensional Statistics: \n", + "#### 1. Mean Error (or Bias) \n", + "$$\\langle x_c - x_m \\rangle = \\langle x_c \\rangle - \\langle x_m \\rangle$$\n", + "#### 2. RMSE (Root Mean Squared Error) \n", + "$$\\sqrt{\\langle(x_c - x_m)^2\\rangle}$$\n", + "#### 3. Mean-Absolute Error (MAE): \n", + "$$\\langle |x_c - x_m| \\rangle$$\n", + "### B. Dimentionless Statistics: \n", + "#### 1. Performance Scores (PS) or Nash-Sutcliffe Coefficient (NSE): $$1 - \\frac{\\langle (x_c - x_m)^2 \\rangle}{\\langle (x_m - x_R)^2 \\rangle}$$\n", + " * Range Qualification: \n", + " * 0.8float: \n", + " bias = df1.mean() - df2.mean()\n", + " return np.round(bias, round)\n", + "\n", + "\n", + "def get_mse(df1: pd.Series, df2: pd.Series, round: int = 3)->float: \n", + " mse = np.square(np.subtract(df2, df1)).mean()\n", + " return np.round(mse, round)\n", + "\n", + "\n", + "def get_rmse(df1: pd.Series, df2: pd.Series, round:int = 3)->float:\n", + " rmse = np.sqrt(get_mse(df1, df2, 10))\n", + " return np.round(rmse, round)\n", + "\n", + "\n", + "def get_mae(df1: pd.Series, df2: pd.Series, round:int = 3)->float:\n", + " mae = np.abs(np.subtract(df2, df1)).mean()\n", + " return np.round(mae, round)\n", + "\n", + "def get_mad(df1: pd.Series, df2: pd.Series, round:int = 3)->float:\n", + " mae = np.abs(np.subtract(df2, df1)).std()\n", + " return np.round(mae, round)\n", + "\n", + "\n", + "def get_madp(df1: pd.Series, df2: pd.Series, round:int = 3)->float:\n", + " pc1, pc2 = get_percentiles(df1, df2)\n", + " return get_mad(pc1, pc2, round)\n", + "\n", + "\n", + "def get_madc(df1: pd.Series, df2: pd.Series, round:int = 3)->float:\n", + " madp = get_madp(df1, df2, round)\n", + " return get_mad(df1, df2, round) + madp\n", + "\n", + "\n", + "def get_rms(df1: pd.Series, df2: pd.Series, round:int = 3)->float:\n", + " crmsd = ((df1 - df1.mean()) - (df2 - df2.mean()))**2\n", + " return np.round(np.sqrt(crmsd.mean()), round)\n", + "\n", + "\n", + "def get_corr(df1: pd.Series, df2: pd.Series, round: int = 3)->float:\n", + " corr = df1.corr(df2)\n", + " return np.round(corr, round)\n", + "\n", + "\n", + "def get_nse(df1: pd.Series, df2: pd.Series, round: int = 3)->float:\n", + " nse = 1 - np.nansum(np.subtract(df2, df1) ** 2) / np.nansum((df2 - np.nanmean(df2)) ** 2)\n", + " return np.round(nse, round)\n", + "\n", + "\n", + "def get_lambda(df1: pd.Series, df2: pd.Series, round: int = 3)->float:\n", + " Xmean = np.nanmean(df2)\n", + " Ymean = np.nanmean(df1)\n", + " nObs = len(df2)\n", + " corr = get_corr(df1, df2, 10)\n", + " if corr >= 0:\n", + " kappa = 0\n", + " else:\n", + " kappa = 2 * abs(np.nansum((df2 - Xmean) * (df1 - Ymean)))\n", + "\n", + " Nomin = np.nansum((df2 - df1) ** 2)\n", + " Denom = (\n", + " np.nansum((df2 - Xmean) ** 2)\n", + " + np.nansum((df1 - Ymean) ** 2)\n", + " + nObs * ((Xmean - Ymean) ** 2)\n", + " + kappa\n", + " )\n", + " lambda_index = 1 - Nomin / Denom\n", + " return np.round(lambda_index, round)\n", + " \n", + "\n", + "def get_kge(df1: pd.Series, df2: pd.Series, round: int =3)->float:\n", + " corr = get_corr(df1, df2, 10)\n", + " b = (df1.mean() - df2.mean())/df2.std()\n", + " g = df1.std()/df2.std()\n", + " kge = 1 - np.sqrt((corr-1)**2 + b**2 + (g-1)**2)\n", + " return np.round(kge, round)\n", + "\n", + "\n", + "def align_ts(df1: pd.Series, df2: pd.Series)->Tuple[pd.Series, pd.Series]:\n", + " ts1, ts2 = df1.align(df2, axis = 0)\n", + " ts1 = ts1.interpolate()\n", + " nan_mask1 = pd.isna(ts1)\n", + " nan_mask2 = pd.isna(ts2)\n", + " nan_mask = np.logical_or(nan_mask1.values, nan_mask2.values)\n", + " ts1 = ts1[~nan_mask]\n", + " ts2 = ts2[~nan_mask]\n", + " return ts1, ts2\n", + "\n", + "\n", + "def get_percentiles(df1: pd.Series, df2: pd.Series, higher_tail:bool = False) -> Tuple[pd.Series, pd.Series]:\n", + " x = np.arange(0, 0.99, 0.01)\n", + " if higher_tail:\n", + " x = np.hstack([x, np.arange(0.99, 1, 0.001)])\n", + " pc1 = np.zeros(len(x))\n", + " pc2 = np.zeros(len(x))\n", + " for it, thd in enumerate(x):\n", + " pc1[it] = df1.quantile(thd)\n", + " pc2[it] = df2.quantile(thd)\n", + " return pd.Series(pc1), pd.Series(pc2)\n", + "\n", + "\n", + "def get_stats(ts1: pd.Series, ts2: pd.Series)->Dict[str, float]:\n", + " \"\"\"\n", + " it is STRONGLY advised to use : \n", + " * model data for ts1\n", + " * observed data for ts2\n", + " \"\"\"\n", + " version_stat = {\n", + " \"bias\": get_bias(ts1, ts2),\n", + " \"rmse\": get_rmse(ts1, ts2),\n", + " \"rms\": get_rms(ts1, ts2),\n", + " \"mean_df1\": np.round(ts1.mean(), 3),\n", + " \"mean_df2\": np.round(ts2.mean(), 3),\n", + " \"std_df1\": np.round(ts1.std(), 3),\n", + " \"std_df2\": np.round(ts2.std(), 3),\n", + " \"nse\": get_nse(ts1, ts2),\n", + " \"lamba\": get_lambda(ts1, ts2),\n", + " \"cr\": get_corr(ts1, ts2),\n", + " \"slope\": linregress(ts1, ts2).slope,\n", + " \"intercept\": linregress(ts1, ts2).intercept,\n", + " \"mad\" : get_mad(ts1, ts2),\n", + " \"madp\" : get_madp(ts1, ts2),\n", + " \"madc\" : get_madc(ts1, ts2),\n", + " \"kge\": get_kge(ts1, ts2)\n", + " }\n", + " return version_stat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def json_format(d):\n", + " for key, value in d.items():\n", + " if isinstance(value, (dict, list, tuple)):\n", + " json_format(value) # Recurse into nested dictionaries\n", + " elif isinstance(value, np.ndarray):\n", + " d[key] = value.tolist() # Convert NumPy array to list\n", + " elif isinstance(value, pd.Timestamp):\n", + " d[key] = value.strftime(\"%Y-%m-%d %H:%M:%S\") # Convert pandas Timestamp to string\n", + " elif isinstance(value, pd.Timedelta):\n", + " d[key] = str(value) # Convert pandas Timedelta to string\n", + " else: \n", + " d[key] = str(value)\n", + " return d" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Model vs observations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ts_folder = 'obs/surge'\n", + "model_file = 'results/2D/v0.nc'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 1 - Example for `horn`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ioc_code = 'horn'\n", + "obs = get_obs(ts_folder, ioc_code, '.parquet')\n", + "mod, slon, slat = get_model(model_file, ioc_code, surge_stations)\n", + "mod_, obs_ = align_ts(mod, obs)\n", + "stats_ = get_stats(mod_, obs_)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "general metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "stats_" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### storm detection" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "we select the 99th quantile of the modeled TS and compare it with the 99th quantile of the observed TS" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pyextremes import get_extremes\n", + "\n", + "threshold = mod_.quantile(0.99)\n", + "print(f\"threshold is: {np.round(threshold, 3)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ext_ = get_extremes(mod_, \"POT\", threshold=threshold, r=\"72H\")\n", + "modeled_extremes = pd.DataFrame({\"modeled\" : ext_, \"time_model\" : ext_.index}, index=ext_.index)\n", + "modeled_extremes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "obs_ = obs_[tmin:tmax]\n", + "threshold = min(1, obs_.quantile(0.99))\n", + "print(f\"threshold is: {np.round(threshold, 3)}\")\n", + "ext_ = get_extremes(obs_, \"POT\", threshold=threshold, r=\"72H\")\n", + "observed_extremes = pd.DataFrame({\"observed\" : ext_, \"time_obs\" : ext_.index}, index=ext_.index)\n", + "observed_extremes" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "we can plot the time series and the extremes detected" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_extreme_raster(ts:pd.Series, quantile: float, duration_cluster: int = 72, color = 'black', label = \"\"):\n", + " \"\"\"\n", + " this function might induce overhead if the time series is too long\n", + " \"\"\"\n", + " ext = get_extremes(ts, \"POT\", threshold=ts.quantile(quantile), r=f\"{duration_cluster}H\")\n", + " ts_ = rasterize(hv.Curve(ts, label=label),line_width = 0.5).opts(cmap=[color], **ts_view, show_grid=True, alpha = 0.7,)\n", + " sc_ = hv.Scatter(ext,label=label).opts(opts.Scatter(line_color=\"black\", fill_color=color, size=8))\n", + " th_ = hv.HLine(ts.quantile(quantile)).opts(opts.HLine(color=color, line_dash=\"dashed\"))\n", + " th_text_ = hv.Text(ts.index[int(len(ts)/2)],ts.quantile(quantile), f\"{ts.quantile(quantile):.2f}\")\n", + " return ts_ * sc_ * th_ * th_text_\n", + "\n", + "def plot_extreme(ts:pd.Series, quantile: float, duration_cluster: int = 72, color = 'k', label = \"\"):\n", + " ext = get_extremes(ts, \"POT\", threshold=ts.quantile(quantile), r=f\"{duration_cluster}H\")\n", + " ts_ = hv.Curve(ts,label=label).opts(color=color, **ts_view, show_grid=True, alpha = 0.7,)\n", + " sc_ = hv.Scatter(ext,label=label).opts(opts.Scatter(line_color=\"black\", fill_color=color, size=8))\n", + " th_ = hv.HLine(ts.quantile(quantile), label=label).opts(opts.HLine(color=color, line_dash=\"dashed\"))\n", + " th_text_ = hv.Text(ts.index[int(len(ts)/2)],ts.quantile(quantile), f\"{ts.quantile(quantile):.2f}\")\n", + " return ts_ * sc_ * th_ * th_text_\n", + "\n", + "\n", + "mod_plot = plot_extreme_raster(mod_, 0.9, color = 'blue', label = \"model\")\n", + "obs_plot = plot_extreme_raster(obs_, 0.9, color = 'red', label = \"observed\")\n", + "\n", + "mod_plot * obs_plot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "extremes = pd.concat([modeled_extremes, observed_extremes], axis=1)\n", + "extremes = extremes.groupby(pd.Grouper(freq='3D')).mean().dropna(how='all')\n", + "extremes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "extremes_match = extremes.groupby(pd.Grouper(freq='3D')).mean().dropna()\n", + "extremes_match" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"database match: {len(extremes_match)/len(extremes)*100}%, over {len(extremes)} total storms\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "let's build useful metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "extremes_match['difference'] = extremes_match['modeled'] - extremes_match['observed']\n", + "extremes_match['norm_diff'] = extremes_match['difference']/extremes_match['observed']\n", + "extremes_match['error'] = extremes_match[\"norm_diff\"].abs()\n", + "# extremes_match['mean'] = (extremes_match['observed'] + extremes_match['modeled'])/2\n", + "extremes_match" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# R1: diff for the biggest storm in each dataset\n", + "idx_max = extremes_match['observed'].idxmax()\n", + "R1 = extremes_match['error'][idx_max]\n", + "R1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# R3: Difference between observed and modelled for the biggest storm\n", + "idx_max = extremes_match['observed'].nlargest(3).index\n", + "R3 = extremes_match['error'][idx_max].mean()\n", + "R3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "extremes_match['error'].mean()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "let's build a function to calculate the storm metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_extremes_ts(ts1: pd.Series, ts2: pd.Series, quantile: float, cluster_duration:int = 72): \n", + " # first ts\n", + " threshold = ts1.quantile(quantile)\n", + " ext_ = get_extremes(ts1, \"POT\", threshold=threshold, r=f\"{cluster_duration}h\")\n", + " extremes1 = pd.DataFrame({\"modeled\" : ext_, \"time_model\" : ext_.index}, index=ext_.index)\n", + " # second ts\n", + " threshold = ts2.quantile(quantile)\n", + " ext_ = get_extremes(ts2, \"POT\", threshold=threshold, r=f\"{cluster_duration}h\")\n", + " extremes2 = pd.DataFrame({\"observed\" : ext_, \"time_obs\" : ext_.index}, index=ext_.index)\n", + " extremes = pd.concat([extremes1, extremes2], axis=1)\n", + " if extremes.empty:\n", + " return pd.DataFrame()\n", + " else: \n", + " extremes = extremes.groupby(pd.Grouper(freq='2D')).mean().dropna(how='all')\n", + " return extremes\n", + "\n", + "\n", + "def match_extremes(extremes: pd.DataFrame):\n", + " if extremes.empty:\n", + " return pd.DataFrame()\n", + " extremes_match = extremes.groupby(pd.Grouper(freq='2D')).mean().dropna()\n", + " if len(extremes_match) == 0:\n", + " return pd.DataFrame()\n", + " else: \n", + " extremes_match['difference'] = extremes_match['observed'] - extremes_match['modeled']\n", + " extremes_match['error'] = np.abs(extremes_match['difference']/extremes_match['observed'])\n", + " extremes_match['error_m'] = extremes_match[\"error\"] * extremes_match['observed']\n", + " return extremes_match\n", + "\n", + "\n", + "def storm_metrics(ts1: pd.Series, ts2: pd.Series, quantile: float, cluster_duration:int = 72\n", + ")->Dict[str, float]:\n", + " extremes = get_extremes_ts(ts1, ts2, quantile, cluster_duration)\n", + " extremes_match = match_extremes(extremes)\n", + " if extremes_match.empty:\n", + " return {\n", + " \"db_match\" : np.nan,\n", + " \"R1_norm\": np.nan,\n", + " \"R1\": np.nan,\n", + " \"R3_norm\": np.nan,\n", + " \"R3\": np.nan,\n", + " \"error\": np.nan,\n", + " \"error_metric\": np.nan\n", + " }\n", + " else: \n", + " # R1: diff for the biggest storm in each dataset\n", + " idx_max = extremes_match['observed'].idxmax()\n", + " R1_norm = extremes_match['error'][idx_max]\n", + " R1 = extremes_match['difference'][idx_max]\n", + " # R3: Difference between observed and modelled for the biggest storm\n", + " idx_max = extremes_match['observed'].nlargest(3).index\n", + " R3_norm = extremes_match['error'][idx_max].mean()\n", + " R3 = extremes_match['difference'][idx_max].mean()\n", + " metrics = {\n", + " \"db_match\" : len(extremes_match)/len(extremes),\n", + " \"R1_norm\": R1_norm, \n", + " \"R1\": R1, \n", + " \"R3_norm\": R3_norm, \n", + " \"R3\": R3, \n", + " \"error\": extremes_match['error'].mean(),\n", + " \"error_metric\": extremes_match['error_m'].mean()\n", + " }\n", + " return metrics\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "get_extremes_ts(mod_, obs_, quantile=0.99, cluster_duration=72)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "storm_metrics(mod_, obs_, quantile=0.99, cluster_duration=72)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "storm_metrics(mod_, obs_, quantile=0.95, cluster_duration=72)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "we can use `get_extremes_ts` or `match_extremes` to provide insightful info scatter plot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cmap_ = bp.Turbo256\n", + "def scatter_plot_raster(ts1: pd.Series, ts2: pd.Series, quantile: float, cluster_duration:int = 72, pp_plot: bool = False):\n", + " extremes = get_extremes_ts(ts1, ts2, quantile, cluster_duration)\n", + " extremes_match = match_extremes(extremes)\n", + " p = hv.Points((ts1.values, ts2.values))\n", + " sc_ = spread(rasterize(p)).opts(cmap=cmap_, cnorm='linear', alpha = 0.9, **scatter_view)\n", + " if extremes_match.empty:\n", + " ext_ = hv.Points((0,0))\n", + " else: \n", + " ext_ = hv.Points((extremes_match['modeled'].values, extremes_match['observed']),label=f\"extremes\").opts(size = 8, fill_color='r', line_color = 'k')\n", + " ax_plot = hv.Slope(1,0).opts(color='grey', show_grid=True)\n", + " lr = linregress(ts1, ts2)\n", + " lr_plot = hv.Slope(lr.slope,lr.intercept, label = f\"y = {lr.slope:.2f}x + {lr.intercept:.2f}\").opts(color='red',line_dash=\"dashed\")\n", + " # \n", + " if pp_plot: \n", + " pc1, pc2 = get_percentiles(ts1, ts2,higher_tail=True)\n", + " ppp = hv.Scatter((pc1, pc2),('modeled', 'observed'), label=f\"percentiles\").opts(fill_color='g', line_color = 'b', size=10)\n", + " return ax_plot * lr_plot * sc_ * ext_ * ppp\n", + " else: \n", + " return ax_plot * lr_plot * sc_ * ext_\n", + " \n", + "def scatter_plot(ts1: pd.Series, ts2: pd.Series, quantile: float, cluster_duration:int = 72, pp_plot: bool = False):\n", + " extremes = get_extremes_ts(ts1, ts2, quantile, cluster_duration)\n", + " extremes_match = match_extremes(extremes)\n", + " p = hv.Points((ts1.values, ts2.values))\n", + " sc_ = p.opts(alpha = 0.9, **scatter_view)\n", + " if extremes_match.empty:\n", + " ext_ = hv.Points((0,0))\n", + " else: \n", + " ext_ = hv.Points((extremes_match['modeled'].values, extremes_match['observed']),label=f\"extremes\").opts(size = 8, fill_color='r', line_color = 'k')\n", + " ax_plot = hv.Slope(1,0).opts(color='grey', show_grid=True)\n", + " lr = linregress(ts1, ts2)\n", + " lr_plot = hv.Slope(lr.slope,lr.intercept, label = f\"y = {lr.slope:.2f}x + {lr.intercept:.2f}\").opts(color='red',line_dash=\"dashed\")\n", + "\n", + " # \n", + " if pp_plot: \n", + " pc1, pc2 = get_percentiles(ts1, ts2)\n", + " ppp = hv.Scatter((pc1, pc2),('modeled', 'observed'), label=f\"percentiles\").opts(fill_color='g', line_color = 'k', alpha=0.9, size=8)\n", + " return ax_plot * lr_plot * sc_ * ext_ * ppp\n", + " else: \n", + " return ax_plot * lr_plot * sc_ * ext_\n", + " \n", + "scatter_plot_raster(mod_, obs_, quantile=0.99, cluster_duration=72, pp_plot=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2 - Example for `viti` (Viti Levu, Fiji Islands):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ioc_code = 'viti'\n", + "obs = get_obs(ts_folder, ioc_code, '.parquet')\n", + "mod, slon, slat = get_model(model_file, ioc_code, surge_stations)\n", + "mod_, obs_ = align_ts(mod, obs)\n", + "stats_ = get_stats(mod_, obs_)\n", + "\n", + "mod_plot = plot_extreme_raster(mod_, 0.99, color = 'blue', label='model')\n", + "obs_plot = plot_extreme_raster(obs_, 0.99, color ='red', label='obs')\n", + "\n", + "mod_plot * obs_plot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "scatter_plot_raster(mod_, obs_, quantile=0.99, cluster_duration=72)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "as we can see on the graphs above, since there is no extreme event recorded in 2023, the peaks in the modeled & observed TS don't match" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "storm_metrics(mod_, obs_, quantile=0.99, cluster_duration=72)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## all stations\n", + "now we can iterate over all stations and calculate the stats for the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "stats = {}\n", + "for v in versions.keys():\n", + " model_file = versions[v]\n", + " if v not in stats:\n", + " stats[v] = {}\n", + " for i_s, name in enumerate(surge_stations.Station_Name):\n", + " ioc_code = surge_stations.iloc[i_s].ioc_code\n", + " print(name, ioc_code)\n", + " obs = get_obs(ts_folder, ioc_code, '.parquet')\n", + " mod, mlon, mlat = get_model(model_file, ioc_code, surge_stations)\n", + " mod_, obs_ = align_ts(mod, obs)\n", + " stats_ = get_stats(mod_, obs_)\n", + " # try: \n", + " # storm metrics\n", + " metric99 = storm_metrics(mod_, obs_, quantile=0.99, cluster_duration=72)\n", + " metric95 = storm_metrics(mod_, obs_, quantile=0.95, cluster_duration=72)\n", + " # Create a dictionary for the current version's statistics\n", + " stats_[\"obs_lat\"]= surge_stations.iloc[i_s].latitude\n", + " stats_[\"obs_lon\"]= surge_stations.iloc[i_s].longitude\n", + " stats_[\"mod_lat\"]= float(mlat)\n", + " stats_[\"mod_lon\"]= float(mlon)\n", + " stats_[\"R1\"] = metric99[\"R1\"]\n", + " stats_[\"R1_norm\"] = metric99[\"R1_norm\"]\n", + " stats_[\"R3\"] = metric99[\"R3\"]\n", + " stats_[\"R3_norm\"] = metric99[\"R3_norm\"]\n", + " stats_[\"error99\"] = metric99[\"error\"]\n", + " stats_[\"error99m\"] = metric99[\"error_metric\"]\n", + " stats_[\"error95\"] = metric95[\"error\"]\n", + " stats_[\"error95m\"] = metric95[\"error_metric\"]\n", + " # Create a dictionary for the current version's statistics\n", + " stats[v][ioc_code] = stats_\n", + " # except Exception as e:\n", + " # print(e)\n", + " # continue\n", + "with open(f'stats_all.json', 'w') as f:\n", + " json.dump(json_format(stats), f, indent=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## second part: visualisation of the results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "open JSON" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(f'stats_all.json') as f:\n", + " stats = json.load(f)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "load word maritime areas from https://tomsail.github.io/static/renumber.html (no direct download possible)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "oceans_ = gp.read_file('assets/world_oceans_final.json')\n", + "oceans_plot = oceans_.hvplot(color = 'name', alpha=0.5, height=800, width= 1200,cmap = 'glasbey', legend = False)\n", + "countries = gp.read_file(gp.datasets.get_path('naturalearth_lowres'))\n", + "map_ = countries.hvplot().opts(color='grey',line_alpha=0.9, tools=[])\n", + "good_obs = surge_stations[surge_stations.ioc_code.isin(stats['v0'].keys())]\n", + "obs_ = good_obs.hvplot.scatter(x = 'longitude', y = 'latitude', color = 'r', line_color='k', height=600, width= 1200, hover_cols=['ioc_code','Station_Name'])\n", + "oceans_plot * map_ * obs_" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Assuming oceans_ is already read and available as a GeoDataFrame and stations is the stations dataframe\n", + "def find_ocean_for_station(station, oceans_df, xstr = \"longitude\", ystr = \"latitude\"):\n", + " # Create a Point object for the station's location\n", + " point = gp.GeoSeries([shapely.Point(station[xstr], station[ystr])], crs=\"EPSG:4326\")\n", + " \n", + " # Check for each ocean if the point is within it, and return the ocean's name\n", + " for _, ocean in oceans_df.iterrows():\n", + " if point.within(ocean['geometry']).any():\n", + " return ocean['name']\n", + " return None\n", + "\n", + "# Apply the function to the stations dataframe\n", + "surge_stations['ocean'] = surge_stations.apply(lambda station: find_ocean_for_station(station, oceans_), axis=1)\n", + "\n", + "ocean_counts = surge_stations['ocean'].value_counts().reset_index()\n", + "ocean_counts.columns = ['ocean', 'count']\n", + "\n", + "hv_ocean_counts = hv.Dataset(ocean_counts)\n", + "ocean_histogram = hv_ocean_counts.to(hv.Bars, 'ocean', 'count')\n", + "ocean_histogram.opts(opts.Bars(width=1000, height=500, show_grid=True, tools=['hover'], xrotation=45))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initial plot with the first version selected\n", + "initial_version = list(stats.keys())[0]\n", + "df = pd.DataFrame(stats[initial_version]).T.reset_index()\n", + "\n", + "def scatter_hist(src, x, y, z):\n", + " p = hv.Points(src, kdims=[x,y]\n", + " ).hist(num_bins=90, dimension=[x, y]).opts(\n", + " opts.Points(\n", + " show_title=False, \n", + " tools=['hover','box_select', 'tap'], \n", + " size = 10, color=z, \n", + " cmap=\"rainbow4\", \n", + " line_color='k', \n", + " # line_='Category20',\n", + " # line_width=2,\n", + " # show_legend = False, \n", + " colorbar = True, \n", + " ), \n", + " opts.Histogram(tools=['hover','box_select']),\n", + " # opts.Layout(shared_axes=True, shared_datasource=True, merge_tools=True,), \n", + " )\n", + " return p\n", + "\n", + "# Assuming plot_signal and plot_model are defined as follows:\n", + "def plot_all(ioc_code, model_file, ts_folder):\n", + " # get obs\n", + " obs = get_obs(ts_folder, ioc_code, '.parquet')\n", + " mod, mlon, mlat = get_model(model_file, ioc_code, surge_stations)\n", + " mod_, obs_ = align_ts(mod, obs)\n", + " mod_plot = plot_extreme(mod_, 0.95, color = 'blue', label='model')\n", + " obs_plot = plot_extreme(obs_, 0.95, color ='red', label='obs')\n", + " corr_plot = scatter_plot(mod_, obs_, quantile=0.99, cluster_duration=72)\n", + " return obs_plot, mod_plot, corr_plot\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_std_dev_circles(std_dev_range: np.ndarray) -> hv.Overlay:\n", + " std_dev_circles = []\n", + " for std in std_dev_range:\n", + " angle = np.linspace(0, np.pi/2, 100)\n", + " radius = np.full(100, std)\n", + " x = radius * np.cos(angle)\n", + " y = radius * np.sin(angle)\n", + " std_dev_circles.append(\n", + " hv.Curve((x, y)).opts(color='gray', line_dash='dotted', line_width=1)\n", + " )\n", + " return hv.Overlay(std_dev_circles)\n", + "\n", + "\n", + "def create_std_ref(radius: float) -> hv.Overlay:\n", + " angle = np.linspace(0, np.pi/2, 100)\n", + " x = radius * np.cos(angle)\n", + " y = radius * np.sin(angle)\n", + " return hv.Curve((x, y)).opts(color='gray', line_dash='dashed', line_width=2) * \\\n", + " hv.Text(radius, 0., f'REF', halign='right', valign='bottom').opts(\n", + " text_font_size='10pt', text_color='gray')\n", + "\n", + "\n", + "def create_corr_lines(corr_range: np.ndarray, std_dev_max: float) -> hv.Overlay:\n", + " corr_lines = []\n", + " for corr in corr_range:\n", + " theta = np.arccos(corr)\n", + " radius = np.linspace(0, std_dev_max, 2)\n", + " x = radius * np.cos(theta)\n", + " y = radius * np.sin(theta)\n", + " corr_lines.append(\n", + " hv.Curve((x, y)).opts(color='blue', line_dash='dashed', line_width=1) *\n", + " hv.Text(x[-1], y[-1], f'{corr:.2f}', halign='left', valign='bottom').opts(\n", + " text_font_size='10pt', text_color='blue')\n", + " )\n", + " corr_label = hv.Text( 0.75 * std_dev_max, 0.75 * std_dev_max, f'Correlation Coefficient' ).opts( text_font_size='12pt', text_color='blue', angle=-45 )\n", + " return hv.Overlay(corr_lines) * corr_label\n", + "\n", + "\n", + "def create_rms_contours(standard_ref: float, std_dev_max: float, rms_range: np.ndarray, norm:bool) -> hv.Overlay:\n", + " rms_contours = []\n", + " for rms in rms_range:\n", + " angle = np.linspace(0, np.pi, 100)\n", + " x = standard_ref + rms * np.cos(angle)\n", + " y = rms * np.sin(angle)\n", + " inside_max_std = np.sqrt(x**2 + y**2) < std_dev_max\n", + " x[~inside_max_std] = np.nan\n", + " y[~inside_max_std] = np.nan\n", + " rms_contours.append(\n", + " hv.Curve((x, y)).opts(color='green', line_dash='dashed', line_width=1) *\n", + " hv.Text(standard_ref + rms * np.cos(2*np.pi/3), rms * np.sin(2*np.pi/3), f'{rms:.2f}', halign='left', valign='bottom').opts(\n", + " text_font_size='10pt', text_color='green')\n", + " )\n", + " label = \"RMS %\" if norm else \"RMS\"\n", + " rms_label = hv.Text( standard_ref, rms_range[1]*np.sin(np.pi/2), label, halign='left', valign='bottom' ).opts( text_font_size='11pt', text_color='green' )\n", + " return hv.Overlay(rms_contours) * rms_label\n", + "\n", + "\n", + "def taylor_diagram(df: pd.DataFrame,\n", + " norm: bool = True, \n", + " marker: str = \"circle\", \n", + " color: str = \"black\", \n", + " label: str = \"Taylor Diagram\"\n", + " ) -> hv.Overlay:\n", + " if df.empty:\n", + " std_range = np.arange(0, 1.5, np.round(1/5, 2))\n", + " corr_range = np.arange(0, 1, 0.1)\n", + " rms_range = np.arange(0, 1.5, np.round(1/5, 2))\n", + " std_dev_overlay = create_std_dev_circles(std_range) * create_std_ref(1)\n", + " corr_lines_overlay = create_corr_lines(corr_range, std_range.max())\n", + " rms_contours_overlay = create_rms_contours(1, std_range.max(), rms_range, norm=norm)\n", + " return std_dev_overlay * corr_lines_overlay * rms_contours_overlay\n", + " theta = np.arccos(df['cr']) # Convert Cr to radians for polar plot\n", + " if norm: \n", + " std_ref = 1\n", + " std_mod = df['std_df1'] / df['std_df2']\n", + " else: \n", + " if len(df) > 1:\n", + " raise ValueError('for not normalised Taylor diagrams, you need only 1 data point')\n", + " std_ref = df['std_df1'].mean()\n", + " std_mod = df['std_df2'].mean()\n", + " # \n", + " std_range = np.arange(0, 1.5 * std_ref, np.round(std_ref/5, 2))\n", + " corr_range = np.arange(0, 1, 0.1)\n", + " rms_range = np.arange(0, 1.5 * std_ref, np.round(std_ref/5, 2))\n", + "\n", + " std_dev_overlay = create_std_dev_circles(std_range) * create_std_ref(std_ref)\n", + " corr_lines_overlay = create_corr_lines(corr_range, std_range.max())\n", + " rms_contours_overlay = create_rms_contours(std_ref, std_range.max(), rms_range, norm=norm)\n", + "\n", + " x = std_mod * np.cos(theta)\n", + " y = std_mod * np.sin(theta)\n", + " df['x'] = x\n", + " df['y'] = y\n", + " df['rms_perc'] = df['rms'] / df['std_df2']\n", + " # hover parameters\n", + " tooltips = [\n", + " ('Bias', '@bias'),\n", + " ('Corr Coef (%)', '@cr'),\n", + " ('RMSE (m)', '@rmse'),\n", + " ('Centered RMS (m)', '@rms'),\n", + " ('KGE', '@kge'),\n", + " ('Std Dev Model (m)', '@std_df1'),\n", + " ('Std Dev Measure (m)', '@std_df2'),\n", + " ('Station (m)', '@ioc_code'),\n", + " ('Ocean', '@ocean'),\n", + " ]\n", + " if norm: \n", + " tooltips.append(('RMS %', '@rms_perc'))\n", + " hover = HoverTool(tooltips=tooltips)\n", + "\n", + " # Scatter plot for models with hover tool\n", + " scatter_plot = hv.Points(\n", + " df, ['x', 'y'],['cr', 'std_df1', 'std_df2', 'rms', 'rmse', 'rms_perc', 'ioc_code', 'ocean'],label=label\n", + " ).opts(\n", + " color=color,\n", + " # cmap='Category20',\n", + " # line_color='k', \n", + " line_width=1,\n", + " marker = marker,\n", + " size=8, \n", + " tools=[hover],\n", + " default_tools=[],\n", + " show_legend=True,\n", + " hover_fill_color='firebrick',\n", + " xlim=(0, std_range.max()*1.05),\n", + " ylim=(0, std_range.max()*1.05),\n", + " # clabel = 'ocean'\n", + " )\n", + " # Combine all the elements\n", + " taylor_diagram = scatter_plot\n", + " return taylor_diagram\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def layout_fun(obs_, mod_, corr_):\n", + " tgt = (obs_ * mod_ ).opts(width=ts_view['width'], labelled=['y'], show_grid=True)\n", + " src = (obs_ * mod_ ).opts(width=ts_view['width'], height = 100, yaxis=None, default_tools=[])\n", + " RangeToolLink(src, tgt)\n", + " ts_plot = (src + tgt).cols(1)\n", + " # \n", + " corr_.opts(**scatter_view, show_grid=True)\n", + " return (ts_plot + corr_).cols(1)\n", + "\n", + "\n", + "def stacked_hist(plot, element):\n", + " \"\"\"found here https://discourse.holoviz.org/t/stacked-histogram/6205/2\"\"\"\n", + " offset = 0\n", + " for r in plot.handles[\"plot\"].renderers:\n", + " r.glyph.bottom = \"bottom\"\n", + "\n", + " data = r.data_source.data\n", + " new_offset = data[\"top\"] + offset\n", + " data[\"top\"] = new_offset\n", + " data[\"bottom\"] = offset * np.ones_like(data[\"top\"])\n", + " offset = new_offset\n", + "\n", + " plot.handles[\"plot\"].y_range.end = max(offset) * 1.1\n", + " plot.handles[\"plot\"].y_range.reset_end = max(offset) * 1.1\n", + "\n", + "\n", + "\n", + "def hist_(src,z, g = 'ocean', map = None):\n", + " if z in ['rmse', 'rms', 'bias']:\n", + " range_ = (0,0.5)\n", + " else: \n", + " range_ = (0,1)\n", + " \n", + " df = src[[z,g]].reset_index()\n", + " # \n", + " unique_oceans = df[g].unique()\n", + " # Create a new DataFrame with one-hot encoded structure\n", + " rows = []\n", + " for index, row in df.iterrows():\n", + " new_row = {group: np.nan for group in unique_oceans}\n", + " new_row[row[g]] = row[z]\n", + " rows.append(new_row)\n", + " one_hot_df = pd.DataFrame(rows, columns=unique_oceans)\n", + " # \n", + " mean = src[z].mean()\n", + " color_key = hv.Cycle('Category20').values\n", + " # only way to get the colors to match the ocean mapping\n", + " if map is None: \n", + " map = {ocean: color_key[i % len(color_key)] for i, ocean in enumerate(unique_oceans)}\n", + " colors = [map[ocean] for ocean in unique_oceans]\n", + " return one_hot_df.hvplot.hist(\n", + " bins=20, \n", + " bin_range = range_,\n", + " # cmap = ocean_mapping, \n", + " color = colors, \n", + " ).opts(\n", + " hooks=[stacked_hist], \n", + " **scatter_view, \n", + " title = f\"{z} mean: {mean:.2f}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initial ts plot \n", + "SURGE_FOLDER = \"obs/surge/\"\n", + "obs = pd.read_parquet(SURGE_FOLDER+\"abed.parquet\")\n", + "zeros = pd.Series(index=obs.index)\n", + "empt_ = hv.Curve(zeros)\n", + "empty_layout_ts = layout_fun(empt_, empt_, scatter_plot(zeros, zeros, 0.9))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define a class to encapsulate dashboard logic and state\n", + "import xarray as xr\n", + "\n", + "class Dashboard(param.Parameterized):\n", + " version = param.Selector(objects=list(stats.keys()))\n", + " parameter = param.Selector(objects=['rmse', 'rms', 'bias', 'kge', 'nse2', 'R1_norm', 'R3_norm','mad','madp', 'error95','error99','R1', 'R3','error95m','error99m', 'std_df2', 'nse2', 'lamba', 'cr'])\n", + " selected_station = param.Integer(default=0)\n", + " \n", + " def __init__(self, **params):\n", + " super().__init__(**params)\n", + " self.oceans_ = gp.read_file('assets/world_oceans_final.json')\n", + " self.df = pd.DataFrame()\n", + " self.ds = None\n", + " self.ocean_mapping = {}\n", + " self.update_data()\n", + "\n", + " @param.depends('version', watch=True)\n", + " def update_data(self):\n", + " self.df = pd.DataFrame(stats[self.version]).T\n", + " self.df = self.df.astype(float)\n", + " self.df['ocean'] = self.df.apply(lambda station: find_ocean_for_station(station, oceans_, \"obs_lon\", \"obs_lat\"), axis=1)\n", + " # Create a color mapping for oceans\n", + " unique_oceans = self.df['ocean'].unique()\n", + " color_key = hv.Cycle('Category20').values\n", + " self.ocean_mapping = {ocean: color_key[i % len(color_key)] for i, ocean in enumerate(unique_oceans)}\n", + " # Apply the color mapping to the oceans map\n", + " self.map_ = self.oceans_[self.oceans_[\"name\"].isin(unique_oceans)].hvplot(\n", + " color='name',\n", + " alpha=0.9,\n", + " **ts_view,\n", + " cmap=self.ocean_mapping, # Use the color mapping dictionary\n", + " tools=[],\n", + " legend=False\n", + " ) * map_\n", + " self.df = self.df.dropna()\n", + " self.df['ioc_code'] = self.df.index\n", + " self.df.loc[self.df['nse'] > 0, 'nse2'] = self.df['nse']\n", + " self.df.loc[self.df['nse'] < 0, 'nse2'] = 0\n", + " self.file_ds = f\"{versions[self.version]}\"\n", + " self.ds = xr.open_dataset(f\"{versions[self.version]}\")\n", + "\n", + " def update_ts(self, index):\n", + " if not index:\n", + " return empty_layout_ts\n", + " else: \n", + " station_name = self.df.iloc[index[0]]['ioc_code'] # 'index' column holds the station names after reset_index\n", + " obs_, mod_, corr_ = plot_all(station_name, self.file_ds, SURGE_FOLDER)\n", + " layout = layout_fun(obs_, mod_, corr_)\n", + " return layout\n", + " \n", + " @param.depends('version', 'parameter')\n", + " def view(self):\n", + " # Update the DataFrame based on the selected version\n", + " self.update_data()\n", + " \n", + " # Your plotting logic here (simplified for this example)\n", + " scatter_ = scatter_hist(self.df, 'obs_lon', 'obs_lat', self.parameter)\n", + " \n", + " stream = streams.Selection1D(source=scatter_[0])\n", + " time_series = hv.DynamicMap(self.update_ts, streams=[stream])\n", + "\n", + " # Update the layout with new plots\n", + " layout = pn.Column(self.map_ * scatter_, time_series.opts(shared_axes=False))\n", + " # layout = pn.Column(scatter_)\n", + " return layout\n", + "\n", + " @param.depends('version')\n", + " def taylor(self):\n", + " self.update_data()\n", + " diagram = taylor_diagram(pd.DataFrame())\n", + " for ocean in self.ocean_mapping.keys():\n", + " df = self.df[self.df['ocean'] == ocean]\n", + " diagram *= taylor_diagram(df, norm=True, color=self.ocean_mapping[ocean], label=ocean)\n", + " return diagram.opts(**scatter_view, shared_axes=False)\n", + " \n", + " @param.depends('version', 'parameter')\n", + " def hist(self):\n", + " # Update the DataFrame based on the selected version\n", + " self.update_data()\n", + " hist = hist_(self.df, self.parameter, g = 'ocean', map = self.ocean_mapping)\n", + " return hist.opts(shared_axes=False)\n", + "\n", + "\n", + "# Instantiate the dashboard and create the layout\n", + "dashboard = Dashboard()\n", + "layout = pn.Row(pn.Column(pn.Row(dashboard.param.version, \n", + " dashboard.param.parameter), \n", + " dashboard.view), \n", + " pn.Column(dashboard.taylor, dashboard.hist))\n", + "\n", + "# Serve the Panel app\n", + "pn.serve(layout)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/index.html b/index.html index 331cb9d..3d38329 100644 --- a/index.html +++ b/index.html @@ -46,6 +46,7 @@