from pathlib import Path
import numpy as np
import pandas as pd
import h5py
from tqdm import tqdm
from typing import Any
from scipy.signal import butter, sosfilt
from seisbench import logger
from .das_base import DASBenchmarkDataset, DASDataWriter
[docs]
class MLSubDAS(DASBenchmarkDataset):
"""
MLSubDAS dataset by Xiao et al. (2026)
"""
chunk_count = 13
chunk_size = 100
min_total_labels = 500
min_phase_labels = 250
def __init__(self, **kwargs):
logger.warning(
"This dataset has been annotated semi-automatically. The annotations will be incomplete and high picking "
"errors (>1 s) or spatially incoherent labeling occurs. Check out the original publication for details on "
"the dataset and annotation strategy. The limitations of this dataset should be taken into account when "
"using it for training or evaluation."
)
citation = (
"Xiao, H., Tilmann, F., van den Ende, M., Rivet, D., Loureiro, A., Tsuji, T., ... & Denolle, M. A. (2026). "
"DeepSubDAS: an earthquake phase picker from submarine distributed acoustic sensing data. "
"Geophysical Journal International, 245(2), ggag061."
"https://doi.org/10.1093/gji/ggag061"
)
license = "CC BY 4.0"
super().__init__(
citation=citation,
license=license,
compile_from_source=False,
repository_lookup=True,
**kwargs,
)
[docs]
@classmethod
def available_chunks(cls, force: bool = False, wait_for_file: bool = False):
return [f"{i:02d}" for i in range(cls.chunk_count)]
def _download_dataset(
self,
files: list[Path],
chunk: str,
base_path: Path = None,
selection_path: Path = None,
catalog_path: Path = None,
**kwargs,
):
"""
Converts dataset from local files.
This function is only for reference, as it relies on a local file structure.
"""
if base_path is None:
raise ValueError("`base_path` must be provided for conversion from source.")
if selection_path is None:
raise ValueError(
"`selection_path` must be provided for conversion from source."
)
if catalog_path is None:
raise ValueError(
"`catalog_path` must be provided for conversion from source."
)
entries = self._scan_files(base_path)
entries = entries[entries["total_labels"] >= self.min_total_labels]
entries = self._subselect_events(entries, selection_path)
n_train = int(0.7 * len(entries))
n_dev = int(0.1 * len(entries))
n_test = len(entries) - n_train - n_dev
split = ["train"] * n_train + ["dev"] * n_dev + ["test"] * n_test
np.random.seed(42)
np.random.shuffle(split)
entries["split"] = split
chunk_idx = int(chunk)
entries = entries[
chunk_idx * self.chunk_size : (chunk_idx + 1) * self.chunk_size
]
data_key = "data"
catalog = pd.read_parquet(catalog_path)
catalog_dict = {
row["file_name"]: row.to_dict() for _, row in catalog.iterrows()
}
with DASDataWriter(
self.path, chunk, files[0], files[1], strict=False
) as writer:
for _, entry in tqdm(entries.iterrows(), total=len(entries)):
csv_path = base_path / entry["folder"] / (entry["file"] + ".csv")
if not csv_path.is_file():
csv_path = (
base_path / entry["folder"] / (entry["file"] + ".mat.csv")
)
hdf_path = base_path / entry["folder"] / (entry["file"] + ".h5")
metadata = pd.read_csv(csv_path)
with h5py.File(hdf_path, "r") as f:
if data_key not in f:
print(f"Problem with file {entry['folder']}/{entry['file']}")
continue
# Data shape: (samples, channels)
data = f[data_key][()]
data = self._preprocess_data(data)
annotations = self._convert_annotations(metadata, data.shape[1], entry)
record_metadata = self._get_record_metadata(entry)
event_metadata = catalog_dict.get(
entry["file"], {"file_name": "just_so_the_next_line_doesnt_fail"}
)
del event_metadata["file_name"]
writer.add_record(
{**record_metadata, **event_metadata}, data, annotations
)
@staticmethod
def _scan_files(base_path: Path) -> pd.DataFrame:
entries = []
def truncate_csv_name(x):
x = x.name[:-4]
if x.endswith(".mat"):
x = x[:-4]
return x
for folder in base_path.iterdir():
csv_files = [truncate_csv_name(x) for x in sorted(folder.glob("*.csv"))]
hdf_files = set(x.name[:-3] for x in sorted(folder.glob("*.h5")))
for file in csv_files:
if file in hdf_files:
try:
metadata = pd.read_csv(
folder / f"{file}.csv",
dtype={
"p_wave_index": np.float32,
"s_wave_index": np.float32,
},
)
except FileNotFoundError:
metadata = pd.read_csv(
folder / f"{file}.mat.csv",
dtype={
"p_wave_index": np.float32,
"s_wave_index": np.float32,
},
)
if "p_wave_index" in metadata.columns:
p_labels = np.sum(~np.isnan(metadata["p_wave_index"]))
s_labels = np.sum(~np.isnan(metadata["s_wave_index"]))
else:
p_labels, s_labels = 0, 0
else:
p_labels, s_labels = 0, 0
entries.append(
{
"folder": folder.name,
"file": file,
"has_hdf5": file in hdf_files,
"p_labels": p_labels,
"s_labels": s_labels,
}
)
entries = pd.DataFrame(entries)
entries["total_labels"] = entries["p_labels"] + entries["s_labels"]
return entries
@staticmethod
def _subselect_events(entries: pd.DataFrame, selection_path: Path) -> pd.DataFrame:
with open(selection_path) as f:
selected = set(f.read().split("\n"))
return entries[entries["file"].isin(selected)].copy()
@staticmethod
def _convert_annotations(
metadata: pd.DataFrame,
n_channels: int,
entry: dict[str, Any],
) -> dict[str, np.ndarray]:
# Should output P_0, P_1, ... in case of multiple P waves
index_column = "Unnamed: 0"
annotations = {}
for phase in "PS":
if entry[f"{phase.lower()}_labels"] >= MLSubDAS.min_phase_labels:
phase_labels = metadata[
~np.isnan(metadata[f"{phase.lower()}_wave_index"])
]
break_points = (
[0]
+ list(np.where(np.diff(phase_labels[index_column]) < 0)[0])
+ [len(phase_labels)]
)
for i, (p0, p1) in enumerate(zip(break_points[:-1], break_points[1:])):
annotation = np.nan * np.zeros(n_channels)
index_vals = phase_labels[index_column].values[p0:p1]
pick_samples = phase_labels[f"{phase.lower()}_wave_index"].values[
p0:p1
]
annotation[index_vals] = pick_samples
annotations[f"{phase}_{i}"] = annotation
return annotations
@staticmethod
def _preprocess_data(data: np.ndarray) -> np.ndarray:
def butter_bandpass(low, high, fs=100, order=4):
nyq = 0.5 * fs
low /= nyq
high /= nyq
sos = butter(order, [low, high], btype="band", output="sos")
return sos
def cosine_window(n, boundary_samples=100):
boundary_samples = min(boundary_samples, n // 2)
x = np.ones(n)
x[:boundary_samples] = (
1 - np.cos(np.linspace(0, np.pi, boundary_samples))
) / 2
x[-boundary_samples:] = (
1 + np.cos(np.linspace(0, np.pi, boundary_samples))
) / 2
return x
sos = butter_bandpass(1, 20, 100)
window = cosine_window(data.shape[0]).reshape(-1, 1)
data = window * (data - np.mean(data, axis=0, keepdims=True))
data = sosfilt(sos, data, axis=0)
return data.astype(np.float32)
@staticmethod
def _get_record_metadata(entry: pd.Series) -> dict[str, Any]:
general = {
"record_identifier": entry["file"],
"record_sampling_rate_hz": 100.0,
"record_p_labels": entry["p_labels"]
if entry["p_labels"] >= MLSubDAS.min_phase_labels
else 0,
"record_s_labels": entry["s_labels"]
if entry["s_labels"] >= MLSubDAS.min_phase_labels
else 0,
"split": entry["split"],
}
cable = {}
if entry["folder"] == "das_alaska":
# Stated in Xiao et al. (2026)
cable = {
"record_channel_spacing_m": 9.57,
}
elif entry["folder"] == "das_chillie":
if "CCN" in entry["file"] or "SER" in entry["file"]:
# Stated in Xiao et al. (2026)
cable = {
"record_channel_spacing_m": 10.0,
}
else:
# Stated in Xiao et al. (2026)
cable = {
"record_channel_spacing_m": 4.08,
}
elif entry["folder"] == "das_japan":
# Stated in Xiao et al. (2026)
cable = {
"record_channel_spacing_m": 10.0,
}
elif entry["folder"] == "das_spain":
# Stated in Xiao et al. (2026)
cable = {
"record_channel_spacing_m": 10.0,
}
elif entry["folder"] == "other_data":
if "NyAalesund" in entry["file"]:
# Stated in Bouffaut et al. (2022)
cable = {
"record_channel_spacing_m": 4.08,
}
elif "maderia" in entry["file"]:
# Stated in Loureiro et al. (2025)
cable = {
"record_channel_spacing_m": 5.1,
}
return {**general, **cable}