Source code for cmip6_downscaling.methods.common.tasks

from __future__ import annotations

import contextlib
import datetime
import json
import os
import warnings
from dataclasses import asdict
from pathlib import PosixPath

import datatree as dt
import fsspec
import pandas as pd
import rechunker
import xarray as xr
import zarr
from carbonplan_data.metadata import get_cf_global_attrs
from carbonplan_data.utils import set_zarr_encoding as set_web_zarr_encoding
from ndpyramid import pyramid_regrid
from prefect import task
from prefect.triggers import any_failed
from upath import UPath
from xarray_schema import DataArraySchema, DatasetSchema
from xarray_schema.base import SchemaError

from ... import __version__ as version
from ... import config
from ...data.cmip import get_gcm
from ...data.observations import open_era5
from ...utils import str_to_hash
from .containers import RunParameters, TimePeriod
from .utils import (
    blocking_to_zarr,
    calc_auspicious_chunks_dict,
    is_cached,
    resample_wrapper,
    set_zarr_encoding,
    subset_dataset,
    validate_zarr_store,
)

xr.set_options(keep_attrs=True)
warnings.filterwarnings(
    "ignore",
    "(.*) filesystem path not explicitly implemented. falling back to default implementation. This filesystem may not be tested",
    category=UserWarning,
)

PIXELS_PER_TILE = 128
scratch_dir = UPath(config.get("storage.scratch.uri"))
intermediate_dir = UPath(config.get("storage.intermediate.uri")) / version
results_dir = UPath(config.get("storage.results.uri")) / version
use_cache = config.get('run_options.use_cache')


