Source code for dep_tools.loaders

"""ABC definition and implementation of Loader objects."""

from abc import ABC, abstractmethod
from typing import Any, Iterable
import warnings

from geopandas import GeoDataFrame
from odc.geo.geobox import GeoBox
from odc.geo.geom import Geometry
from odc.stac import load as stac_load
from pystac import Item
from rasterio.errors import RasterioError, RasterioIOError
import rioxarray
from stackstac import stack
from xarray import DataArray, Dataset, concat


[docs] class Loader(ABC): """A base class for something that loads data based on input areas.""" def __init__(self): pass
[docs] @abstractmethod def load(self, areas): pass
[docs] class StacLoader(Loader): """A loader which loads data based on (STAC) items and areas."""
[docs] @abstractmethod def load(self, items, areas) -> Any: pass
[docs] class OdcLoader(StacLoader): """A wrapper around :func:`odc.stac.load`. In addition to allowing conformance to the :class:`Loader` form, this class offers a number of convenience operations which compliment :func:`odc.stac.load`: - If the data is loaded as floating point, any nodata values (as defined by the `"nodata"` attribute of the loaded data) are set to NaN, and the attribute itself is reset to NaN. - The nodata value is also set to be accessed via the rioxarray accessor (`.rio.nodata`). Args: load_as_dataset: If False, load as a DataArray with each variable in the `band` dimension. clip_to_areas: If True, loaded data is clipped to the given areas using :func:`odc.geo.mask`. **kwargs: Additional arguments to :func:`odc.stac.load`. """ def __init__( self, load_as_dataset: bool = True, clip_to_areas: bool = False, **kwargs, ): super().__init__() self._kwargs = kwargs self._clip_to_area = clip_to_areas self._load_as_dataset = load_as_dataset
[docs] def load( self, items: Iterable[Item], areas: GeoDataFrame | GeoBox | None = None ) -> Dataset | DataArray: """Load STAC Items into an xarray object. If `nodata` is passed as a kwarg on initialization, or the stac item contains the nodata value, `xr[variable].nodata` will be set on load. Args: items: The items to load. areas: If `clip_to_areas` is `True`, the output is clipped to these areas. Returns: If `load_to_dataset` is True on initialization, a :class:`xarray.Dataset` is returned. Otherwise a :class:`xarray:DataArray` is returned, with variables set on the `"band"` dimension. Raises: ValueError: If `areas` is a GeoBox and `clip_to_areas` is set to `True`. """ if isinstance(areas, GeoDataFrame): load_geometry = dict(geopolygons=areas) elif isinstance(areas, GeoBox): load_geometry = dict(geobox=areas) else: load_geometry = dict() ds = stac_load( items, **load_geometry, **self._kwargs, ) for name in ds: # Since nan is more-or-less universally accepted as a nodata value, # if the dtype of a band is some sort of floating point, then recode # existing values that are equal to the value set on load to nan if ds[name].dtype.kind == "f": # Should I make this an option? if "nodata" in ds[name].attrs.keys(): ds[name] = ds[name].where(ds[name] != ds[name].nodata, float("nan")) ds[name].attrs["nodata"] = float("nan") # To be helpful, set the nodata for rioxarray accessor ds[name].rio.write_nodata(ds[name].attrs.get("nodata"), inplace=True) if self._clip_to_area: if isinstance(areas, GeoBox): raise ValueError( "Clip not supported for GeoBox (nor should it be needed)" ) if areas is None: warnings.warn("clip_to_area is True but areas is None, ignoring") geom = Geometry(areas.geometry.unary_union, crs=areas.crs) ds = ds.odc.mask(geom) if not self._load_as_dataset: return ( ds.to_array("band") .rename(band="data") .rio.write_crs(ds[list(ds.data_vars)[0]].odc.crs) ) return ds
[docs] class StackStacLoader(StacLoader): def __init__( self, stack_kwargs=dict(resolution=30), resamplers_and_assets=None, **kwargs ): super().__init__(**kwargs) self.stack_kwargs = stack_kwargs self.resamplers_and_assets = resamplers_and_assets
[docs] def load( self, items, areas: GeoDataFrame, ) -> DataArray: areas_proj = areas.to_crs(self._current_epsg) if self.resamplers_and_assets is not None: s = concat( [ stack( items, chunksize=self.dask_chunksize, epsg=self._current_epsg, errors_as_nodata=(RasterioError(".*"),), assets=resampler_and_assets["assets"], resampling=resampler_and_assets["resampler"], bounds=areas_proj.total_bounds.tolist(), band_coords=False, # needed or some coords are often missing # from qa pixel and we get an error. Make sure it doesn't # mess anything up else where (e.g. rio.crs) **self.stack_kwargs, ) for resampler_and_assets in self.resamplers_and_assets ], dim="band", ) else: s = stack( items, chunksize=self.dask_chunksize, epsg=self._current_epsg, errors_as_nodata=(RasterioIOError(".*"),), **self.stack_kwargs, ) return s.rio.write_crs(self._current_epsg).rio.clip( areas_proj.geometry, all_touched=True, from_disk=True, )