Source code for cmip6_downscaling.methods.gard.utils

from __future__ import annotations

from typing import Any

import numpy as np
import xarray as xr
from scipy.special import cbrt
from scipy.stats import norm as norm
from skdownscale.pointwise_models import AnalogRegression, PureAnalog, PureRegression

from ..common.containers import RunParameters

xr.set_options(keep_attrs=True)


[docs] def get_gard_model( model_type: str, model_params: dict[str, Any], ) -> AnalogRegression | PureAnalog | PureRegression: """ Based on input, return the corresponding GARD model instance Parameters ---------- model_type : str Name of the GARD model type to be used, should be one of AnalogRegression, PureAnalog, or PureRegression model_params : Dict Model parameter dictionary Returns ------- model : AnalogRegression, PureAnalog, or PureRegression model instance skdownscale GARD model instance """ if model_type == 'AnalogRegression': return AnalogRegression(**model_params) elif model_type == 'PureAnalog': return PureAnalog(**model_params) elif model_type == 'PureRegression': return PureRegression(**model_params) else: raise NotImplementedError( 'model_type must be AnalogRegression, PureAnalog, or PureRegression' )
[docs] def add_random_effects( model_output: xr.Dataset, scrf: xr.DataArray, run_parameters: RunParameters ) -> xr.Dataset: if run_parameters.model_params is not None: thresh = run_parameters.model_params.get('thresh') else: thresh = None if thresh is not None: # convert scrf from a normal distribution to a uniform distribution scrf_uniform = xr.apply_ufunc( norm.cdf, scrf, dask='parallelized', output_dtypes=[scrf.dtype] ) # find where exceedance prob is exceeded mask = scrf_uniform > (1 - model_output['exceedance_prob']) # Rescale the uniform distribution new_uniform = (scrf_uniform - (1 - model_output['exceedance_prob'])) / model_output[ 'exceedance_prob' ] # Get the normal distribution equivalent of new_uniform r_normal = xr.apply_ufunc( norm.ppf, new_uniform, dask='parallelized', output_dtypes=[new_uniform.dtype] ) if run_parameters.variable == 'pr': downscaled = ( cbrt(model_output['pred']) + (model_output['prediction_error'] * r_normal) ) ** 3 else: downscaled = model_output['pred'] + r_normal * model_output['prediction_error'] # what do we do for thresholds like heat wave? valids = np.logical_or(mask, downscaled >= 0) downscaled = downscaled.where(valids, 0) downscaled = downscaled.where(downscaled >= 0, 0) else: downscaled = model_output['pred'] + scrf * model_output['prediction_error'] return downscaled.to_dataset(name=run_parameters.variable)