Source code for cmip6_downscaling.methods.gard.tasks

from dataclasses import asdict

import dask
import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe
from carbonplan_data.metadata import get_cf_global_attrs
from prefect import task
from scipy.special import cbrt
from skdownscale.pointwise_models import PointWiseDownscaler
from skdownscale.pointwise_models.utils import default_none_kwargs
from upath import UPath

from ... import __version__ as version
from ... import config
from ..common.bias_correction import bias_correct_gcm_by_method
from ..common.containers import RunParameters, str_to_hash
from ..common.utils import apply_land_mask, blocking_to_zarr, set_zarr_encoding, zmetadata_exists
from .utils import add_random_effects, get_gard_model

xr.set_options(keep_attrs=True)
scratch_dir = UPath(config.get("storage.scratch.uri"))
intermediate_dir = UPath(config.get("storage.intermediate.uri")) / version
results_dir = UPath(config.get("storage.results.uri")) / version
use_cache = config.get('run_options.use_cache')


[docs] @task(log_stdout=True) def coarsen_and_interpolate(fine_path: UPath, coarse_path: UPath) -> UPath: """Coarsen up obs and then interpolate it back to the original finescale grid. Parameters ---------- fine_path : UPath Path to finescale (likely observational) dataset coarse_path : UPath Path to coarse scale that will be the template for the coarsening. Returns ------- UPath Path to interpolated dataset. """ ds_hash = str_to_hash(str(fine_path) + str(coarse_path)) target = intermediate_dir / 'coarsen_and_interpolate' / ds_hash if use_cache and zmetadata_exists(target): print(f'found existing target: {target}') interpolated_ds = xr.open_zarr(target) return target fine_ds = xr.open_zarr(fine_path) target_ds = xr.open_zarr(coarse_path) # coarsen regridder = xe.Regridder(fine_ds, target_ds, "bilinear", extrap_method="nearest_s2d") coarse_ds = regridder(fine_ds, keep_attrs=True) # interpolate back to the fine grid regridder = xe.Regridder(coarse_ds, fine_ds, "bilinear", extrap_method="nearest_s2d") interpolated_ds = regridder(coarse_ds, keep_attrs=True) interpolated_ds.attrs.update( {'title': 'coarsen_and_interpolate'}, **get_cf_global_attrs(version=version) ) interpolated_ds = set_zarr_encoding(interpolated_ds) blocking_to_zarr(ds=interpolated_ds, target=target, validate=True, write_empty_chunks=True) return target
def _fit_and_predict_wrapper(xtrain, ytrain, xpred, scrf, run_parameters, dim='time'): xpred = xpred.rename({'t2': 'time'}) scrf = scrf.rename({'t2': 'time'}) kws = default_none_kwargs(run_parameters.bias_correction_kwargs, copy=True) # transformed gcm is the interpolated GCM for the prediction period transformed # w.r.t. the interpolated obs used in the training (because that transformation # is essentially part of the model) bias_corrected_gcm_pred = xr.Dataset() for feature in run_parameters.features: bias_corrected_gcm_pred[feature] = ( bias_correct_gcm_by_method( gcm_pred=xpred[feature], method=run_parameters.bias_correction_method, bc_kwargs=kws[feature], obs=xtrain[feature], ) .sel(variable='variable_0') .drop('variable') ) # model definition model = PointWiseDownscaler( model=get_gard_model(run_parameters.model_type, run_parameters.model_params), dim=dim ) # model fitting # # TODO need to fix this to only transform some variables if 'pr' in run_parameters.features: bias_corrected_gcm_pred['pr'] = cbrt(bias_corrected_gcm_pred['pr']) xtrain['pr'] = cbrt(xtrain['pr']) if 'pr' == run_parameters.variable: ytrain['pr'] = cbrt(ytrain['pr']) # TODO: at this point there is negative precip in some chunks - why? # <xarray.Dataset> # Dimensions: (time: 23376, lat: 5, lon: 48) # Coordinates: # * lat (lat) float32 49.0 49.25 49.5 49.75 50.0 # * lon (lon) float32 -113.0 -112.8 -112.5 -112.2 ... -101.8 -101.5 -101.2 # * time (time) datetime64[ns] 1950-01-01 1950-01-02 ... 2013-12-31 # Data variables: # pr (time, lat, lon) float32 0.4851 0.2508 0.1828 ... -0.5607 -0.5607 # tasmax (time, lat, lon) float32 270.3 270.3 270.1 ... 257.0 256.3 256.3 # tasmin (time, lat, lon) float32 261.5 261.3 261.1 ... 254.1 253.4 253.4 model.fit(xtrain[run_parameters.features], ytrain[run_parameters.variable]) out = model.predict(bias_corrected_gcm_pred[run_parameters.features]).to_dataset(dim='variable') if 'pr' == run_parameters.variable: out['pred'] = out['pred'] ** 3 # # model prediction downscaled = add_random_effects(out, scrf.scrf, run_parameters) return downscaled
[docs] @task(log_stdout=True) def fit_and_predict( xtrain_path: UPath, ytrain_path: UPath, xpred_path: UPath, scrf_path: UPath, run_parameters: RunParameters, dim: str = 'time', ) -> UPath: """Prepare inputs (e.g. normalize), use them to fit a GARD model based upon specified parameters and then use that fitted model to make a prediction. Parameters ---------- xtrain_path : UPath Path to training dataset (interpolated GCM) chunked full_time ytrain_path : UPath Path to target dataset (interpolated obs) chunked full_time xhist_path: UPath Path to historical prediction dataset (interpolated GCM) xpred_path : UPath Path to future prediction dataset (interpolated GCM) chunked full_time scrf_path : UPath Path to scrf chunked in full_time run_parameters : RunParameters Parameters for run set-up and model specs dim : str, optional Dimension to apply the model along. Default is ``time``. Returns ------- path : UPath Path to output dataset chunked full_time """ ds_hash = str_to_hash( str(xtrain_path) + str(ytrain_path) + str(xpred_path) + str(scrf_path) + run_parameters.run_id_hash + str(dim) ) target = results_dir / 'gard_fit_and_predict' / ds_hash if use_cache and zmetadata_exists(target): print(f'found existing target: {target}') return target # load in datasets xtrain = xr.open_zarr(xtrain_path).pipe(apply_land_mask) ytrain = xr.open_zarr(ytrain_path).pipe(apply_land_mask) xpred = xr.open_zarr(xpred_path).pipe(apply_land_mask) scrf = xr.open_zarr(scrf_path).pipe(apply_land_mask) # make sure you have the variables you need in obs for v in xpred.data_vars: assert v in ytrain.data_vars # data transformation (this wants full-time chunking) # transformed_obs is for the training period # we need only the prediction GCM (xpred), but we'll transform it into the space of the # transformed interpolated obs (xtrain) # Create a template dataset for map blocks # This feals a bit fragile. template_var = list(xpred.data_vars.keys())[0] template_da = xpred[template_var] template = xr.Dataset() for var in [run_parameters.variable]: template[var] = template_da # rename time variable to play nice with mapblocks - can't have same dimension name on later arguments out = xr.map_blocks( _fit_and_predict_wrapper, xtrain, args=(ytrain, xpred.rename({'time': 't2'}), scrf.rename({'time': 't2'}), run_parameters), kwargs={'dim': dim}, template=template, ) out.attrs.update({'title': 'gard_fit_and_predict'}, **get_cf_global_attrs(version=version)) out = dask.optimize(out)[0] # remove apply_land_mask after scikit-downscale#110 is merged out_ds = out.pipe(apply_land_mask).pipe(set_zarr_encoding) blocking_to_zarr(ds=out_ds, target=target, validate=True, write_empty_chunks=True) return target
[docs] @task(log_stdout=True) def read_scrf(prediction_path: UPath, run_parameters: RunParameters): """ Read spatial-temporally correlated random fields on file and subset into the correct spatial/temporal domain according to model_output. The random fields are stored in decade (10 year) long time series for the global domain and pre-generated using `scrf.ipynb`. Parameters ---------- prediction_path : UPath Path to prediction dataset run_parameters : RunParameters Parameters for run set-up and model specs Returns ------- scrf : xr.DataArray Spatio-temporally correlated random fields (SCRF) """ # TODO: this is a temporary creation of random fields. ultimately we probably want to have # ~150 years of random fields, but this is fine. ds_hash = str_to_hash( "{obs}_{variable}_{latmin}_{latmax}_{lonmin}_{lonmax}_{predict_dates[0]}_{predict_dates[1]}".format( **asdict(run_parameters) ) ) target = intermediate_dir / 'scrf' / ds_hash if use_cache and zmetadata_exists(target): print(f'found existing target: {target}') return target prediction_ds = xr.open_zarr(prediction_path) scrf_ten_years = xr.open_zarr(f'az://static/scrf/ERA5_{run_parameters.variable}_1981_1990.zarr') scrf_list = [] for year in np.arange( int(run_parameters.predict_period.start), int(run_parameters.predict_period.stop) + 10, 10 ): scrf_list.append(scrf_ten_years.drop('time')) scrf = xr.concat(scrf_list, dim='time') scrf['time'] = pd.date_range( start=f'{run_parameters.predict_period.start}-01-01', periods=scrf.dims['time'] ) scrf = scrf.sel(time=run_parameters.predict_period.time_slice) scrf = scrf.drop('spatial_ref').astype('float32') scrf = scrf.sel( lat=prediction_ds.lat.values, lon=prediction_ds.lon.values, time=prediction_ds.time.values ) assert len(scrf.time) == len(prediction_ds.time) assert len(scrf.lat) == len(prediction_ds.lat) assert len(scrf.lon) == len(prediction_ds.lon) scrf = scrf.assign_coords( {'lat': prediction_ds.lat, 'lon': prediction_ds.lon, 'time': prediction_ds.time} ) if (scrf.chunks['lon'][0] != 48) or (scrf.chunks['lat'][0] != 48): scrf = scrf.chunk({'lon': 48, 'lat': 48, 'time': 3652}) scrf = dask.optimize(scrf)[0] scrf = set_zarr_encoding(scrf) blocking_to_zarr(ds=scrf, target=target, validate=True, write_empty_chunks=True) return target