Source code for cmip6_downscaling.methods.bcsd.tasks
from __future__ import annotations
import warnings
import dask
import xarray as xr
from carbonplan_data.metadata import get_cf_global_attrs
from prefect import task
from skdownscale.pointwise_models import PointWiseDownscaler
from skdownscale.pointwise_models.bcsd import BcsdPrecipitation, BcsdTemperature
from upath import UPath
from ... import __version__ as version
from ... import config
from ...constants import ABSOLUTE_VARS, RELATIVE_VARS
from ...utils import str_to_hash
from ..common.containers import RunParameters
from ..common.utils import apply_land_mask, zmetadata_exists
from .utils import reconstruct_finescale
xr.set_options(keep_attrs=True)
warnings.filterwarnings(
"ignore",
"(.*) filesystem path not explicitly implemented. falling back to default implementation. This filesystem may not be tested",
category=UserWarning,
)
intermediate_dir = UPath(config.get("storage.intermediate.uri")) / version
results_dir = UPath(config.get("storage.results.uri")) / version
use_cache = config.get('run_options.use_cache')
[docs]
@task(log_stdout=True)
def spatial_anomalies(obs_full_time_path: UPath, interpolated_obs_full_time_path: UPath) -> UPath:
"""Returns spatial anomalies
Calculate the seasonal cycle (12 timesteps) spatial anomaly associated
with aggregating the fine_obs to a given coarsened scale and then reinterpolating
it back to the original spatial resolution. The outputs of this function are
dependent on three parameters:
* a grid (as opposed to a specific GCM since some GCMs run on the same grid)
* the time period which fine_obs (and by construct coarse_obs) cover
* the variable
We will save these anomalies to use them in the post-processing. We will add them to the
spatially-interpolated coarse predictions to add the spatial heterogeneity back in.
Conceptually, this step figures out, for example, how much colder a finer-scale pixel
containing Mt. Rainier is compared to the coarse pixel where it exists. By saving those anomalies,
we can then preserve the fact that "Mt Rainier is x degrees colder than the pixels around it"
for the prediction. It is important to note that that spatial anomaly is the same for every month of the
year and the same for every day. So, if in January a finescale pixel was on average 4 degrees colder than
the neighboring pixel to the west, in every day in the prediction (historic or future) that pixel
will also be 4 degrees colder.
Parameters
----------
obs_full_time_path : UPath
UPath to observation dataset chunked in full_time.
interpolated_obs_full_time_path : UPath
UPath to observation dataset interpolated to gcm grid and chunked in full time.
Returns
-------
UPath
Path to spatial anomalies dataset. (shape (nlat, nlon, 12))
"""
ds_hash = str_to_hash(str(obs_full_time_path) + str(interpolated_obs_full_time_path))
target = intermediate_dir / 'bcsd_spatial_anomalies' / ds_hash
if use_cache and zmetadata_exists(target):
print(f"found existing target: {target}")
return target
interpolated_obs_full_time_ds = xr.open_zarr(interpolated_obs_full_time_path)
obs_full_time_ds = xr.open_zarr(obs_full_time_path)
# calculate the difference between the actual obs (with finer spatial heterogeneity)
# and the interpolated coarse obs this will be saved and added to the
# spatially-interpolated coarse predictions to add the spatial heterogeneity back in.
spatial_anomalies = obs_full_time_ds - interpolated_obs_full_time_ds
seasonal_cycle_spatial_anomalies = spatial_anomalies.groupby("time.month").mean()
seasonal_cycle_spatial_anomalies.attrs.update(
{'title': 'bcsd_spatial_anomalies'}, **get_cf_global_attrs(version=version)
)
seasonal_cycle_spatial_anomalies.to_zarr(target, mode='w')
return target
def _fit_and_predict_wrapper(xtrain, ytrain, xpred, run_parameters, dim='time'):
"""Wrapper for map_blocks for fit and predict task
Parameters
----------
xtrain : xr.Dataset
Experiment training dataset
ytrain : xr.Dataset
Observation training dataset
xpred : xr.Dataset
Experiment prediction dataset
run_parameters : RunParameters
Prefect run parameters
dim : str, optional
dimension, by default 'time'
Returns
-------
xr.Dataset
Output bias corrected dataset
Raises
------
ValueError
raise ValueError if the given variable is not implimented.
"""
xpred = xpred.rename({'t2': 'time'})
if run_parameters.variable in ABSOLUTE_VARS:
model = BcsdTemperature(return_anoms=False)
elif run_parameters.variable in RELATIVE_VARS:
model = BcsdPrecipitation(return_anoms=False)
else:
raise ValueError('run_parameters.variable not found in ABSOLUTE_VARS OR RELATIVE_VARS.')
pointwise_model = PointWiseDownscaler(model=model, dim=dim)
pointwise_model.fit(xtrain[run_parameters.variable], ytrain[run_parameters.variable])
bias_corrected_da = pointwise_model.predict(xpred[run_parameters.variable])
bias_corrected_ds = bias_corrected_da.astype('float32').to_dataset(name=run_parameters.variable)
return bias_corrected_ds
[docs]
@task(log_stdout=True)
def fit_and_predict(
experiment_train_full_time_path: UPath,
experiment_predict_full_time_path: UPath,
coarse_obs_full_time_path: UPath,
run_parameters: RunParameters,
) -> UPath:
"""Fit bcsd model on prepared CMIP data with obs at corresponding spatial scale.
Then predict for a set of CMIP data (likely future).
Parameters
----------
experiment_train_full_time_path : UPath
UPath to experiment training dataset chunked in full time
experiment_predict_full_time_path : UPath
UPath to experiment prediction dataset chunked in full time
coarse_obs_full_time_path : UPath
UPath to coarse observation dataset chunked in full time
run_parameters : RunParameters
Prefect run parameters
Returns
-------
UPath
UPath to prediction results dataset.
Raises
------
ValueError
ValueError checking validity of input variables.
"""
title = "bcsd_predictions"
ds_hash = str_to_hash(
str(experiment_train_full_time_path)
+ str(experiment_predict_full_time_path)
+ str(coarse_obs_full_time_path)
)
target = intermediate_dir / 'bcsd_fit_and_predict' / ds_hash
if use_cache and zmetadata_exists(target):
print(f"found existing target: {target}")
return target
xtrain = xr.open_zarr(coarse_obs_full_time_path)
ytrain = xr.open_zarr(experiment_train_full_time_path)
xpred = xr.open_zarr(experiment_predict_full_time_path)
# Create a template dataset for map blocks
# This feals a bit fragile.
template_var = list(xpred.data_vars.keys())[0]
template = (
xpred[[template_var]].astype('float32').rename({template_var: run_parameters.variable})
)
out = xr.map_blocks(
_fit_and_predict_wrapper,
xtrain,
args=(ytrain, xpred.rename({'time': 't2'}), run_parameters),
kwargs={'dim': 'time'},
template=template,
)
out = dask.optimize(out)[0]
out.attrs.update({'title': title}, **get_cf_global_attrs(version=version))
out.to_zarr(target, mode='w')
return target
[docs]
@task(log_stdout=True)
def postprocess_bcsd(
bias_corrected_fine_full_time_path: UPath, spatial_anomalies_path: UPath
) -> UPath:
"""Downscale the bias-corrected data (fit_and_predict results) by interpolating and then
adding the spatial anomalies back in.
Parameters
----------
bias_corrected_fine_full_time_path : UPath
UPath to output dataset from the fit_and_predict task.
spatial_anomalies_path : UPath
UPath to the output of the spatial_anomalies task.
Returns
-------
UPath
UPath to post-processed dataset.
"""
title = "bcsd_postprocess"
ds_hash = str_to_hash(str(bias_corrected_fine_full_time_path) + str(spatial_anomalies_path))
target = results_dir / title / ds_hash
print(target)
if use_cache and zmetadata_exists(target):
print(f"found existing target: {target}")
return target
bias_corrected_fine_full_time_ds = xr.open_zarr(bias_corrected_fine_full_time_path)
# hint for mapblocks about which month each day corresponds to
bias_corrected_fine_full_time_ds = bias_corrected_fine_full_time_ds.assign_coords(
{'month': bias_corrected_fine_full_time_ds['time.month']}
)
spatial_anomalies_ds = xr.open_zarr(spatial_anomalies_path)
# make all spatial anomalies into one chunk so that map_blocks has access to every month.
# otherwise it will only have access to one chunk and will only grab the last chunk (december)
# and result in nans in all months except for december
spatial_anomalies_ds = spatial_anomalies_ds.chunk({'month': -1}).persist()
bcsd_results_ds = xr.map_blocks(
reconstruct_finescale,
bias_corrected_fine_full_time_ds,
args=[spatial_anomalies_ds],
template=bias_corrected_fine_full_time_ds,
)
# masking out ocean regions
bcsd_results_ds = apply_land_mask(bcsd_results_ds)
bcsd_results_ds.attrs.update({'title': title}, **get_cf_global_attrs(version=version))
bcsd_results_ds.to_zarr(target, mode='w')
return target