import asyncio
import bisect
import copy
import inspect
import math
import re
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from operator import attrgetter
from typing import Any, AsyncGenerator, Optional, Type
import numpy as np
import pandas as pd
import scipy.signal
from tqdm import tqdm
from scipy.signal import resample_poly
import torch
import torch.nn as nn
import seisbench
import seisbench.util as sbu
from .base import SeisBenchModel
try:
import xdas
from xdas import DataArray
except ImportError:
xdas = sbu.MissingOptionalDependency("xdas", "xdas")
DataArray = sbu.MissingOptionalDependency("xdas", "xdas")
# Type conversion map
TORCH_TO_NUMPY = {
torch.bool: np.bool_,
torch.uint8: np.uint8,
torch.int8: np.int8,
torch.int16: np.int16,
torch.int32: np.int32,
torch.int64: np.int64,
torch.float16: np.float16,
torch.float32: np.float32,
torch.float64: np.float64,
torch.complex64: np.complex64,
torch.complex128: np.complex128,
}
[docs]
def torch_dtype_to_numpy(dtype: torch.dtype) -> np.dtype:
return np.dtype(TORCH_TO_NUMPY[dtype])
[docs]
@dataclass
class PatchCoordinate:
"""
Coordinates of a patch in the input or output array.
Denotes the upper-left corner of the patch and the dimensions along each axis.
Note that coordinates can take non-integer values due to transformations.
Callbacks should be able to handle this, e.g., by casting to int.
"""
sample: float
channel: float
w_sample: int
w_channel: int
def __array__(self):
return np.array([self.sample, self.channel, self.w_sample, self.w_channel])
@property
def sample_int(self) -> int:
return int(np.round(self.sample))
@property
def channel_int(self) -> int:
return int(np.round(self.channel))
[docs]
@dataclass
class PatchingStructure:
in_samples: int # Number of input samples per patch along the time axis
in_channels: int # Number of input samples per patch along the channel axis
out_samples: int # Number of output samples per patch along the time axis
out_channels: int # Number of output samples per patch along the channel axis
range_samples: tuple[
int, int
] # Range of the input covered by the output along the time axis
range_channels: tuple[
int, int
] # Range of the input covered by the output along the channel axis
overlap_samples: Optional[int] = (
None # Overlap between adjacent patches along the time axis
)
overlap_channels: Optional[int] = (
None # Overlap between adjacent patches along the channel axis
)
@property
def shift_samples(self):
return self.in_samples - self.overlap_samples
@property
def shift_channels(self):
return self.in_channels - self.overlap_channels
[docs]
class DASAnnotateCallback(ABC):
"""
This abstract class describes the interface for callbacks used in the DAS annotate method.
Callbacks will get streaming outputs from the annotate method, containing the different chunks after processing
with the deep learning model. Different callbacks are available, e.g., for picking or for writing the full output.
To implement a new callback, inherit from this class and implement the methods. Callbacks are stateful, allowing
them, for example, to handle overlaps between adjacent chunks.
.. warning:
As callbacks are stateful, they should not be used in parallel.
"""
[docs]
def setup(
self,
data: VirtualTransformedDataArray,
patching_structure: PatchingStructure,
annotate_keys: list[str],
) -> None:
"""
Setup step for the callback. This is called before the first chunk is processed and can be used to initialize
state variables, e.g., the shape of the output or arrays for intermediate results.
The setup step is optional, however, it is usually good practice to reset all state variables in the setup step.
"""
pass
[docs]
def finalize(self) -> None:
"""
Finalize step for the callback. This is called after the last chunk is processed and can be used to generate
the final results based on the intermediate results processed in each chunk.
The finalize step is optional.
"""
pass
[docs]
@abstractmethod
def handle_patch(
self,
annotations: dict[str, np.ndarray],
in_coords: PatchCoordinate,
out_coords: PatchCoordinate,
) -> None:
"""
This method is called for each patch of the output after processing it with the deep learning model.
Results inferred from this step should be stored in class variables.
"""
pass
[docs]
@abstractmethod
def get_results_dict(self) -> dict[str, Any]:
"""
This method returns a dictionary with the results of the callback. It is used to generate the ClassifyOutput
when using the callback through classify.
"""
pass
[docs]
class MultiCallback(DASAnnotateCallback):
def __init__(self, callbacks: list[DASAnnotateCallback]):
self.callbacks = callbacks
[docs]
def setup(
self,
data: VirtualTransformedDataArray,
patching_structure: PatchingStructure,
annotate_keys: list[str],
) -> None:
for callback in self.callbacks:
callback.setup(data, patching_structure, annotate_keys)
[docs]
def handle_patch(
self,
annotations: dict[str, np.ndarray],
in_coords: PatchCoordinate,
out_coords: PatchCoordinate,
) -> None:
for callback in self.callbacks:
callback.handle_patch(annotations, in_coords, out_coords)
[docs]
def finalize(self) -> None:
for callback in self.callbacks:
callback.finalize()
[docs]
def get_results_dict(self) -> dict[str, Any]:
raise NotImplementedError(
"MultiCallback does not support get_results_dict. "
"Instead, query the results from each individual callback."
)
[docs]
class DASPickingCallback(DASAnnotateCallback):
"""
Pick arrivals from probability curves using scipy.signal.find_peaks. The picking is performed independently on each
channel, i.e., no continuity is assumed between channels.
:param thresholds: Confidence thresholds for picking. Can be a single value for all phases,
or a dictionary with thresholds per phase.
:param min_time_separation: Minimum time separation between two picks of the same phase in seconds.
"""
def __init__(
self,
thresholds: float | dict[str, float] = 0.2,
min_time_separation: float = 1.0,
):
self._picks: dict[str, list[sbu.DASPick]] = {}
self._picks_per_channel_idx: dict[tuple[str, int], list[sbu.DASPick]] = (
defaultdict(list)
)
self.thresholds = thresholds
self.min_time_separation = min_time_separation
self._thresholds: dict[str, float] = {}
self._annotate_keys: list[str] = []
self._dt = 0.0
self._output_coords = None
self.channel_coord_name = None
[docs]
def setup(
self,
data: VirtualTransformedDataArray,
patching_structure: PatchingStructure,
annotate_keys: list[str],
) -> None:
self.channel_coord_name = data.channel_coord_name
self._annotate_keys = annotate_keys
self._picks = {key: [] for key in annotate_keys}
self._picks_per_channel_idx = defaultdict(list)
self._dt = data.dt
_, self._output_coords = DASModel.calc_output_shape_and_coordinates(
data, patching_structure
)
if isinstance(self.thresholds, float):
self._thresholds = {phase: self.thresholds for phase in annotate_keys}
else:
self._thresholds = self.thresholds
for key in annotate_keys:
if key not in self._thresholds:
raise ValueError(f"Threshold for key {key} not specified.")
[docs]
def handle_patch(
self,
annotations: dict[str, np.ndarray],
in_coords: PatchCoordinate,
out_coords: PatchCoordinate,
) -> None:
min_separation_samples = int(self.min_time_separation / self._dt)
for key in self._annotate_keys:
threshold = self._thresholds[key]
ann = annotations[key]
for channel_idx in range(ann.shape[1]):
if (ann[:, channel_idx] > threshold).any():
peaks, peak_properties = scipy.signal.find_peaks(
ann[:, channel_idx],
height=threshold,
distance=min_separation_samples,
)
for peak, confidence in zip(peaks, peak_properties["peak_heights"]):
self._picks_per_channel_idx[(key, channel_idx)].append(
sbu.DASPick(
time=self._translate_coords(
out_coords.sample + peak, "time"
),
channel=self._translate_coords(
out_coords.channel + channel_idx,
self.channel_coord_name,
),
confidence=confidence,
phase=key,
)
)
def _translate_coords(self, idx: float, coord_name: str) -> float | np.datetime64:
coord = self._output_coords[coord_name]
v0 = coord.get_value(int(idx))
v1 = coord.get_value(min(int(idx) + 1, len(coord) - 1))
return v0 + (v1 - v0) * (idx - int(idx))
[docs]
def finalize(self) -> None:
min_separation = np.timedelta64(int(self.min_time_separation * 1e9), "ns")
for (phase, _), picks in self._picks_per_channel_idx.items():
sorted_by_time = sorted(picks, key=attrgetter("time"))
sorted_by_confidence = sorted(
picks, key=attrgetter("confidence"), reverse=True
)
for pick in sorted_by_confidence:
idx1 = bisect.bisect_left(
sorted_by_time,
pick.time - min_separation,
lo=0,
hi=len(sorted_by_time),
key=attrgetter("time"),
)
idx2 = bisect.bisect_right(
sorted_by_time,
pick.time + min_separation,
lo=0,
hi=len(sorted_by_time),
key=attrgetter("time"),
)
if any(
other.confidence > pick.confidence
for other in sorted_by_time[idx1:idx2]
):
continue
self._picks[phase].append(pick)
[docs]
def get_results_dict(self) -> dict[str, Any]:
return self._picks
[docs]
def get_results_dataframe(self) -> pd.DataFrame:
return pd.concat([pd.DataFrame(key) for key in self._picks.values()])
[docs]
class WriterBuffer:
"""
A buffer to handle intersections between overlapping output data. The buffer expects data in patches of equal size.
The patch order needs to be left to right (samples), top to bottom (channels), i.e., first all samples for a range
of channels need to be processed before the next row can be processed.
The buffer keeps up to two rows in memory and writes slices along the sample axis once they are fully predicted.
"""
def __init__(
self,
data: VirtualTransformedDataArray,
stacking: str,
output_shape: tuple[int, int],
):
self.stacking = stacking
self.output_shape = output_shape
self._dtype = data.dtype
self._last_channel: int | None = None # First channel of the channel buffer
self._previous_sample: int | None = (
None # First sample of the previous sample buffer
)
self._current_sample: int | None = (
None # First sample of the current sample buffer
)
self._channel_buffer = None # Buffer along the current line
self._previous_sample_buffer = None # Line above, always complete
self._current_sample_buffer = None # Current line, potentially incomplete
if self.stacking == "avg":
self._channel_buffer_count = None
self._previous_sample_buffer_count = None
self._current_sample_buffer_count = None
@property
def stacking(self) -> str:
return self._stacking
@stacking.setter
def stacking(self, value: str):
if value not in ["avg", "max"]:
raise ValueError("Stacking must be either 'avg' or 'max'.")
self._stacking = value
[docs]
def add_data(
self, data: np.ndarray, out_coords: PatchCoordinate
) -> Optional[tuple[np.ndarray, PatchCoordinate]]:
if self._last_channel is not None and self._last_channel >= out_coords.channel:
raise ValueError(
"Channel indices within a row must be strictly increasing."
)
if (
self._current_sample is not None
and self._current_sample != out_coords.sample
):
raise ValueError("Can't start new row without finishing the previous one.")
output = None
if self._last_channel is None:
if out_coords.channel > 0:
raise ValueError("Each row needs to start at channel 0.")
self._channel_buffer = data.copy()
if self.stacking == "avg":
self._channel_buffer_count = np.ones_like(
self._channel_buffer, dtype=np.uint32
)
finalized_segment = None
self._last_channel = 0
else:
# Update channel buffer
boundary_channels = (
out_coords.channel_int - self._last_channel
) # channels that are only in one of the data
tmp_buffer = np.empty(
(out_coords.w_sample, out_coords.w_channel), dtype=self._dtype
)
finalized_segment = self._channel_buffer[:, :boundary_channels]
if self.stacking == "avg":
finalized_segment_count = self._channel_buffer_count[
:, :boundary_channels
]
tmp_buffer[:, -boundary_channels:] = data[
:, -boundary_channels:
] # Right - new data
# Middle - manage overlap
if self.stacking == "max":
tmp_buffer[:, :-boundary_channels] = np.maximum(
self._channel_buffer[:, boundary_channels:],
data[:, :-boundary_channels],
)
elif self.stacking == "avg":
tmp_buffer[:, :-boundary_channels] = (
self._channel_buffer[:, boundary_channels:]
+ data[:, :-boundary_channels]
)
tmp_buffer_count = np.zeros_like(tmp_buffer, dtype=np.uint32)
tmp_buffer_count[:, :-boundary_channels] = self._channel_buffer_count[
:, boundary_channels:
]
tmp_buffer_count += 1
self._channel_buffer_count = tmp_buffer_count
self._channel_buffer = tmp_buffer
# Row completed - Extend finalized segment
if out_coords.channel + out_coords.w_channel == self.output_shape[1]:
if (
finalized_segment is None
): # First entry in the row, i.e., row has only a single patch
finalized_segment = self._channel_buffer
if self.stacking == "avg":
finalized_segment_count = self._channel_buffer_count
else:
finalized_segment = np.concatenate(
[finalized_segment, self._channel_buffer], axis=1
)
if self.stacking == "avg":
finalized_segment_count = np.concatenate(
[finalized_segment_count, self._channel_buffer_count], axis=1
)
# Update sample buffer - Note that horizontal overlap is already handled
if self._current_sample is None: # First entry in the row
self._current_sample_buffer = np.empty(
(out_coords.w_sample, self.output_shape[1]), dtype=self._dtype
)
if self.stacking == "avg":
self._current_sample_buffer_count = np.zeros_like(
self._current_sample_buffer, dtype=np.uint32
)
self._current_sample = out_coords.sample_int
if finalized_segment is not None:
self._current_sample_buffer[
:, self._last_channel : self._last_channel + finalized_segment.shape[1]
] = finalized_segment
if self.stacking == "avg":
self._current_sample_buffer_count[
:,
self._last_channel : self._last_channel
+ finalized_segment.shape[1],
] = finalized_segment_count
self._last_channel = out_coords.channel_int
if (
out_coords.channel + out_coords.w_channel == self.output_shape[1]
): # Row completed - Write sample buffer
if self._previous_sample_buffer is None:
finalized_segment = None
finalized_segment_count = None
finalized_sample = None
self._previous_sample_buffer = self._current_sample_buffer
if self.stacking == "avg":
self._previous_sample_buffer_count = (
self._current_sample_buffer_count
)
self._previous_sample = self._current_sample
else:
boundary_samples = self._current_sample - self._previous_sample
finalized_segment = self._previous_sample_buffer[:boundary_samples]
if self.stacking == "avg":
finalized_segment_count = self._previous_sample_buffer_count[
:boundary_samples
]
tmp_buffer = np.empty(
(out_coords.w_sample, self.output_shape[1]), dtype=self._dtype
)
tmp_buffer[-boundary_samples:] = self._current_sample_buffer[
-boundary_samples:
] # Bottom - New data
# Middle - manage overlap
if self.stacking == "max":
tmp_buffer[:-boundary_samples] = np.maximum(
self._previous_sample_buffer[boundary_samples:],
self._current_sample_buffer[:-boundary_samples],
)
elif self.stacking == "avg":
tmp_buffer[:-boundary_samples] = (
self._previous_sample_buffer[boundary_samples:]
+ self._current_sample_buffer[:-boundary_samples]
)
tmp_buffer_count = np.zeros_like(tmp_buffer, dtype=np.uint32)
tmp_buffer_count[:-boundary_samples] = (
self._previous_sample_buffer_count[boundary_samples:]
)
tmp_buffer_count += self._current_sample_buffer_count
self._previous_sample_buffer_count = tmp_buffer_count
finalized_sample = self._previous_sample
self._previous_sample_buffer = tmp_buffer
self._previous_sample = self._current_sample
# Clear buffers for the current row
self._current_sample_buffer = None
self._current_sample_buffer_count = None
self._current_sample = None
self._channel_buffer = None
self._channel_buffer_count = None
self._last_channel = None
if (
out_coords.sample + out_coords.w_sample == self.output_shape[0]
): # Last row
if finalized_segment is None:
finalized_segment = self._previous_sample_buffer
if self.stacking == "avg":
finalized_segment_count = self._previous_sample_buffer_count
else:
finalized_segment = np.concatenate(
[finalized_segment, self._previous_sample_buffer], axis=0
)
if self.stacking == "avg":
finalized_segment_count = np.concatenate(
[
finalized_segment_count,
self._previous_sample_buffer_count,
],
axis=0,
)
self._previous_sample_buffer = None
self._previous_sample_buffer_count = None
if finalized_segment is not None:
if self.stacking == "avg":
finalized_segment = finalized_segment / finalized_segment_count
else:
# Copy to avoid passing out pointers to a view of a large array
finalized_segment = finalized_segment.copy()
if finalized_sample is None:
finalized_sample = 0
output = (
finalized_segment,
PatchCoordinate(
finalized_sample,
0,
finalized_segment.shape[0],
finalized_segment.shape[1],
),
)
return output
[docs]
def finalize(self) -> None:
if any(
buffer is not None
for buffer in [
self._channel_buffer,
self._channel_buffer_count,
self._current_sample_buffer,
self._current_sample_buffer_count,
self._previous_sample_buffer,
self._previous_sample_buffer_count,
]
):
raise ValueError("Incomplete return data.")
[docs]
class WriterCallback(DASAnnotateCallback):
"""
Writes the raw predictions of the model to disk. The callback implements streaming processing to avoid excessive
memory usage, while ensuring correct splicing at the overlaps between adjacent patches.
The output writing relies on the
`xdas DataArrayWriter <https://xdas.readthedocs.io/en/latest/_autosummary/xdas.processing.DataArrayWriter.html>`_ .
This means that the output will be written in multiple files using one output folder per annotation key.
To load the files for key ``x`` use ``xdas.open_mfdataarray("output_path/x/*")``. Note that the time coordinate will
have minor discontinuities due to the chunked writing. These can be fixed by calling
``data.coords["time"] = data.coords["time"].simplify(tolerance=np.timedelta64(1, "us"))``.
"""
def __init__(self, output_path: Path | str, stacking: str = "avg"):
self.output_path = Path(output_path)
self.output_path.mkdir(parents=False, exist_ok=True)
if stacking not in ["avg", "max"]:
raise ValueError("Stacking must be either 'avg' or 'max'.")
self.stacking = stacking
self._writer_buffers = None
self._data_writers = None
self._output_coords = None
self._annotate_keys = None
self._sample_buffer = None
self._channel_buffer = None
[docs]
def setup(
self,
data: VirtualTransformedDataArray,
patching_structure: PatchingStructure,
annotate_keys: list[str],
) -> None:
output_shape, self._output_coords = DASModel.calc_output_shape_and_coordinates(
data, patching_structure
)
self.channel_coord_name = data.channel_coord_name
self._annotate_keys = annotate_keys
self._data_writers = {}
self._writer_buffers = {}
for key in annotate_keys:
annotation_path = self.output_path / key
annotation_path.mkdir(parents=False, exist_ok=False)
self._data_writers[key] = xdas.processing.DataArrayWriter(annotation_path)
self._writer_buffers[key] = WriterBuffer(
data, stacking=self.stacking, output_shape=output_shape
)
[docs]
def finalize(self) -> None:
for writer in self._data_writers.values():
writer.result() # Closes writer
self._data_writers = None
[docs]
def handle_patch(
self,
annotations: dict[str, np.ndarray],
in_coords: PatchCoordinate,
out_coords: PatchCoordinate,
) -> None:
for key in self._annotate_keys:
segment = self._writer_buffers[key].add_data(annotations[key], out_coords)
if segment is not None:
seg_annotation, seg_coords = segment
self._data_writers[key].write(
xdas.DataArray(
data=seg_annotation,
coords={
"time": self._output_coords["time"][
seg_coords.sample_int : seg_coords.sample_int
+ seg_coords.w_sample
],
self.channel_coord_name: self._output_coords[
self.channel_coord_name
][
seg_coords.channel_int : seg_coords.channel_int
+ seg_coords.w_channel
],
},
)
)
[docs]
def get_results_dict(self) -> dict[str, Any]:
return {key: self.output_path / key for key in self._annotate_keys}
[docs]
class InMemoryCollectionCallback(DASAnnotateCallback):
"""
Collects the raw predictions of the model in memory and splices the DAS array back together from the individual
patches. To avoid memory overflows, this callback should only be used for small datasets.
"""
def __init__(self, stacking: str = "avg"):
self.annotations = None
self.stacking = stacking
self._annotate_keys = None
self._output_coords = None
self._write_buffers = None
self._outputs = None
[docs]
def setup(
self,
data: VirtualTransformedDataArray,
patching_structure: PatchingStructure,
annotate_keys: list[str],
) -> None:
output_shape, self._output_coords = DASModel.calc_output_shape_and_coordinates(
data, patching_structure
)
self._write_buffers = {
key: WriterBuffer(data, stacking=self.stacking, output_shape=output_shape)
for key in annotate_keys
}
self._outputs = {
key: np.empty(output_shape, dtype=data.dtype) for key in annotate_keys
}
self._annotate_keys = annotate_keys
[docs]
def handle_patch(
self,
annotations: dict[str, np.ndarray],
in_coords: PatchCoordinate,
out_coords: PatchCoordinate,
) -> None:
for key in self._annotate_keys:
segment = self._write_buffers[key].add_data(annotations[key], out_coords)
if segment is not None:
seg_annotation, seg_coords = segment
self._outputs[key][
seg_coords.sample_int : seg_coords.sample_int + seg_coords.w_sample,
seg_coords.channel_int : seg_coords.channel_int
+ seg_coords.w_channel,
] = seg_annotation
[docs]
def finalize(self) -> None:
self.annotations = {}
for key in self._annotate_keys:
self.annotations[key] = DataArray(
data=self._outputs[key], coords=self._output_coords
)
[docs]
def get_results_dict(self) -> dict[str, Any]:
return self.annotations
[docs]
class FKFilter(nn.Module):
"""
An F-k filter implemented in PyTorch. The filter processes batched data, i.e., the input format should be
(batch, samples, channels).
:param dx: Channel spacing in space
:param dt: Sample spacing in time
:param v_min: Minimum velocity to be considered in the filter. If None, no filtering is applied.
:param v_max: Maximum velocity to be considered in the filter. If None, no filtering is applied.
:param mode: Either "pass" or "reject". If "pass" all velocities between v_min and v_max are retained. If "reject",
all frequencies outside this band.
"""
def __init__(
self,
dt: float,
dx: float,
v_min: Optional[float] = None,
v_max: Optional[float] = None,
mode: str = "pass",
**kwargs,
):
super().__init__(**kwargs)
self.dt = dt
self.dx = dx
self.v_min = v_min
self.v_max = v_max
self.mode = mode
if mode not in ["pass", "reject"]:
raise ValueError("mode must be 'pass' or 'reject'")
[docs]
def forward(self, data: torch.Tensor) -> torch.Tensor:
_, nt, nx = data.shape
Fk = torch.fft.fft2(data)
f = torch.fft.fftfreq(nt, self.dt) # 1/s
k = torch.fft.fftfreq(nx, self.dx) # 1/m
F, K = torch.meshgrid(f, k, indexing="ij")
c = torch.abs(F / (K + 1e-10))
mask = torch.ones(Fk.shape[1:], dtype=torch.bool)
if self.v_min is not None:
mask = mask & (c >= self.v_min)
if self.v_max is not None:
mask = mask & (c < self.v_max)
if self.mode == "reject":
mask = ~mask
Fk_filtered = Fk * mask
return torch.fft.ifft2(Fk_filtered).real
[docs]
class DASModel(SeisBenchModel, ABC):
"""
This is the base class for all models processing DAS data.
.. admonition:: Hint
If you are an end-user looking to apply pretrained models, you most likely won't interact with this class
directly. Instead, you will use classes inheriting from this class and their :py:func:`annotate` and
:py:func:`classify` functions. If you aim to develop your own model, you should inherit from this class and
have a look at the details below.
.. admonition:: Hint
When calling ``annotate`` or ``classify``, the model can perform automatic resampling along both axis. This
ensures that the model can be flexibly applied to data of different sampling rates and channel spacings.
However, as models are typically stable with respect to small changes in sampling rate and channel spacing,
this class allows for a range of sampling rates and channel spacings to be specified. When called on data that
does not fall into this ratio, the model will search for the smallest set of integers for upsampling and
downsampling. The resampling is done using ``scipy.signal.resample_poly``. To get the exact resampling ratio
for a particular input array, check the function :py:func:`get_resample_ratio`.
:param patching_structure: The structure of the patches to cut for annotation. If None, the function
:py:func:`get_patching_structure` needs to be implemented, allowing to dynamically adjust
the patching structure to the input data.
:param dt_range: Admissible range for the time step of data to be processed. This value is only taken into account
for the execution of the ``annotate``/``classify`` functions. See the above hint on the resampling
behavior. Values are in seconds.
:param dx_range: Same as ``dt_range`` but along the channel axis. Values are in meters.
:param buffer_queue_size: Maximum number of chunks to keep in the intermediate buffers.
:param annotate_forward_kwargs: Additional keyword arguments to pass to the ``forward`` method of the model when
running ``annotate``/``classify``.
:param annotate_keys: List of annotation keys to read from the output.
:param default_args: Default arguments for the optional keyword arguments of ``annotate``/``classify``.
:param fk_filter_args: Arguments for the F-k filter. See :py:class:`FKFilter` for details.
:param filter_samples: Filter to apply along the sample axis. See :py:class:`VirtualTransformedDataArray`
for details.
"""
_annotate_args = {
"batch_size": ("Batch size for the model", 2),
"pbar": ("Show progress bar", True),
"overlap_samples": (
"Overlap between patches along the sample axis. "
"Values between 0 and 1 are treated as fractions of the patch length. "
"Values above 1 a sample counts.",
0.5,
),
"overlap_channels": (
"Overlap between patches along the channel axis. "
"Values between 0 and 1 are treated as fractions of the patch length. "
"Values above 1 a channel counts.",
0.5,
),
"channel_coord_name": (
"Name of the channel coordinate in the input data. "
"The same will be used in the output. "
"If None, the name is inferred automatically using a list of candidates.",
None,
),
}
def __init__(
self,
dt_range: Optional[tuple[float, float]] = None,
dx_range: Optional[tuple[float, float]] = None,
patching_structure: Optional[PatchingStructure] = None,
buffer_queue_size: int = 8,
annotate_forward_kwargs: Optional[dict[str, Any]] = None,
annotate_keys: Optional[list[str]] = None,
default_args: Optional[dict[str, Any]] = None,
fk_filter_args: Optional[dict[str, Any]] = None,
filter_samples: Optional[tuple[str, dict[str, Any]]] = None,
**kwargs,
):
super().__init__(**kwargs)
self.dt_range = dt_range
self.dx_range = dx_range
self.patching_structure = patching_structure
self.buffer_queue_size = buffer_queue_size
if annotate_forward_kwargs is None:
annotate_forward_kwargs = {}
self.annotate_forward_kwargs = annotate_forward_kwargs
if annotate_keys is None:
annotate_keys = []
self.annotate_keys = annotate_keys
if default_args is None:
default_args = {}
self.default_args = default_args
self.fk_filter_args = fk_filter_args
self.filter_samples = filter_samples
[docs]
def get_model_args(self):
# This function does not export technical arguments that are usually set in the subclass.
model_args = super().get_model_args()
model_args = {
**model_args,
**{
"buffer_queue_size": self.buffer_queue_size,
},
}
return model_args
[docs]
def get_patching_structure(
self, data_shape: tuple[float, float], argdict: dict[str, Any]
) -> PatchingStructure:
"""
To enable dynamic window sizes, depending on the shape of the input record, this function can be overwritten.
By default, returns the predefined patching structure. In addition, this function allows to overwrite the
overlap dynamically.
The ``data_shape`` is provided for adaptive models. Note that the data shape can have float coordinates due to
in-memory resampling of the data. The actual output shape can only be inferred once the patching structure has
been defined, as the number of truncated samples depends on the patching structure. Therefore, models should be
flexible towards the case of slightly smaller data shapes than the theoretical one.
"""
if self.patching_structure is None:
raise ValueError(
"Patching structure needs to be defined. "
"This can happen either by setting it in the constructor or by implementing "
"_get_patch_structure in a subclass."
)
patching_structure = copy.deepcopy(self.patching_structure)
if patching_structure.overlap_samples is None:
overlap_samples = self._argdict_get_with_default(argdict, "overlap_samples")
if overlap_samples < 1:
overlap_samples = int(overlap_samples * patching_structure.in_samples)
patching_structure.overlap_samples = overlap_samples
if patching_structure.overlap_channels is None:
overlap_channels = self._argdict_get_with_default(
argdict, "overlap_channels"
)
if overlap_channels < 1:
overlap_channels = int(
overlap_channels * patching_structure.in_channels
)
patching_structure.overlap_channels = overlap_channels
return patching_structure
[docs]
def annotate(self, *args, **kwargs) -> None:
asyncio.run(self.annotate_async(*args, **kwargs))
[docs]
def classify(self, *args, **kwargs) -> sbu.ClassifyOutput:
return asyncio.run(self.classify_async(*args, **kwargs))
[docs]
async def annotate_async(
self, data: "DataArray", callback: DASAnnotateCallback, **kwargs
) -> None:
self._verify_argdict(kwargs)
# Kwargs overwrite default args
argdict = self.default_args.copy()
argdict.update(kwargs)
channel_coord_name = self._argdict_get_with_default(
argdict, "channel_coord_name"
)
resample_samples, resample_channels = self.get_resample_ratios(
data, channel_coord_name
)
patching_structure = self.get_patching_structure(
VirtualTransformedDataArray.estimate_theoretical_output_shape(
data, resample_samples, resample_channels
),
argdict,
)
virtual_data = VirtualTransformedDataArray(
data,
patching_structure,
resample_samples=resample_samples,
resample_channels=resample_channels,
filter_samples=self.filter_samples,
force_dtype=torch_dtype_to_numpy(self.dtype),
channel_coord_name=self._argdict_get_with_default(
argdict, "channel_coord_name"
),
)
if self.fk_filter_args is None:
fk_filter = None
else:
fk_filter = FKFilter(
dt=virtual_data.dt, dx=virtual_data.dx, **self.fk_filter_args
)
callback.setup(virtual_data, patching_structure, self.annotate_keys)
queue_chunks = asyncio.Queue(self.buffer_queue_size)
queue_output = asyncio.Queue(self.buffer_queue_size)
loader_task = self._slice_and_preprocess(virtual_data, queue_chunks)
processor_task = self._process_patch(
queue_chunks, queue_output, patching_structure, fk_filter, argdict
)
postprocessor_task = self._postprocess_patch(
queue_output, callback, len(virtual_data), argdict
)
await asyncio.gather(loader_task, processor_task, postprocessor_task)
callback.finalize()
async def _slice_and_preprocess(
self,
data: VirtualTransformedDataArray,
queue_out: asyncio.Queue,
) -> None:
async for patch, coord in data:
await queue_out.put((patch, coord))
await queue_out.put(None)
async def _process_patch(
self,
queue_in: asyncio.Queue,
queue_out: asyncio.Queue,
patching_structure: PatchingStructure,
fk_filter: Optional[FKFilter],
argdict: dict[str, Any],
) -> None:
batch_size = self._argdict_get_with_default(argdict, "batch_size")
buffer = []
while True:
item = await queue_in.get()
if item is None:
if len(buffer) > 0:
annotation_items = await asyncio.to_thread(
self._predict_buffer,
buffer,
patching_structure,
fk_filter,
argdict,
)
for annotation_item in annotation_items:
await queue_out.put(annotation_item)
break
buffer.append(item)
if len(buffer) == batch_size:
annotation_items = await asyncio.to_thread(
self._predict_buffer, buffer, patching_structure, fk_filter, argdict
)
for annotation_item in annotation_items:
await queue_out.put(annotation_item)
buffer = []
queue_in.task_done()
await queue_out.put(None)
def _predict_buffer(
self,
buffer: list[tuple[np.ndarray, PatchCoordinate]],
patching_structure: PatchingStructure,
fk_filter: Optional[FKFilter],
argdict: dict[str, Any],
) -> list[tuple[dict[str, np.ndarray], PatchCoordinate]]:
data = [data for data, _ in buffer]
coords = [coord for _, coord in buffer]
data = torch.tensor(
np.stack(data, axis=0), device=self.device, dtype=self.dtype
)
train_mode = self.training
try:
self.eval()
with torch.no_grad():
if fk_filter is not None:
data = fk_filter(data)
annotations = self(
data, **self.annotate_forward_kwargs, argdict=argdict
)
finally:
if train_mode:
self.train()
output = []
for i, in_coord in enumerate(coords):
out_coord = self._transform_patch_coordinates(in_coord, patching_structure)
sample_annotations = {
key: annotations[key][i].cpu().numpy() for key in self.annotate_keys
}
output.append((sample_annotations, in_coord, out_coord))
return output
[docs]
def get_resample_ratios(
self, data: "DataArray", channel_coord_name: Optional[str]
) -> tuple[tuple[int, int], tuple[int, int]]:
"""
Estimates integer ratios for resampling along the sample and channel axes to fall into the predefined ratios.
"""
if channel_coord_name is None:
channel_coord_name = VirtualTransformedDataArray.guess_channel_coord_name(
data
)
if self.dt_range is None:
resample_samples = (1, 1)
else:
time_coords = data.coords["time"]
data_dt = (time_coords[1] - time_coords[0]) / np.timedelta64(1, "s")
resample_samples = self._find_range(data_dt, *self.dt_range)
if self.dx_range is None:
resample_channels = (1, 1)
else:
channel_coords = data.coords[channel_coord_name]
data_dx = float((channel_coords[1] - channel_coords[0]).values)
resample_channels = self._find_range(data_dx, *self.dx_range)
return resample_samples, resample_channels
@staticmethod
def _find_range(
v: float, vmin: float, vmax: float, limit: int = 10
) -> tuple[int, int]:
for up in range(1, limit + 1):
for down in range(1, limit + 1):
if vmin <= (v * down) / up <= vmax:
break
else:
continue
break
else:
raise ValueError("No integer ratio found for resampling data.")
return up, down
@staticmethod
def _transform_patch_coordinates(
coord: PatchCoordinate, patching_structure: PatchingStructure
) -> PatchCoordinate:
"""
Transform between coordinates in the input array and coordinates in the output array.
"""
if (
coord.w_sample != patching_structure.in_samples
or coord.w_channel != patching_structure.in_channels
):
raise ValueError("Patch coordinates do not match patching structure.")
scale_samples = patching_structure.out_samples / (
patching_structure.range_samples[1] - patching_structure.range_samples[0]
)
scale_channels = patching_structure.out_channels / (
patching_structure.range_channels[1] - patching_structure.range_channels[0]
)
return PatchCoordinate(
sample=coord.sample * scale_samples,
channel=coord.channel * scale_channels,
w_sample=patching_structure.out_samples,
w_channel=patching_structure.out_channels,
)
[docs]
@staticmethod
def calc_output_shape_and_coordinates(
da: VirtualTransformedDataArray, patching_structure: PatchingStructure
) -> "tuple[tuple[int, int], dict[str, xdas.InterpCoordinate]]":
"""
Calculates the shape and coordinate axis of the output array after processing with the given patching structure.
In case the output shape would be fractional, an extra sample is added to the output array along the
corresponding axis.
"""
scale_samples = patching_structure.out_samples / (
patching_structure.range_samples[1] - patching_structure.range_samples[0]
)
scale_channels = patching_structure.out_channels / (
patching_structure.range_channels[1] - patching_structure.range_channels[0]
)
truncated_samples = (
patching_structure.in_samples
- patching_structure.range_samples[1]
+ patching_structure.range_samples[0]
)
truncated_channels = (
patching_structure.in_channels
- patching_structure.range_channels[1]
+ patching_structure.range_channels[0]
)
output_shape = (
np.ceil((da.shape[0] - truncated_samples) * scale_samples).astype(int),
np.ceil((da.shape[1] - truncated_channels) * scale_channels).astype(int),
)
output_coords = {
"time": xdas.InterpCoordinate(
data={
"tie_indices": [0, output_shape[0] - 1],
"tie_values": [
da.coords["time"][patching_structure.range_samples[0]].data,
da.coords["time"][
-(
patching_structure.in_samples
- patching_structure.range_samples[1]
)
- 1
].data,
],
},
dim="time",
),
da.channel_coord_name: xdas.InterpCoordinate(
data={
"tie_indices": [0, output_shape[1] - 1],
"tie_values": [
da.coords[da.channel_coord_name][
patching_structure.range_channels[0]
].data,
da.coords[da.channel_coord_name][
-(
patching_structure.in_channels
- patching_structure.range_channels[1]
)
- 1
].data,
],
},
dim=da.channel_coord_name,
),
}
return output_shape, output_coords
async def _postprocess_patch(
self,
queue_in: asyncio.Queue,
callback: DASAnnotateCallback,
n_patches: int,
argdict: dict[str, Any],
) -> None:
if self._argdict_get_with_default(argdict, "pbar"):
pbar = tqdm(total=n_patches)
else:
pbar = None
while True:
item = await queue_in.get()
if item is None:
break
await asyncio.to_thread(callback.handle_patch, *item)
if pbar is not None:
pbar.update(1)
queue_in.task_done()
if pbar is not None:
pbar.close()
[docs]
async def classify_async(self, data: "DataArray", **kwargs) -> sbu.ClassifyOutput:
"""
The classify method is used to process the data and apply the default callback.
The ``kwargs`` are split into two groups: those that are passed to the callback and those that are passed to the
annotate method.
"""
callback_cls = self.classify_callback
callback_kwargs = {
k: v
for k, v in kwargs.items()
if k in inspect.signature(callback_cls).parameters
}
annotate_kwargs = {
k: v
for k, v in kwargs.items()
if k not in inspect.signature(callback_cls).parameters
}
callback = callback_cls(**callback_kwargs)
await self.annotate_async(data, callback, **annotate_kwargs)
return sbu.ClassifyOutput(self.name, **callback.get_results_dict())
@property
def classify_callback(self) -> Type[DASAnnotateCallback]:
"""
Return the default callback for this model. For example, for picking models, this would be a DASPickingCallback.
The class will then be instantiated and used to process the output of the annotate method.
Constructor arguments will be extracted from the ``kwargs`` passed to ``classify``.
"""
raise NotImplementedError()
def _argdict_get_with_default(self, argdict, key):
return argdict.get(key, self._annotate_args.get(key)[1])
def _verify_argdict(self, argdict):
for key in argdict.keys():
if not any(
re.fullmatch(pattern.replace("*", ".*"), key)
for pattern in self._annotate_args.keys()
):
seisbench.logger.warning(f"Unknown argument '{key}' will be ignored.")