Source code for seisbench.data.das_base

from abc import ABC
from pathlib import Path
from typing import Any
import datetime

import pandas as pd
import numpy as np
import h5py
import copy

import seisbench
from .base import AbstractBenchmarkDataset, MultiWaveformDataset

NPArrayOrHDF5Dataset = np.ndarray | h5py.Dataset


[docs] class DASDataset: # File path format for metadata and records METADATA_FILE = "metadata_$CHUNK.parquet" DATA_FILE = "records_$CHUNK.hdf5" def __init__(self, path: Path | str = None, chunks: list[str] | None = None): self._path = path if chunks is None: chunks = self.available_chunks(path) self._chunks = sorted(chunks) self._validate_chunks(path, self._chunks) self._validate_dataset() self._metadata = self._load_metadata() self._data_pointers = {} @property def chunks(self) -> list[str]: return self._chunks @property def path(self) -> Path: """ Path of the dataset """ # This path is overwritten by DASBenchmarkDataset return Path(self._path) @property def metadata(self) -> pd.DataFrame: return self._metadata def __len__(self) -> int: return len(self._metadata) def _validate_chunks(self, path: Path, chunks: list[str]) -> None: available_chunks = self.available_chunks(path) if any(chunk not in available_chunks for chunk in chunks): raise ValueError( f"The dataset does not contain the following chunks: " f"{[chunk for chunk in chunks if chunk not in self.available_chunks(path)]}" ) def _validate_dataset(self) -> None: for chunk in self.chunks: metadata_path = self.METADATA_FILE.replace("$CHUNK", chunk) data_path = self.DATA_FILE.replace("$CHUNK", chunk) if not (self.path / metadata_path).is_file(): raise ValueError(f"Metadata file {metadata_path} does not exist.") if not (self.path / data_path).is_file(): raise ValueError(f"Data file {data_path} does not exist.") def _load_metadata(self) -> pd.DataFrame: metadata = [] for chunk in self.chunks: metadata_path = self.METADATA_FILE.replace("$CHUNK", chunk) chunk_metadata = pd.read_parquet(self.path / metadata_path) chunk_metadata["chunk"] = chunk metadata.append(chunk_metadata) metadata = pd.concat(metadata) metadata.reset_index(drop=True, inplace=True) return metadata
[docs] @staticmethod def available_chunks(path: Path) -> list[str]: """ Determines the chunks of the dataset in the given path. If available, parses the chunks file. Otherwise, scans the dataset for metadata and records files. :param path: Dataset path :return: List of chunks """ chunks_path = path / "chunks" if chunks_path.is_file(): with open(chunks_path, "r") as f: chunks = [x for x in f.read().split("\n") if x.strip()] if len(chunks) == 0: chunks = [""] else: metadata_files = set( [ x.name[9:-8] for x in path.iterdir() if x.name.startswith("metadata_") and x.name.endswith(".parquet") ] ) waveform_files = set( [ x.name[8:-5] for x in path.iterdir() if x.name.startswith("records_") and x.name.endswith(".hdf5") ] ) chunks = list(metadata_files & waveform_files) return sorted(chunks)
[docs] def get_sample( self, idx: int, record_virtual: bool = True, annotations_virtual: bool = False ) -> tuple[dict[str, Any], NPArrayOrHDF5Dataset, dict[str, NPArrayOrHDF5Dataset]]: """ Load the sample with the given index. Use the `record_virtual` and `annotations_virtual` arguments to control whether the record and annotations are loaded into memory or only pointers are returned. By default, the record will not be loaded into memory, while the annotations will be loaded into memory. :param idx: Index of the sample to load :param record_virtual: If true, the record is returned as a virtual array. Otherwise, the record is loaded into memory. :param annotations_virtual: If true, the annotations are returned as virtual arrays. Otherwise, the annotations are loaded into memory. """ metadata = self._metadata.iloc[idx].to_dict() chunk = metadata["chunk"] if chunk not in self._data_pointers: self._data_pointers[chunk] = h5py.File( self.path / self.DATA_FILE.replace("$CHUNK", chunk), "r" )["records"] record_group = self._data_pointers[chunk][metadata["record_name"]] record = record_group["record"] annotations = {} for key in record_group["annotations"].keys(): annotations[key] = record_group["annotations"][key] if not annotations_virtual: annotations[key] = annotations[key][()] if not record_virtual: record = record[()] return metadata, record, annotations
[docs] def filter(self, mask: np.ndarray, inplace: bool = True) -> "DASDataset | None": """ Filters dataset, e.g. by distance/magnitude/..., using a binary mask. Default behaviour is to perform inplace filtering. Setting inplace equal to false will return a filtered copy of the data set. :param mask: Boolean mask to apply to metadata. :param inplace: If true, filter inplace. Example usage: .. code:: python dataset.filter(dataset.metadata["record_sampling_rate_hz"] > 100) """ if inplace: self._metadata = self._metadata[mask] # Drop pointers to chunks that are no longer in the metadata remaining_chunks = self._metadata["chunk"].unique() self._data_pointers = { k: v for k, v in self._data_pointers.values() if k in remaining_chunks } return None else: other = self.copy() other.filter(mask, inplace=True) return other
[docs] def copy(self) -> "DASDataset": """ Create a copy of the data set. All attributes are copied by value. """ other = copy.copy(self) other._metadata = self._metadata.copy() other._data_pointers = {} return other
[docs] def get_split(self, split: str, inplace: bool = False) -> "DASDataset | None": """ Returns a dataset with the requested split. :param split: Split name to return. Usually one of "train", "dev", "test" :return: Dataset filtered to the requested split. """ if "split" not in self.metadata.columns: raise ValueError("Split requested but no split defined in metadata") mask = (self.metadata["split"] == split).values return self.filter(mask, inplace=inplace)
[docs] def train(self, inplace: bool = False) -> "DASDataset | None": """ Convenience method for get_split("train"). :return: Training dataset """ return self.get_split("train", inplace=inplace)
[docs] def dev(self, inplace: bool = False) -> "DASDataset | None": """ Convenience method for get_split("dev"). :return: Development dataset """ return self.get_split("dev", inplace=inplace)
[docs] def test(self, inplace: bool = False) -> "DASDataset | None": """ Convenience method for get_split("test"). :return: Test dataset """ return self.get_split("test", inplace=inplace)
[docs] def train_dev_test(self) -> tuple["DASDataset", "DASDataset", "DASDataset"]: """ Convenience method for returning training, development and test set. Equal to: >>> self.train(), self.dev(), self.test() :return: Training dataset, development dataset, test dataset """ return self.train(), self.dev(), self.test()
def __add__(self, other) -> "MultiDASDataset": if isinstance(other, DASDataset): return MultiDASDataset([self, other]) elif isinstance(other, MultiDASDataset): return MultiDASDataset([self] + other.datasets) else: raise TypeError( "Can only add DASDataset and MultiDASDataset to DASDataset." )
[docs] class MultiDASDataset: """ This class is a wrapper for multiple DAS datasets. It allows combining multiple datasets into a single dataset. It has mostly the same API as :py:class:`DASDataset`. """ def __init__(self, datasets: list[DASDataset]): if not isinstance(datasets, list) or not all( isinstance(x, DASDataset) for x in datasets ): raise TypeError("MultiDASDataset expects a list of DASDataset as input.") if len(datasets) == 0: raise ValueError("MultiDASDatasets need to have at least one member.") self._datasets = [dataset.copy() for dataset in datasets] self._metadata = pd.concat(x.metadata for x in datasets) # Identify dataset self._metadata["trace_dataset"] = sum( ([i] * len(dataset) for i, dataset in enumerate(self._datasets)), [] ) self._metadata.reset_index(inplace=True, drop=True) @property def datasets(self): return list(self._datasets) def __add__(self, other: "DASDataset | MultiDASDataset") -> "MultiDASDataset": if isinstance(other, DASDataset): return MultiDASDataset(self._datasets + [other]) elif isinstance(other, MultiDASDataset): return MultiDASDataset(self._datasets + other._datasets) else: raise TypeError( "Can only add DASDataset and MultiDASDataset to MultiDASDataset." ) @property def metadata(self): return self._metadata
[docs] def filter( self, mask: np.ndarray, inplace: bool = True ) -> "MultiDASDataset | None": """ Filters dataset, similar to :py:func:`WaveformDataset.filter`. :param mask: Boolean mask to apple to metadata. :param inplace: If true, filter inplace. """ submasks = self._split_mask(mask) if inplace: for submask, dataset in zip(submasks, self.datasets): dataset.filter(submask, inplace=True) # Calculate new metadata self._metadata = pd.concat(x.metadata for x in self.datasets) return None else: return MultiDASDataset( [ dataset.filter(submask, inplace=False) for submask, dataset in zip(submasks, self.datasets) ] )
[docs] def get_sample(self, idx: int, *args, **kwargs): dataset_idx, local_idx = self._resolve_idx(idx) return self.datasets[dataset_idx].get_sample(local_idx, *args, **kwargs)
_split_mask = MultiWaveformDataset._split_mask _resolve_idx = MultiWaveformDataset._resolve_idx train = DASDataset.train dev = DASDataset.dev test = DASDataset.test train_dev_test = DASDataset.train_dev_test get_split = DASDataset.get_split __len__ = DASDataset.__len__
[docs] class DASBenchmarkDataset(AbstractBenchmarkDataset, DASDataset, ABC): """ This class is the base class for benchmark DAS datasets. For the functionality, see the superclasses. """ _files = [DASDataset.METADATA_FILE, DASDataset.DATA_FILE]
[docs] class RandomDASDataset(DASBenchmarkDataset): """ This is a purely random dataset for testing purposes. It does not contain any actual data and should only be used for unit tests. """ def __init__(self, **kwargs): super().__init__(compile_from_source=True, repository_lookup=False, **kwargs)
[docs] @classmethod def available_chunks(cls, force: bool = False, wait_for_file: bool = False): return ["0", "1"]
def _download_dataset(self, files: list[Path], chunk: str, **kwargs): with DASDataWriter(self.path, chunk, files[0], files[1]) as writer: for i in range(5): data = np.random.randn(500, 600) record_metadata = { "record_channel_spacing_m": 4.0, "record_sampling_rate_hz": 125, "record_start_time": datetime.datetime.now(), } annotations = { "P": np.random.rand(600) * 500, "S": np.random.rand(600) * 500, } writer.add_record(record_metadata, data, annotations)
[docs] class DASDataWriter: """ This class allows writing DAS datasets in SeisBench format. It only writes a single chunk. To write multiple chunks, use multiple data writers with different `chunk` arguments but identical `path`. :param path: Path to write the chunk to :param chunk: Chunk identifier :param metadata_path: Overwrite for the metadata path. If provided, writes the metadata here instead of the default location. The chunk key will be ignored in this case. Unless integrated into complex workflows, this parameter should not be used. :param data_path: Same as ``.metadata_path`` but for the data file. :param data_type: Data type of the data. Defaults to float32. :param strict: If true, raise an error if the metadata does not contain the key fields. Otherwise, only raise a warning. """ def __init__( self, path: Path | str, chunk: str = "", metadata_path: Path | str | None = None, data_path: Path | str | None = None, data_type: type[np.floating] | type[np.integer] = np.float32, strict: bool = True, ): self.path = Path(path) self.chunk = chunk self.data_type = data_type self.strict = strict self._metadata_path = metadata_path self._data_path = data_path self.path.mkdir(parents=False, exist_ok=True) self._metadata = [] self._data_file = h5py.File(self.data_path, "w") self._data_group = self._data_file.create_group("records") self._record_count = 0 @property def metadata_path(self) -> Path: if self._metadata_path is None: return self.path / DASDataset.METADATA_FILE.replace("$CHUNK", self.chunk) else: return Path(self._metadata_path) @property def data_path(self) -> Path: if self._data_path is None: return self.path / DASDataset.DATA_FILE.replace("$CHUNK", self.chunk) else: return Path(self._data_path)
[docs] def add_record( self, metadata: dict[str, Any], data: np.ndarray, annotations: dict[str, np.ndarray], ) -> None: """ Add a record to the dataset. While the data and annotations will immediately be written to disk, the metadata will be stored in memory and written to disk when the dataset is closed. :param metadata: Metadata of the record. There are no mandatory fields, but warnings will be issued if typical key fields are missing. :param data: Data of the record. The data needs to be a 2D array (time, channel). :param annotations: Annotations of the record. Each annotation consists of a 1D array with the same length as the number of channels. The entries are in samples along the time axis. For example, an annotation called ``"P"`` indicates the indices of the P wave arrival at each channel. NaN values are allowed. Annotations can differ between the records. """ self._validate_metadata_entry(metadata) self._validate_annotations(data, annotations) record_name = f"record_{self._record_count}" self._record_count += 1 if "record_name" in metadata: seisbench.logger.warning( "Column 'record_name' already exists in metadata. Will overwrite it." ) metadata["record_name"] = record_name record_group = self._data_group.create_group(record_name) record_group.create_dataset("record", data=data, dtype=self.data_type) annotations_group = record_group.create_group("annotations") for key, value in annotations.items(): annotations_group.create_dataset(key, data=value) self._metadata.append(metadata)
@staticmethod def _validate_annotations( data: np.ndarray, annotations: dict[str, np.ndarray] ) -> None: for key, value in annotations.items(): if len(value) != data.shape[1]: raise ValueError( f"Annotation {key} has length {len(value)}, but data has {data.shape[1]} channels." ) def _validate_metadata_entry(self, metadata: dict[str, Any]) -> None: keys = [ "record_sampling_rate_hz", "record_channel_spacing_m", "record_start_time", ] for key in keys: if key not in metadata: if self.strict: raise ValueError( f"Metadata entry {key} is missing. " f"Add the key or set strict=False to ignore this error." ) else: seisbench.logger.warning( f"Metadata entry {metadata} is missing some required keys." ) def _finalize(self): metadata = pd.DataFrame(self._metadata) metadata.to_parquet(self.metadata_path, index=False) self._data_file.close() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self._finalize() if exc_type is None: return True else: seisbench.logger.error( f"Error in creating dataset. " f"Saved current progress to {self.metadata_path} and {self.data_path}. Error message:\n" ) return False