[docs] @task(log_stdout=True) def make_run_parameters(**kwargs) -> RunParameters: """Prefect task wrapper for RunParameters""" return RunParameters(**kwargs)
[docs] @task(log_stdout=True) def get_obs(run_parameters: RunParameters) -> UPath: """Task to return observation data subset from input parameters. Parameters ---------- run_parameters : RunParameters RunParameter dataclass defined in common/conatiners.py. Constructed from prefect parameters. Returns ------- UPath Path to subset observation dataset. """ feature_string = '_'.join(run_parameters.features) frmt_str = "{obs}_{feature_string}_{latmin}_{latmax}_{lonmin}_{lonmax}_{train_dates[0]}_{train_dates[1]}".format( **asdict(run_parameters), feature_string=feature_string ) title = f"obs ds: {frmt_str}" ds_hash = str_to_hash(frmt_str) target = intermediate_dir / 'get_obs' / ds_hash if use_cache and is_cached(target): print(f'found existing target: {target}') return target ds = open_era5(run_parameters.features, run_parameters.train_period) subset = subset_dataset( ds, run_parameters.features, run_parameters.train_period.time_slice, run_parameters.bbox, chunking_schema={'time': 365, 'lat': 150, 'lon': 150}, ) subset = ds.chunk({'time': 365, 'lat': 150, 'lon': 150}) for key in subset.variables: subset[key].encoding = {} subset.attrs.update({'title': title}, **get_cf_global_attrs(version=version)) subset = set_zarr_encoding(subset) blocking_to_zarr(ds=subset, target=target, validate=True, write_empty_chunks=True) return target
[docs] @task(log_stdout=True) def get_experiment(run_parameters: RunParameters, time_subset: str) -> UPath: """Prefect task that returns cmip GCM data from input run parameters. Parameters ---------- run_parameters : RunParameters RunParameter dataclass defined in common/conatiners.py. Constructed from prefect parameters. time_subset : str String describing time subset request. Either 'train_period', 'predict_period', or 'both'. Returns ------- UPath UPath to experiment dataset. """ if time_subset == 'both': time_period = TimePeriod( start=str( min( int(run_parameters.train_period.start), int(run_parameters.predict_period.start) ) ), stop=str( max(int(run_parameters.train_period.stop), int(run_parameters.predict_period.stop)) ), ) else: time_period = getattr(run_parameters, time_subset) features = getattr(run_parameters, 'features') if features: feature_string = '_'.join(features) frmt_str = "{model}_{member}_{scenario}_{feature_string}_{latmin}_{latmax}_{lonmin}_{lonmax}_{time_period.start}_{time_period.stop}".format( time_period=time_period, **asdict(run_parameters), feature_string=feature_string ) else: frmt_str = "{model}_{member}_{scenario}_{variable}_{latmin}_{latmax}_{lonmin}_{lonmax}_{time_period.start}_{time_period.stop}".format( time_period=time_period, **asdict(run_parameters) ) if int(time_period.start) < 2015 and run_parameters.scenario != 'historical': scenarios = ['historical', run_parameters.scenario] else: scenarios = [run_parameters.scenario] title = f"experiment ds: {frmt_str}" ds_hash = str_to_hash(frmt_str) target = intermediate_dir / 'get_experiment' / ds_hash if use_cache and is_cached(target): print(f'found existing target: {target}') return target mode = 'w' for feature in run_parameters.features: ds_list = [] for s in scenarios: ds_list.append( get_gcm( scenario=s, member_id=run_parameters.member, table_id=run_parameters.table_id, grid_label=run_parameters.grid_label, source_id=run_parameters.model, variable=feature, time_slice=time_period.time_slice, ) ) ds = xr.concat(ds_list, dim='time') subset = subset_dataset(ds, feature, time_period.time_slice, run_parameters.bbox) # Note: dataset is chunked into time:365 chunks to standardize leap-year chunking. subset = subset.chunk({'time': 365}) for key in subset.variables: subset[key].encoding = {} subset.attrs.update({'title': title}, **get_cf_global_attrs(version=version)) subset = set_zarr_encoding(subset) subset[[feature]].to_zarr(target, mode=mode) mode = 'a' return target
[docs] @task(log_stdout=True) def rechunk( path: UPath, pattern: str = None, template: UPath = None, max_mem: str = "5GB", ) -> UPath: """Use `rechunker` package to adjust chunks of dataset to a form conducive for your processing. Parameters ---------- path : UPath path to zarr store pattern : str The pattern of chunking you want to use. If used together with `template` it will override the template to ensure that the final dataset truly follows that `full_space` or `full_time` spec. This matters when you are passing a template that is either a shorter time length or a template that is a coarser grid (and thus a shorter lat/lon chunksize) template : UPath The path to the file you want to use as a chunking template. The utility will grab the chunk sizes and use them as the chunk target to feed to rechunker. max_mem : str The memory available for rechunking steps. Must look like "2GB". Optional, default is 5GB. Returns ------- target : UPath Path to rechunked dataset """ # if both defined then you'll take the spatial part of template and override one dimension with the specified pattern if template is not None: pattern_string = 'matched' if pattern is not None: pattern_string += f'_{pattern}' elif pattern is not None: pattern_string = pattern task_hash = str_to_hash(str(path) + pattern_string + str(template) + max_mem) target = intermediate_dir / 'rechunk' / task_hash path_tmp = scratch_dir / 'rechunk' / task_hash target_store = fsspec.get_mapper(str(target)) temp_store = fsspec.get_mapper(str(path_tmp)) if use_cache and is_cached(target): print(f'found existing target: {target}') # if we wanted to check that it was chunked correctly we could put this down below where # the target_schema is validated. but that requires us going through the development # of the schema would just hurt performance likely unnecessarily. # nevertheless, as future note: if we encounter chunk issues i suggest putting a schema check here return target # if a cached target isn't found we'll go through the rechunking step # open the zarr group target_store.clear() temp_store.clear() group = zarr.open_consolidated(path) # open the dataset to access the coordinates ds = xr.open_zarr(path) example_var = list(ds.data_vars)[0] # if you have defined a template then use the chunks of that template # to form the desired chunk definition if template is not None: template_ds = xr.open_zarr(template) # define the chunk definition chunk_def = { 'time': min(template_ds.chunks['time'][0], len(ds.time)), 'lat': min(template_ds.chunks['lat'][0], len(ds.lat)), 'lon': min(template_ds.chunks['lon'][0], len(ds.lon)), } # if you have also defined a pattern then override the dimension you've specified there if pattern is not None: # the chunking pattern will return the dimensions that you'll chunk along # so `full_time` will return `('lat', 'lon')` chunk_dims = config.get(f"chunk_dims.{pattern}") for dim in chunk_def: if dim not in chunk_dims: # override the chunksize of those unchunked dimensions to be the complete length (like passing chunksize=-1 chunk_def[dim] = len(ds[dim]) elif pattern is not None: chunk_dims = config.get(f"chunk_dims.{pattern}") chunk_def = calc_auspicious_chunks_dict(ds[example_var], chunk_dims=chunk_dims) else: raise AttributeError('must either define chunking pattern or template') # Note: # for rechunker v 0.3.3: # initialize the chunks_dict that you'll pass in, filling the coordinates with # `None`` because you don't want to rechunk the coordinate arrays. this works with # for rechunker v 0.4.2: # initialize chunks_dict using the `chunk_def`` above chunks_dict = { 'time': (chunk_def['time'],), 'lon': (chunk_def['lon'],), 'lat': (chunk_def['lat'],), } for var in ds.data_vars: chunks_dict[var] = chunk_def # now that you have your chunks_dict you can check that the dataset at `path` # you're passing in doesn't already match that schema. because if so, we don't # need to bother with rechunking and we'll skip it! schema_dict = {var: DataArraySchema(chunks=chunk_def) for var in ds.data_vars} target_schema = DatasetSchema(schema_dict) with contextlib.suppress(SchemaError): # check to see if the initial dataset already matches the schema, in which case just # return the initial path and work with that target_schema.validate(ds) return path rechunk_plan = rechunker.rechunk( source=group, target_chunks=chunks_dict, max_mem=max_mem, target_store=target_store, temp_store=temp_store, target_options={ k: {'compressor': zarr.Blosc(clevel=1), 'write_empty_chunks': True} for k in chunks_dict }, temp_options={k: {'compressor': None, 'write_empty_chunks': True} for k in chunks_dict}, executor='dask', ) rechunk_plan.execute() # consolidate_metadata here since when it comes out of rechunker it isn't consolidated. zarr.consolidate_metadata(target_store) validate_zarr_store(target_store) temp_store.clear() return target
[docs] @task(log_stdout=True) def time_summary(ds_path: UPath, freq: str) -> UPath: """Prefect task to create resampled data. Takes mean of `tasmax` and `tasmin` and sum of `pr`. Parameters ---------- ds_path : UPath UPath to input zarr store at daily timestep freq : str aggregation frequency Returns ------- UPath Path to resampled dataset. """ ds_hash = str_to_hash(str(ds_path) + freq) target = results_dir / 'time_summary' / ds_hash if use_cache and is_cached(target): print(f'found existing target: {target}') return target ds = xr.open_zarr(ds_path) out_ds = resample_wrapper(ds, freq=freq) out_ds.attrs.update({'title': 'time_summary'}, **get_cf_global_attrs(version=version)) out_ds = set_zarr_encoding(out_ds) blocking_to_zarr(ds=out_ds, target=target, validate=True, write_empty_chunks=True) return target
[docs] @task(log_stdout=True) def get_weights(*, run_parameters, direction, regrid_method="bilinear"): """Retrieve pre-generated regridding weights. Parameters ---------- run_parameters : dict Dictionary of run parameters direction : str Direction of regridding. regrid_method : str Regridding method. Returns ------- path : UPath Path to weights file. """ weights = pd.read_csv(config.get('weights.gcm_obs_weights.uri')) path = ( weights[ (weights.source_id == run_parameters.model) & (weights.table_id == run_parameters.table_id) & (weights.grid_label == run_parameters.grid_label) & (weights.regrid_method == regrid_method) & (weights.direction == direction) ] .iloc[0] .path ) return path
[docs] @task(log_stdout=True) def get_pyramid_weights(*, run_parameters, levels: int, regrid_method: str = "bilinear"): """Retrieve pre-generated regridding pyramids weights. Parameters ---------- run_parameters : dict Dictionary of run parameters levels : int Number of levels in the pyramid. regrid_method : str Regridding method. Returns ------- path : UPath Path to pyramid weights file. """ weights = pd.read_csv(config.get('weights.downscaled_pyramid_weights.uri')) path = ( weights[(weights.regrid_method == regrid_method) & (weights.levels == levels)].iloc[0].path ) return path
[docs] @task(log_stdout=True) def regrid( source_path: UPath, target_grid_path: UPath, weights_path: UPath = None, pre_chunk_def: dict = None, ) -> UPath: """Task to regrid a dataset to target grid. Parameters ---------- source_path : UPath Path to dataset that will be regridded target_grid_path : UPath Path to template grid dataset weights_path : UPath (Optional) Path to weights file Returns ------- UPath Path to regridded output dataset. """ import xesmf as xe ds_hash = str_to_hash(str(source_path) + str(target_grid_path)) target = intermediate_dir / 'regrid' / ds_hash if use_cache and is_cached(target): print(f'found existing target: {target}') return target source_ds = xr.open_zarr(source_path) target_grid_ds = xr.open_zarr(target_grid_path) if pre_chunk_def is not None: source_ds = source_ds.chunk(**pre_chunk_def) if weights_path: from ndpyramid.regrid import _reconstruct_xesmf_weights weights = _reconstruct_xesmf_weights(xr.open_zarr(weights_path)) regridder = xe.Regridder( source_ds, target_grid_ds, weights=weights, reuse_weights=True, method="bilinear", extrap_method="nearest_s2d", ignore_degenerate=True, ) else: regridder = xe.Regridder( source_ds, target_grid_ds, method="bilinear", extrap_method="nearest_s2d", ignore_degenerate=True, ) regridded_ds = regridder(source_ds, keep_attrs=True) regridded_ds.attrs.update( {'title': source_ds.attrs['title']}, **get_cf_global_attrs(version=version) ) regridded_ds = set_zarr_encoding(regridded_ds) blocking_to_zarr(ds=regridded_ds, target=target, validate=True, write_empty_chunks=True) return target
def _load_coords(ds: xr.Dataset) -> xr.Dataset: '''Helper function to explicitly load all dataset coordinates''' for var, da in ds.coords.items(): ds[var] = da.load() return ds def _pyramid_postprocess(dt: dt.DataTree, levels: int, other_chunks: dict = None) -> dt.DataTree: '''Postprocess data pyramid Adds multiscales metadata and sets Zarr encoding Parameters ---------- dt : dt.DataTree Input data pyramid levels : int Number of levels in pyramid other_chunks : dict Chunks for non-spatial dims Returns ------- dt.DataTree Updated data pyramid with metadata / encoding set ''' chunks = {"x": PIXELS_PER_TILE, "y": PIXELS_PER_TILE} if other_chunks is not None: chunks.update(other_chunks) for level in range(levels): slevel = str(level) dt.ds.attrs['multiscales'][0]['datasets'][level]['pixels_per_tile'] = PIXELS_PER_TILE # set dataset chunks dt[slevel].ds = dt[slevel].ds.chunk(chunks) if 'date_str' in dt[slevel].ds: dt[slevel].ds['date_str'] = dt[slevel].ds['date_str'].chunk(-1) # set dataset encoding dt[slevel].ds = set_web_zarr_encoding( dt[slevel].ds, codec_config={"id": "zlib", "level": 1}, float_dtype="float32" ) for var in ['time', 'time_bnds']: if var in dt[slevel].ds: dt[slevel].ds[var].encoding['dtype'] = 'int32' # set global metadata dt.ds.attrs.update({'title': 'multiscale data pyramid'}, **get_cf_global_attrs(version=version)) return dt
[docs] @task(log_stdout=True) def pyramid( ds_path: UPath, weights_pyramid_path: str = None, levels: int = 2, other_chunks: dict = None ) -> UPath: '''Task to create a data pyramid from an xarray Dataset Parameters ---------- ds_path : UPath Path to input dataset weights_pyramid_path : str Path to weights pyramid levels : int, optional Number of levels in pyramid, by default 2 uri : str, optional Path to write output data pyamid to, by default None other_chunks : dict Chunks for non-spatial dims Returns ------- target : UPath ''' ds_hash = str_to_hash(str(ds_path) + str(levels) + str(other_chunks)) target = results_dir / 'pyramid' / ds_hash if use_cache and is_cached(target): print(f'found existing target: {target}') return target ds = xr.open_zarr(ds_path).pipe(_load_coords) ds.coords['date_str'] = ds['time'].dt.strftime('%Y-%m-%d').astype('S10') ds.attrs.update({'title': ds.attrs['title']}, **get_cf_global_attrs(version=version)) target_pyramid = dt.open_datatree('az://static/target-pyramid', engine='zarr') if weights_pyramid_path is not None: weights_pyramid = dt.open_datatree(weights_pyramid_path, engine='zarr') else: weights_pyramid = None # create pyramid dta = pyramid_regrid( ds, target_pyramid=target_pyramid, levels=levels, weights_pyramid=weights_pyramid, regridder_kws={'ignore_degenerate': True}, ) dta = _pyramid_postprocess(dta, levels=levels, other_chunks=other_chunks) # write to target for child in dta.children.values(): for variable in child.ds.data_vars: child[variable].encoding['write_empty_chunks'] = True dta.to_zarr(target, mode='w') validate_zarr_store(target) return target
[docs] @task(log_stdout=True) def run_analyses(ds_path: UPath, run_parameters: RunParameters) -> UPath: """Prefect task to run the analyses on results from a downscaling run. Parameters ---------- ds_path : UPath Path to input dataset run_parameters : RunParameters Downscaling run parameter container Returns ------- PosixPath The local location of an executed notebook path. """ import papermill from azure.storage.blob import BlobServiceClient, ContentSettings from cmip6_downscaling.analysis import metrics root = PosixPath(metrics.__file__) template_path = root.parent / 'analyses_template.ipynb' executed_notebook_path = root.parent / f'analyses_{run_parameters.run_id}.ipynb' executed_html_path = root.parent / f'analyses_{run_parameters.run_id}.html' parameters = asdict(run_parameters) parameters['run_id'] = run_parameters.run_id # TODO: figure out how to unpack these fields in the notebook # asdict will return lists for train_dates and predict_dates # parameters['train_period_start'] = train_period.start # parameters['train_period_end'] = train_period.stop # parameters['predict_period_start'] = predict_period.start # parameters['predict_period_end'] = predict_period.stop # execute notebook with papermill papermill.execute_notebook(template_path, executed_notebook_path, parameters=parameters) # convert from ipynb to html # TODO: move this to stand alone function # Q: can we control the output path name? os.system(f"jupyter nbconvert {executed_notebook_path} --to html") # TODO: move to stand alone function connection_string = os.getenv('AZURE_STORAGE_CONNECTION_STRING', None) if connection_string is not None: # if you have a connection_string, copy the html to azure, if not just return # because it is already in your local machine blob_service_client = BlobServiceClient.from_connection_string(connection_string) # TODO: fix b/c the run_id has slashes now!!! blob_name = config.get('storage.web_results.blob') / parameters.run_id / 'analyses.html' blob_client = blob_service_client.get_blob_client(container='$web', blob=blob_name) # clean up before writing with contextlib.suppress(Exception): blob_client.delete_blob() # need to specify html content type so that it will render and not download with open(executed_html_path, "rb") as data: blob_client.upload_blob( data, content_settings=ContentSettings(content_type='text/html') ) return executed_notebook_path
def _finalize(run_parameters, kind='runs', **paths): path_dict = dict(**paths) now = datetime.datetime.utcnow().isoformat() target1 = results_dir / kind / run_parameters.run_id / f'{now}.json' target2 = results_dir / kind / run_parameters.run_id / 'latest.json' print(f'finalize 1: {target1}') print(f'finalize 2: {target2}') out = {'parameters': asdict(run_parameters)} out['attrs'] = get_cf_global_attrs(version=version) out['datasets'] = {k: str(p) for k, p in path_dict.items()} with target1.open(mode='w') as f: json.dump(out, f, indent=2) with target2.open(mode='w') as f: json.dump(out, f, indent=2)
[docs] @task(log_stdout=True) def finalize(run_parameters: RunParameters = None, **paths): """Prefect task to finalize the downscaling run. Parameters ---------- run_parameters : RunParameters Downscaling run parameter container paths : dict Dictionary of paths to write result file """ _finalize(run_parameters, kind='runs', **paths)
@task(log_stdout=True, trigger=any_failed) def finalize_on_failure(run_parameters: RunParameters = None, **paths): """Prefect task to finalize a failed downscaling run. Parameters ---------- run_parameters : RunParameters Downscaling run parameter container paths : dict Dictionary of paths to write to result file """ _finalize(run_parameters, kind='failed-runs', **paths)