Source code for seisbench.models.das_wrapper

from typing import Any, Optional

import numpy as np
import torch
import seisbench

from .das_base import DASModel, PatchingStructure
from .base import WaveformModel


[docs] class DASWaveformModelWrapper(DASModel): """ This class is a wrapper to allow applying WaveformModels (trained for regular seismic data) to DAS datasets. The models are applied channel by channel. Example usage: .. code-block:: python base_model = PhaseNet.from_pretrained("instance") model = DASWaveformModelWrapper(base_model) :param model: The WaveformModel to apply to DAS data :param component_strategy: The strategy to transform the single-component DAS channels into three-component data for the model. Supports ``clone`` (provide same channel for each compoennt) and ``pad`` (provide channel as first component, zero-padding for second and third component). """ _annotate_args = DASModel._annotate_args.copy() def __init__( self, model: WaveformModel, component_strategy: str = "clone", **kwargs ): if not isinstance(model, WaveformModel): raise ValueError("Can only wrap WaveformModels.") if not model.output_type == "array": raise ValueError("Only 'array' models are supported.") if model.sampling_rate is not None: dt_range = (0.5 / model.sampling_rate, 2.0 / model.sampling_rate) else: dt_range = None annotate_keys = [x for x in model.labels if x != "N"] # Drop noise output super().__init__( citation=model.citation, dt_range=dt_range, annotate_keys=annotate_keys, filter_samples=self._get_filter_args(model, dt_range), **kwargs, ) self.model = model self.component_strategy = component_strategy self.default_args.update(model.default_args) model_annotate_args = model._annotate_args.copy() del model_annotate_args[ "batch_size" ] # Otherwise we get an unreasonable batch size self._annotate_args.update(model_annotate_args) if "overlap" in self._annotate_args: del self._annotate_args["overlap"] # Use overlap_samples instead @property def component_strategy(self): return self._component_strategy @component_strategy.setter def component_strategy(self, value): if value not in ["clone", "pad"]: raise ValueError("component_strategy must be either 'clone' or 'pad'.") self._component_strategy = value @staticmethod def _get_filter_args( model: WaveformModel, dt_range: Optional[tuple[float, float]] ) -> Optional[tuple[str, dict[str, Any]]]: if model.filter_args is None: return None else: if len(model.filter_args) != 1 or isinstance(model.filter_args, dict): seisbench.logger.warning( "Automatic filter inference failed due to incompatible filter specification. " "Will not apply any filter." ) return None # As the DAS model allows for a wider frequency range than the WaveformModel, # we sometimes need to adjust the filter frequency. if dt_range is not None: max_freq = ( 0.999999 * 0.5 / dt_range[1] ) # No filter frequency can be above half the Nyquist else: max_freq = np.inf filter_name = model.filter_args[0] filter_kwargs = model.filter_kwargs corners = filter_kwargs["corners"] if filter_kwargs.get("zerophase", False): seisbench.logger.warning( "Zero phase filtering not supported for DASWaveformModelWrapper. " "Doubling filter order." ) corners *= 2 if filter_name == "bandpass": return "iirfilter", { "N": corners, "btype": "bandpass", "ftype": "butter", "Wn": [ filter_kwargs["freqmin"], min(filter_kwargs["freqmax"], max_freq), ], } elif filter_name == "bandstop": return "iirfilter", { "N": corners, "btype": "bandstop", "ftype": "butter", "Wn": [ filter_kwargs["freqmin"], min(filter_kwargs["freqmax"], max_freq), ], } elif filter_name == "lowpass": return "iirfilter", { "N": corners, "btype": "lowpass", "ftype": "butter", "Wn": min(filter_kwargs["freq"], max_freq), } elif filter_name == "highpass": return "iirfilter", { "N": corners, "btype": "highpass", "ftype": "butter", "Wn": min(filter_kwargs["freq"], max_freq), } else: seisbench.logger.warning( "Automatic filter inference failed due to unsupported filter type " f"('{filter_name}'). Will not apply any filter." ) return None
[docs] def get_patching_structure( self, data_shape: tuple[float, float], argdict: dict[str, Any] ) -> PatchingStructure: n_samples = int(np.floor(data_shape[0])) n_channels = int(np.floor(data_shape[1])) in_samples, pred_samples = self.model._get_in_pred_samples(np.empty(n_samples)) in_channels = min(1024, n_channels) overlap_samples = self._argdict_get_with_default(argdict, "overlap_samples") if overlap_samples < 1: overlap_samples = int(overlap_samples * in_samples) blinding = self._argdict_get_with_default(argdict, "blinding") return PatchingStructure( in_channels=in_channels, out_channels=in_channels, range_channels=(0, in_channels), overlap_channels=0, out_samples=pred_samples[1] - pred_samples[0] - blinding[0] - blinding[1], in_samples=in_samples, range_samples=( pred_samples[0] + blinding[0], pred_samples[1] - blinding[1], ), # Note the slightly different convention here overlap_samples=overlap_samples, )
[docs] def forward(self, x: torch.Tensor, argdict: Optional[dict[str, Any]] = None): # x shape: (batch, samples, channels_das) x_original_shape = x.shape x = x.permute(0, 2, 1) # -> (batch, channels_das, samples) x = x.reshape((-1, 1, x.shape[2])) # -> (batch * channels_das, 1, samples) # Waveform model input shape: (batch, channels_3c, samples) n_components = len(self.model.component_order) if self.component_strategy == "clone": x = x.repeat(1, n_components, 1) elif self.component_strategy == "pad": x = torch.concatenate( [x] + (n_components - 1) * [torch.zeros_like(x)], dim=1 ) else: raise ValueError(f"Unknown strategy {self.component_strategy}") preprocessed = self.model.annotate_batch_pre(x, argdict=argdict) if isinstance(preprocessed, tuple): # Contains piggyback information assert len(preprocessed) == 2 preprocessed, piggyback = preprocessed else: piggyback = None preds = self.model(preprocessed) preds = self.model.annotate_batch_post( preds, piggyback=piggyback, argdict=argdict ) output = {} for i, key in enumerate(self.model.labels): if key == "N": continue ann = preds[:, :, i].reshape( x_original_shape[0], x_original_shape[2], preds.shape[1] ) # -> (batch, channels_das, samples) ann = ann.permute(0, 2, 1) # -> (batch, samples, channels_das) blinding = self._argdict_get_with_default(argdict, "blinding") b0, b1 = blinding[0], ann.shape[1] - blinding[1] ann = ann[:, b0:b1, :] output[key] = ann return output
[docs] def save(self, *args, **kwargs): """ This model does not provide a save function. Instead, save the underlying WaveformModel. """ raise NotImplementedError( "Saving not supported for this type of model." "Instead, save the underlying WaveformModel." )