from __future__ import annotations
import asyncio
import random
import string
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from functools import partial
from itertools import groupby
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Iterable, Literal, cast
from urllib.parse import urljoin
from zipfile import ZipFile
import numpy as np
from obspy import Inventory, Stream, UTCDateTime, read_inventory
from obspy.clients.fdsn import Client
from obspy.clients.fdsn.header import FDSNException
from obspy.core.event.base import WaveformStreamID
from obspy.core.event.catalog import Catalog, Event
from obspy.core.event.origin import Pick
from obspy.geodetics import gps2dist_azimuth
from obspy.io.nordic.core import read_nordic
import seisbench
from seisbench import logger
from seisbench.data.base import BenchmarkDataset, EventParameters, TraceParameters
from seisbench.util import download_http
from seisbench.util.trace_ops import (
rotate_stream_to_zne,
stream_to_array,
trace_has_spikes,
)
REMOVE_CHANNELS = {
"HHT", # Stray channel or typo
"BHN", # Low sampling rate channel BH
"BHE",
"BHZ",
"SHN", # Low sampling rate channel SH
"SHE",
"SHZ",
}
STATION_MAPPING = {
"LACW": "LAC",
"STCW": "STC",
"VACW": "VAC",
"KVCW": "KVC",
"KRCW": "KRC",
"ZEIT": "ZEITZ",
"LEIB": "LEIB1",
}
COMPONENT_ORDER = "ZNE"
RNG = np.random.default_rng(31337)
PICK_WHITELIST = {
"Pg",
"Sg",
}
WEBNET_SWITCH_SH_CH_CHANNEL = datetime(2015, 5, 26)
STATION_NETWORK_CACHE: dict[str, str] = {}
ASCII_SET = string.ascii_lowercase + string.digits
[docs]
class BohemiaSaxony(BenchmarkDataset):
"""
Regional benchmark dataset of waveform data and metadata
for the North-West Bohemia and Saxony region in Germany/Czech Republic.
.. warning ::
This dataset contains restricted data from the West Bohemia Local Seismic Network (WEBNET).
To compile the full dataset, you will need to provide an EIDA token.
Please see the `WEBNET site <https://doi.org/10.7914/SN/WB>`_ for more information.
"""
_client: MultiClient
def __init__(self, eida_token: str | None = None, **kwargs):
citation = """
The data is compiled from data of the networks from Saxony Seismic Network (SX),
West Bohemia Local Seismic Network (WEBNET; WB), Thuringia Seismological Network (TH),
and Czech Regional Seismic Network (CZ) and is provided by the BGR (Bundesanstalt für
Geowissenschaften und Rohstoffe) and GFZ GEOFON.
The dataset includes restricted data from the WB network, which requires an EIDA token
for access. Please provide the path to your EIDA token file when initializing the
dataset via the `eida_token` argument.
Catalog and Picks:
* Earthquakes in Saxony (Germany) and surroundings
from 2006 to 2024 -- onsets and locations,
https://opara.zih.tu-dresden.de/items/5387886f-25f2-4faf-8dca-33981d898ab9
Seismic Networks:
* SXNET Saxon Seismic Network, https://doi.org/10.7914/SN/SX
* West Bohemia Local Seismic Network (WEBNET), https://doi.org/10.7914/SN/WB
* Thüringer Seismologisches Netz, https://doi.org/10.7914/SN/TH
* Czech Regional Seismic Network, https://doi.org/10.7914/SN/CZ
"""
self._eida_token = eida_token
super().__init__(
citation=citation,
license="CC0-1.0",
repository_lookup=False,
**kwargs,
)
def _init_client(self):
self._client = MultiClient()
self._client.add_client(Client("BGR"))
self._client.add_client(Client("GEOFON", eida_token=self._eida_token))
self._client.add_client(Client("LMU"))
def _download_catalog_colm(self, path: Path = Path.cwd()) -> Path:
files = list((self.path / "final2").glob("cll_*.txt"))
if files:
logger.debug("Catalog files already exist, skipping download.")
return path
logger.info("Downloading Bohemia Saxony catalog files.")
with NamedTemporaryFile(suffix=".zip", delete=False, dir=path) as temp_file:
catalog_url = urljoin(seisbench.remote_root, "auxiliary/collm-catalog.zip")
download_http(
catalog_url,
temp_file.name,
desc="Downloading Bohemia Saxony catalog",
)
with ZipFile(temp_file.name, "r") as zip_file:
zip_file.extractall(path)
# Fixup
bad_file = path / "final2" / "cll_2019.txt"
bad_file.write_text(bad_file.read_text().replace("GRZ!", "GRZ1"))
return path
[docs]
def get_inventory(
self,
catalog: Catalog,
force_download: bool = False,
) -> Inventory:
inventory_file = self.path / "inventory.xml"
if not inventory_file.exists() or force_download:
stations = get_stations(catalog)
starttime, endtime = get_catalog_timerange(catalog)
inv = self._client.get_inventory(
stations=stations
# starttime=starttime,
# endtime=endtime,
)
inv.write(str(inventory_file), format="STATIONXML")
else:
logger.info("Loading inventory from %s", inventory_file)
inv: Inventory = read_inventory(str(inventory_file))
return inv
[docs]
def get_catalog(self) -> Catalog:
self._download_catalog_colm(self.path)
nordic_files = sorted((self.path / "final2").glob("cll_*.txt"))
catalog = Catalog()
for file in nordic_files:
logger.info(f"Reading NORDIC file {file}")
catalog += read_nordic(str(file))
logger.info("Loaded catalog with %d events.", len(catalog))
return catalog
def _download_dataset(
self,
writer,
time_before: float = 60.0,
time_after: float = 60.0,
**kwargs,
):
logger.info(
"No pre-downloaded dataset available, downloading from BGR and Geofon. "
"This may take a while.",
)
if not self._eida_token:
logger.warning(
"No EIDA token provided. "
"Restricted WB network data will not be accessible.",
)
self._init_client()
writer.data_format = {
"dimension_order": "CW",
"component_order": COMPONENT_ORDER,
"measurement": "velocity",
"unit": "counts",
"instrument_response": "not restituted",
}
cat = self.get_catalog()
inv = self.get_inventory(cat, force_download=True)
cat = fixup_catalog(cat, inv)
async def download_event_data(
event_work: list,
timeout: float | None = 60.0,
split: Literal["train", "dev", "test"] = "train",
) -> int:
n_stations = 0
n_work = len(event_work)
try:
for i_result, result in enumerate(
asyncio.as_completed(event_work, timeout=timeout)
):
try:
event_params, trace_params, data = await result
event_params["split"] = split
except Exception as exc:
logger.warning(
"Error processing event %s: %s",
event.short_str(),
exc,
)
continue
await asyncio.to_thread(
writer.add_trace,
{**event_params, **trace_params},
data,
)
logger.debug(
"Processed station %s (%d/%d)",
event.short_str(),
i_result + 1,
n_work,
)
n_stations += 1
except asyncio.TimeoutError:
logger.error(
"Timeout while downloading event data for %s.", event.resource_id
)
return n_stations
n_stations_total = 0
n_events = len(cat)
logger.info("Downloading data for %d events.", n_events)
for i_event, event in enumerate(cat):
event_work = []
for _, pick_group in groupby(
event.picks,
key=lambda p: (
p.waveform_id.network_code,
p.waveform_id.station_code,
p.waveform_id.location_code,
),
):
work = self.get_station_waveform_data(
event=event,
picks=list(pick_group),
inventory=inv,
time_before=time_before,
time_after=time_after,
)
event_work.append(work)
n_stations = asyncio.run(download_event_data(event_work, split=get_split()))
n_stations_total += n_stations
logger.info(
"Downloaded %d waveform examples. %d stations for event %s (%d/%d)",
n_stations_total,
n_stations,
event.short_str(),
i_event + 1,
n_events,
)
def get_catalog_timerange(catalog: Catalog) -> tuple[datetime, datetime]:
"""
Get the time range of the catalog.
:param catalog: Catalog to get the time range from
:return: Start and end time of the catalog
"""
tmin = datetime.max
tmax = datetime.min
for ev in catalog:
ev = cast(Event, ev)
origin = ev.preferred_origin()
tmin = min(tmin, origin.time.datetime)
tmax = max(tmax, origin.time.datetime)
return tmin, tmax
def get_stations(catalog: Catalog) -> set[str]:
"""
Get the networks and stations from the catalog.
:param catalog: Catalog to get the networks and stations from
:return: Dictionary with networks as keys and sets of stations as values
"""
stations = set()
for ev in catalog:
ev = cast(Event, ev)
for pick in ev.picks:
if pick.waveform_id is not None:
station = pick.waveform_id.station_code
stations.add(station)
return stations
class MultiClient:
_clients: list[Client]
_network: dict[str, Client]
_executor_pool: ThreadPoolExecutor
def __init__(self):
self._clients: list[Client] = []
self._network: dict[str, Client] = {}
self._executor_pool = ThreadPoolExecutor(max_workers=8)
def add_client(self, client: Client):
self._clients.append(client)
async def get_waveforms(
self,
network: str,
station: str,
location: str,
channel: str,
starttime: datetime = datetime.min,
endtime: datetime = datetime.max,
) -> Stream:
if not self._network:
raise ValueError(
"No networks discovered. "
"Please call get_inventory() first to discover networks."
)
client = self._network[network]
logger.debug("Fetching waveforms from client %s", client.base_url)
loop = asyncio.get_event_loop()
stream = await loop.run_in_executor(
self._executor_pool,
partial(
client.get_waveforms,
network=network,
station=station,
location=location or "*",
channel=channel or "*",
starttime=UTCDateTime(starttime),
endtime=UTCDateTime(endtime),
),
)
return cast(Stream, stream)
def get_inventory(
self,
stations: str | Iterable[str],
starttime: datetime = datetime.min,
endtime: datetime = datetime.max,
level: Literal["response", "station", "channel"] = "channel",
) -> Inventory:
inventory = Inventory()
async def get_inventory(client: Client):
logger.info("Fetching inventory from client %s", client.base_url)
inv = await asyncio.to_thread(
client.get_stations,
network="*",
station=",".join(stations)
if isinstance(stations, Iterable)
else stations,
location="*",
channel="*",
level=level,
starttime=UTCDateTime(starttime),
endtime=UTCDateTime(endtime),
)
for network in inv:
self._network[network.code] = client
inventory.extend(inv)
async def worker():
tasks = [get_inventory(client) for client in self._clients]
return await asyncio.gather(*tasks)
asyncio.run(worker())
return inventory
def get_event_params(event: Event):
origin = event.preferred_origin()
magnitude = event.preferred_magnitude()
sb_id = str(event.resource_id).split("/")[-1]
if sb_id == "1":
sb_id = "sb_id_" + "".join(random.choice(ASCII_SET) for _ in range(6))
if origin is None:
raise ValueError("Event has no preferred origin.")
params: EventParameters = EventParameters(
split="train",
source_id=sb_id,
source_origin_time=str(origin.time),
source_origin_uncertainty_sec=origin.time_errors["uncertainty"],
source_latitude_deg=origin.latitude,
source_latitude_uncertainty_km=origin.latitude_errors["uncertainty"],
source_longitude_deg=origin.longitude,
source_longitude_uncertainty_km=origin.longitude_errors["uncertainty"],
source_depth_km=origin.depth / 1e3,
source_depth_uncertainty_km=origin.depth_errors["uncertainty"] / 1e3,
)
if magnitude is not None:
params["source_magnitude"] = magnitude.mag
params["source_magnitude_uncertainty"] = magnitude.mag_errors["uncertainty"]
params["source_magnitude_type"] = magnitude.magnitude_type
params["source_magnitude_author"] = magnitude.creation_info.agency_id
return params
def get_trace_params(
waveform_id: WaveformStreamID,
inventory: Inventory,
event_params: EventParameters,
) -> TraceParameters:
try:
coordinates = inventory.get_coordinates(waveform_id.get_seed_string())
except Exception as r:
raise ValueError(
f"Could not get coordinates for {waveform_id.get_seed_string()}: {r}"
) from r
if not np.isnan(coordinates["latitude"] * coordinates["longitude"]):
back_azimuth = gps2dist_azimuth(
event_params["source_latitude_deg"],
event_params["source_longitude_deg"],
coordinates["latitude"],
coordinates["longitude"],
)[2]
else:
back_azimuth = np.nan
trace_params = TraceParameters(
trace_name="foo",
path_back_azimuth_deg=back_azimuth,
station_network_code=waveform_id.network_code,
station_code=waveform_id.station_code,
trace_channel=waveform_id.channel_code[:2],
station_location_code=waveform_id.location_code,
station_latitude_deg=coordinates["latitude"],
station_longitude_deg=coordinates["longitude"],
station_elevation_m=coordinates["elevation"],
)
return trace_params
def fixup_catalog(catalog: Catalog, inventory: Inventory) -> Catalog:
"""
Fixes the channel codes of the picks.
"""
for event in catalog.events.copy():
for pick in event.picks.copy():
if pick.phase_hint not in PICK_WHITELIST:
event.picks.remove(pick)
continue
waveform_id = pick.waveform_id
try:
waveform_id.channel_code = "%sH%s" % tuple(waveform_id.channel_code)
except TypeError:
# This happens if the channel code is already in the correct format.
event.picks.remove(pick)
continue
# Fix station names to FDSN
waveform_id.station_code = STATION_MAPPING.get(
waveform_id.station_code,
waveform_id.station_code,
)
try:
waveform_id.network_code = get_network_code(
waveform_id.station_code, inventory
)
except KeyError:
event.picks.remove(pick)
continue
if waveform_id.network_code == "WB":
origin_time = event.preferred_origin().time.datetime
if origin_time < WEBNET_SWITCH_SH_CH_CHANNEL:
waveform_id.channel_code = "SH%s" % waveform_id.channel_code[-1]
else:
waveform_id.channel_code = "CH%s" % waveform_id.channel_code[-1]
if waveform_id.channel_code in REMOVE_CHANNELS:
event.picks.remove(pick)
if not event.picks:
logger.info("Removing event %s with no picks.", event.resource_id)
catalog.events.remove(event)
for station, network in STATION_NETWORK_CACHE.items():
if not network:
logger.warning(
"Station %s has no network code.",
station,
)
return catalog
async def homogenize_sampling_rate(st: Stream, sampling_rate: float) -> float:
"""
Homogenizes the sampling rate of the stream to the given sampling rate.
"""
if sampling_rate <= 0.0:
sampling_rate = st[0].stats.sampling_rate
for trace in st:
if trace.stats.sampling_rate != sampling_rate:
logger.info(
"Resampling trace %s from %.2f Hz to %.2f Hz.",
trace.id,
trace.stats.sampling_rate,
sampling_rate,
)
await asyncio.to_thread(trace.resample, sampling_rate)
return sampling_rate
def get_network_code(station_code: str, inventory: Inventory) -> str:
if station_code in STATION_NETWORK_CACHE:
return STATION_NETWORK_CACHE[station_code]
for network in inventory:
for station in network:
if station.code == station_code:
STATION_NETWORK_CACHE[station_code] = network.code
return network.code
raise KeyError(f"Unknown network code for station {station_code}.")
def get_split(test: float = 0.1, dev: float = 0.1) -> Literal["train", "dev", "test"]:
"""
Returns a random split for the dataset.
"""
r = RNG.random()
train = 1.0 - test - dev
if r < train:
return "train"
elif r < train + dev:
return "dev"
else:
return "test"
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument(
"--eida-token",
type=str,
default=None,
help="Path to EIDA token file.",
)
args = parser.parse_args()
logger.root.setLevel("INFO")
if args.eida_token is not None:
path = Path(args.eida_token)
if not path.exists():
parser.error(f"EIDA token file {path} does not exist.")
args.eida_token = str(path.expanduser().resolve())
ds = BohemiaSaxony(eida_token=args.eida_token, output_order="ZNE")