seisbench.models

Base classes

class ActivationLSTMCell(input_size, hidden_size, gate_activation=<function hard_sigmoid>, recurrent_dropout=0)[source]

Bases: Module

LSTM Cell using variable gating activation, by default hard sigmoid

If gate_activation=torch.sigmoid this is the standard LSTM cell

Uses recurrent dropout strategy from https://arxiv.org/abs/1603.05118 to match Keras implementation.

forward(input, state)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

init_weights()[source]
class CustomLSTM(cell, *cell_args, bidirectional=True, **cell_kwargs)[source]

Bases: Module

LSTM to be used with custom cells

forward(input, state=None)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class GroupingHelper(grouping)[source]

Bases: object

A helper class for grouping streams for the annotate function. In most cases, no direct interaction with this class is required. However, when implementing new models, subclassing this helper allows for more flexibility.

group_stream(stream, strict, min_length_s, comp_dict)[source]

Perform grouping of input stream. In addition, enforces the strict mode, i.e, if strict=True only keeps segments where all components are available, and discards segments that are too short. For grouping=channel no checks are performed.

Parameters:
  • stream (Stream) – Input stream

  • strict (bool) – If streams should be treated strict as for waveform model. Only applied if grouping is “full”.

  • min_length_s (float) – Minimum length of a segment in seconds. Only applied if grouping is “full”.

  • comp_dict (dict[str, int]) – Mapping of component characters to int. Only used if grouping is “full”.

Return type:

list[list[Trace]]

Returns:

Grouped list of list traces.

property grouping
static trace_id_without_component(trace)[source]
class SeisBenchModel(citation=None)[source]

Bases: Module

Base SeisBench model interface for processing waveforms.

Parameters:

citation (str, optional) – Citation reference, defaults to None.

property citation
classmethod from_pretrained(name, version_str='latest', update=False, force=False, wait_for_file=False)[source]

Load pretrained model with weights.

A pretrained model weights consists of two files. A weights file [name].pt and a [name].json config file. The config file can (and should) contain the following entries, even though all arguments are optional:

  • “docstring”: A string documenting the pipeline. Usually also contains information on the author.

  • “model_args”: Argument dictionary passed to the init function of the pipeline.

  • “seisbench_requirement”: The minimal version of SeisBench required to use the weights file.

  • “default_args”: Default args for the annotate()/classify() functions. These arguments will supersede any potential constructor settings.

  • “version”: The version string of the model. For all but the latest version, version names should furthermore be denoted in the file names, i.e., the files should end with the suffix “.v[VERSION]”. If no version is specified in the json, the assumed version string is “1”.

Warning

Even though the version is set to “latest” by default, this will only use the latest version locally available. Only if no weight is available locally, the remote repository will be queried. This behaviour is implemented for privacy reasons, as it avoids contacting the remote repository for every call of the function. To explicitly update to the latest version from the remote repository, set update=True.

Parameters:
  • name (str) – Model name prefix.

  • version_str (str) – Version of the weights to load. Either a version string or “latest”. The “latest” model is the model with the highest version number.

  • force (bool, optional) – Force execution of download callback, defaults to False

  • update (bool) – If true, downloads potential new weights file and config from the remote repository. The old files are retained with their version suffix.

  • wait_for_file (bool, optional) – Whether to wait on partially downloaded files, defaults to False

Returns:

Model instance

Return type:

SeisBenchModel

abstract get_model_args()[source]

Obtain all model parameters for saving.

Returns:

Dictionary of all parameters for a model to store during saving.

Return type:

Dict

classmethod list_pretrained(details=False, remote=True)[source]

Returns list of available pretrained weights and optionally their docstrings.

Parameters:
  • details (bool) – If true, instead of a returning only a list, also return their docstrings. By default, returns the docstring of the “latest” version for each weight. Note that this requires to download the json files for each model in the background and is therefore slower. Defaults to false.

  • remote (bool) – If true, reports both locally available weights and versions in the remote repository. Otherwise only reports local versions.

Returns:

List of available weights or dict of weights and their docstrings

Return type:

list or dict

classmethod list_versions(name, remote=True)[source]

Returns list of available versions for a given weight name.

Parameters:
  • name (str) – Name of the queried weight

  • remote (bool) – If true, reports both locally available versions and versions in the remote repository. Otherwise only reports local versions.

Returns:

List of available versions

Return type:

list[str]

classmethod load(path, version_str=None)[source]

Load a SeisBench model from local path.

For more information on the SeisBench model format see py:func:save.

Parameters:
  • path (pathlib.Path ot str) – Define the path to the SeisBench model.

  • version_str (str, None) – Version string of the model. If none, no version string is appended.

Returns:

Model instance

Return type:

SeisBenchModel

property name
save(path, weights_docstring='', version_str=None)[source]

Save a SeisBench model locally.

SeisBench models are stored inside the directory ‘path’. SeisBench models are saved in 2 parts, the model configuration is stored in JSON format [path][.json], and the underlying model weights in PyTorch format [path][.pt]. Where ‘path’ is the output path to store. The suffixes are appended to the path parameter automatically.

In addition, the models can have a version string which is appended to the json and the pt path. For example, setting version_str=”1” will append .v1 to the file names.

The model config should contain the following information, which is automatically created from the model instance state:

  • “weights_docstring”: A string documenting the pipeline. Usually also contains information on the author.

  • “model_args”: Argument dictionary passed to the init function of the pipeline.

  • “seisbench_requirement”: The minimal version of SeisBench required to use the weights file.

  • “default_args”: Default args for the annotate()/classify() functions.

Non-serializable arguments (e.g. functions) cannot be saved to JSON, so are not converted.

Parameters:
  • path (pathlib.Path or str) – Define the path to the output model.

  • weights_docstring (str, default to '') – Documentation for the model weights (training details, author etc.)

  • version_str (str, None) – Version string of the model. If none, no version string is appended.

property weights_docstring
property weights_version
class WaveformModel(component_order=None, sampling_rate=None, output_type=None, default_args=None, in_samples=None, pred_sample=0, labels=None, filter_args=None, filter_kwargs=None, grouping='instrument', allow_padding=False, **kwargs)[source]

Bases: SeisBenchModel, ABC

Abstract interface for models processing waveforms. Based on the properties specified by inheriting models, WaveformModel automatically provides the respective annotate()/classify() functions. Both functions take obspy streams as input. The annotate() function has a rather strictly defined output, i.e., it always outputs obspy streams with the annotations. These can for example be functions of pick probability over time. In contrast, the classify() function can tailor it’s output to the model type. For example, a picking model might output picks, while a magnitude estimation model might only output a scalar magnitude. Internally, classify() will usually rely on annotate() and simply add steps to it’s output.

For details see the documentation of these functions.

The following parameters are available for the annotate/classify functions:

Argument

Description

Default value

batch_size

Batch size for the model

256

overlap

Overlap between prediction windows in samples (only for window prediction models)

0

stacking

Stacking method for overlapping windows (only for window prediction models). Options are ‘max’ and ‘avg’.

avg

stride

Stride in samples (only for point prediction models)

1

strict

If true, only annotate if recordings for all components are available, otherwise impute missing data with zeros.

False

flexible_horizontal_components

If true, accepts traces with Z12 components as ZNE and vice versa. This is usually acceptable for rotationally invariant models, e.g., most picking models.

True

Hint

Please note that the default parameters can be superseded by the pretrained model weights. Check model.default_args to see which parameters are overwritten.
Parameters:
  • component_order (list, optional) – Specify component order (e.g. [‘ZNE’]), defaults to None.

  • sampling_rate (float) – Sampling rate of the model, defaults to None. If sampling rate is not None, the annotate and classify functions will automatically resample incoming traces and validate correct sampling rate if the model overwrites annotate_stream_pre().

  • output_type (str) –

    The type of output from the model. Current options are:

    • ”point” for a point prediction, i.e., the probability of containing a pick in the window or of a pick at a certain location. This will provide an annotate() function. If an classify_aggregate() function is provided by the inheriting model, this will also provide a classify() function.

    • ”array” for prediction curves, i.e., probabilities over time for the arrival of certain wave types. This will provide an annotate() function. If an classify_aggregate() function is provided by the inheriting model, this will also provide a classify() function.

    • ”regression” for a regression value, i.e., the sample of the arrival within a window. This will only provide a classify() function.

  • default_args (dict[str, any]) – Default arguments to use in annotate and classify functions

  • in_samples (int) – Number of input samples in time

  • pred_sample (int, tuple) – For a “point” prediction: sample number of the sample in a window for which the prediction is valid. For an “array” prediction: a tuple of first and last sample defining the prediction range. Note that the number of output samples and input samples within the given range are not required to agree.

  • labels (list or string or callable) – Labels for the different predictions in the output, e.g., Noise, P, S. If a function is passed, it will be called for every label generation and be provided with the stats of the trace that was annotated.

  • filter_args (tuple) – Arguments to be passed to obspy.filter() in annotate_stream_pre()

  • filter_kwargs (dict) – Keyword arguments to be passed to obspy.filter() in annotate_stream_pre()

  • grouping (Union[str, GroupingHelper]) – Level of grouping for annotating streams. Supports “instrument”, “channel” and “full”. Alternatively, a custom GroupingHelper can be passed.

  • allow_padding (bool) – If True, annotate will pad different windows if they have different sizes. This is useful, for example, for multi-station methods.

  • kwargs – Kwargs are passed to the superclass

annotate(stream, copy=True, **kwargs)[source]

Annotates an obspy stream using the model based on the configuration of the WaveformModel superclass. For example, for a picking model, annotate will give a characteristic function/probability function for picks over time. The annotate function contains multiple subfunctions, which can be overwritten individually by inheriting models to accommodate their requirements. These functions are:

Please see the respective documentation for details on their functionality, inputs and outputs.

Hint

If your machine is equipped with a GPU, this function will usually run faster when making use of the GPU. Just call model.cuda(). In addition, you might want to increase the batch size by passing the batch_size argument to the function. Possible values might be 2048 or 4096 (or larger if your GPU permits).

Warning

Even though the asyncio implementation itself is not parallel, this does not guarantee that only a single CPU core will be used, as the underlying libraries (pytorch, numpy, scipy, …) might be parallelised. If you need to limit the parallelism of these libraries, check their documentation, e.g., here or here. Bear in mind that a lower number of threads might occasionally improve runtime performance, as it limits overheads, e.g., here.

Parameters:
  • stream (obspy.core.Stream) – Obspy stream to annotate

  • copy (bool) – If true, copies the input stream. Otherwise, the input stream is modified in place.

  • kwargs

Returns:

Obspy stream of annotations

async annotate_async(stream, copy=True, **kwargs)[source]

annotate implementation based on asyncio Parameters as for annotate().

annotate_batch_post(batch, piggyback, argdict)[source]

Runs postprocessing on the predictions of a window for the annotate function, e.g., reformatting them. By default, returns the original prediction. Inheriting classes should overwrite this function if necessary.

Parameters:
  • batch (Tensor) – Predictions for the batch. The data type depends on the model.

  • argdict (dict[str, Any]) – Dictionary of arguments

  • piggyback (Any) – Piggyback information, by default None.

Return type:

Tensor

Returns:

Postprocessed predictions

annotate_batch_pre(batch, argdict)[source]

Runs preprocessing on batch level for the annotate function, e.g., normalization. By default, returns the input batch unmodified. Optionally, this can return a tuple of the preprocessed batch and piggyback information that is passed to annotate_batch_post(). This can for example be used to transfer normalization information. Inheriting classes should overwrite this function if necessary.

Parameters:
  • batch (Tensor) – Input batch

  • argdict (dict[str, Any]) – Dictionary of arguments

Return type:

Tensor

Returns:

Preprocessed batch and optionally piggyback information that is passed to annotate_batch_post()

annotate_stream_pre(stream, argdict)[source]

Runs preprocessing on stream level for the annotate function, e.g., filtering or resampling. By default, this function will resample all traces if a sampling rate for the model is provided. Furthermore, if a filter is specified in the class, the filter will be executed. As annotate create a copy of the input stream, this function can safely modify the stream inplace. Inheriting classes should overwrite this function if necessary. To keep the default functionality, a call to the overwritten method can be included.

Parameters:
  • stream (obspy.Stream) – Input stream

  • argdict – Dictionary of arguments

Returns:

Preprocessed stream

annotate_stream_validate(stream, argdict)[source]

Validates stream for the annotate function. This function should raise an exception if the stream is invalid. By default, this function will check if the sampling rate fits the provided one, unless it is None, and check for mismatching traces, i.e., traces covering the same time range on the same instrument with different values. Inheriting classes should overwrite this function if necessary. To keep the default functionality, a call to the overwritten method can be included.

Parameters:
  • stream (obspy.Stream) – Input stream

  • argdict – Dictionary of arguments

Returns:

None

classify(stream, parallelism=None, **kwargs)[source]

Classifies the stream. The classification can contain any information, but should be consistent with existing models.

Parameters:
  • stream (obspy.core.Stream) – Obspy stream to classify

  • kwargs

Return type:

ClassifyOutput

Returns:

A classification for the full stream, e.g., a list of picks or the source magnitude.

classify_aggregate(annotations, argdict)[source]

An aggregation function that converts the annotation streams returned by annotate() into a classification. A classification consists of a ClassifyOutput, essentialy a namespace that can hold an arbitrary set of keys. However, when implementing a model which already exists in similar form, we recommend using the same output format. For example, all pick outputs should have the same format.

Parameters:
  • annotations – Annotations returned from annotate()

  • argdict – Dictionary of arguments

Return type:

ClassifyOutput

Returns:

Classification object

async classify_async(stream, **kwargs)[source]

Async interface to the classify() function. See details there.

Return type:

ClassifyOutput

classify_stream_pre(stream, argdict)[source]

Runs preprocessing on stream level for the classify function, e.g., subselecting traces. By default, this function will simply return the input stream. In contrast to annotate_stream_pre(), this function operates on the original input stream. The stream should therefore not be modified in place. Note that annotate_stream_pre() will be executed on the output of this stream within the classify() function.

Parameters:
  • stream (obspy.Stream) – Input stream

  • argdict – Dictionary of arguments

Returns:

Preprocessed stream

property component_order
static detections_from_annotations(annotations, threshold)[source]

Converts the annotations streams for a single phase to discrete detections using a classical trigger on/off. The lower threshold is set to half the higher threshold. Detections are represented by Detection objects. The detection start_time and end_time are set to the trigger on and off times.

Parameters:
  • annotations – Stream of annotations

  • threshold – Higher threshold for trigger

Return type:

DetectionList

Returns:

List of detections

property device
get_model_args()[source]

Obtain all model parameters for saving.

Returns:

Dictionary of all parameters for a model to store during saving.

Return type:

Dict

static picks_from_annotations(annotations, threshold, phase)[source]

Converts the annotations streams for a single phase to discrete picks using a classical trigger on/off. The lower threshold is set to half the higher threshold. Picks are represented by Pick objects. The pick start_time and end_time are set to the trigger on and off times.

Parameters:
  • annotations – Stream of annotations

  • threshold – Higher threshold for trigger

  • phase – Phase to label, only relevant for output phase labelling

Return type:

PickList

Returns:

List of picks

static resample(stream, sampling_rate)[source]

Perform inplace resampling of stream to a given sampling rate.

Parameters:
  • stream (obspy.core.Stream) – Input stream

  • sampling_rate (float) – Sampling rate (sps) to resample to

static sanitize_mismatching_overlapping_records(stream)[source]

Detects if for any id the stream contains overlapping traces that do not match. If yes, all mismatching parts are removed and a warning is issued.

Parameters:

stream (obspy.core.Stream) – Input stream

Returns:

The stream object without mismatching traces

Return type:

obspy.core.Stream

stream_to_array(stream, argdict)[source]

Converts streams into a start time and a numpy array. Assumes:

  • All traces within a group can be put into an array, i.e, the strict parameter is already enforced. Every remaining gap is intended to be filled with zeros. The selection/cutting of intervals has already been done by GroupingHelper.group_stream().

  • No overlapping traces of the same component exist

  • All traces have the same sampling rate

Parameters:
  • stream (obspy.core.Stream) – Input stream

  • argdict – Dictionary of arguments

Returns:

output_times: Start times for each array

Returns:

output_data: Arrays with waveforms

class WaveformPipeline(components, citation=None)[source]

Bases: ABC

A waveform pipeline is a collection of models that together expose an annotate() and a classify() function. Examples of waveform pipelines would be multi-step picking models, conducting first a detection with one model and then a pick identification with a second model. This could also easily be extended by adding further models, e.g., estimating magnitude for each detection.

In contrast to WaveformModel, a waveform pipeline is not a pytorch module and has no forward function. This also means, that all components of a pipeline will usually be trained separately. As a rule of thumb, if the pipeline can be trained end to end, it should most likely rather be a WaveformModel. For a waveform pipeline, the annotate() and classify() functions are not automatically generated, but need to be implemented manually.

Waveform pipelines offer functionality for downloading pipeline configurations from the SeisBench repository. Similarly to SeisBenchModel, waveform pipelines expose a from_pretrained() function, that will download the configuration for a pipeline and its components.

To implement a waveform pipeline, this class needs to be subclassed. This class will throw an exception when trying to instantiate.

Warning

In contrast to SeisBenchModel this class does not yet feature versioning for weights. By default, all underlying models will use the latest, locally available version. This functionality will eventually be added. Please raise an issue on Github if you require this functionality.

Parameters:
  • components (dict [str, SeisBenchModel]) – Dictionary of components contained in the model. This should contain all models used in the pipeline.

  • citation (str, optional) – Citation reference, defaults to None.

annotate(stream, **kwargs)[source]
property citation
classify(stream, **kwargs)[source]
abstract classmethod component_classes()[source]

Returns a mapping of component names to their classes. This function needs to be defined in each pipeline, as it is required to load configurations.

Returns:

Dictionary mapping component names to their classes.

Return type:

Dict[str, SeisBenchModel classes]

property docstring
classmethod from_pretrained(name, force=False, wait_for_file=False)[source]

Load pipeline from configuration. Automatically loads all dependent pretrained models weights.

A pipeline configuration is a json file. On the top level, it has three entries:

  • “components”: A dictionary listing all contained models and the pretrained weight to use for this model.

    The instances of these classes will be created using the from_pretrained() method. The components need to match the components from the dictionary returned by component_classes().

  • “docstring”: A string documenting the pipeline. Usually also contains information on the author.

  • “model_args”: Argument dictionary passed to the init function of the pipeline. (optional)

Parameters:
  • name (str) – Configuration name

  • force (bool, optional) – Force execution of download callback, defaults to False

  • wait_for_file (bool, optional) – Whether to wait on partially downloaded files, defaults to False

Returns:

Pipeline instance

Return type:

WaveformPipeline

classmethod list_pretrained(details=False)[source]

Returns list of available configurations and optionally their docstrings.

Parameters:

details (bool) – If true, instead of a returning only a list, also return their docstrings. Note that this requires to download the json files for each model in the background and is therefore slower. Defaults to false.

Returns:

List of available weights or dict of weights and their docstrings

Return type:

list or dict

property name
hard_sigmoid(x)[source]

BasicPhaseAE

class BasicPhaseAE(in_channels=3, classes=3, phases='NPS', sampling_rate=100, **kwargs)[source]

Bases: WaveformModel

Simple AutoEncoder network architecture to pick P-/S-phases, from Woollam et al., (2019).

The following parameters are available for the annotate/classify functions:

Argument

Description

Default value

batch_size

Batch size for the model

256

overlap

Overlap between prediction windows in samples (only for window prediction models)

300

stacking

Stacking method for overlapping windows (only for window prediction models). Options are ‘max’ and ‘avg’.

avg

stride

Stride in samples (only for point prediction models)

1

strict

If true, only annotate if recordings for all components are available, otherwise impute missing data with zeros.

False

flexible_horizontal_components

If true, accepts traces with Z12 components as ZNE and vice versa. This is usually acceptable for rotationally invariant models, e.g., most picking models.

True

*_threshold

Detection threshold for the provided phase

0.3

Hint

Please note that the default parameters can be superseded by the pretrained model weights. Check model.default_args to see which parameters are overwritten.
Parameters:
  • in_channels (int) – Number of input channels, by default 3.

  • in_samples (int) – Number of input samples per channel, by default 600. The model expects input shape (in_channels, in_samples)

  • classes (int) – Number of output classes, by default 3.

  • phases (list, str) – Phase hints for the classes, by default “NPS”. Can be None.

  • sampling_rate (float) – Sampling rate of traces, by default 100.

  • kwargs – Keyword arguments passed to the constructor of WaveformModel.

annotate_batch_post(batch, piggyback, argdict)[source]

Runs postprocessing on the predictions of a window for the annotate function, e.g., reformatting them. By default, returns the original prediction. Inheriting classes should overwrite this function if necessary.

Parameters:
  • batch (Tensor) – Predictions for the batch. The data type depends on the model.

  • argdict (dict[str, Any]) – Dictionary of arguments

  • piggyback (Any) – Piggyback information, by default None.

Return type:

Tensor

Returns:

Postprocessed predictions

annotate_batch_pre(batch, argdict)[source]

Runs preprocessing on batch level for the annotate function, e.g., normalization. By default, returns the input batch unmodified. Optionally, this can return a tuple of the preprocessed batch and piggyback information that is passed to annotate_batch_post(). This can for example be used to transfer normalization information. Inheriting classes should overwrite this function if necessary.

Parameters:
  • batch (Tensor) – Input batch

  • argdict (dict[str, Any]) – Dictionary of arguments

Return type:

Tensor

Returns:

Preprocessed batch and optionally piggyback information that is passed to annotate_batch_post()

classify_aggregate(annotations, argdict)[source]

Converts the annotations to discrete thresholds using picks_from_annotations(). Trigger onset thresholds for picks are derived from the argdict at keys “[phase]_threshold”.

Parameters:
  • annotations – See description in superclass

  • argdict – See description in superclass

Return type:

ClassifyOutput

Returns:

List of picks

forward(x, logits=False)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_model_args()[source]

Obtain all model parameters for saving.

Returns:

Dictionary of all parameters for a model to store during saving.

Return type:

Dict

CRED

class CRED(in_samples=3000, in_channels=3, sampling_rate=100, original_compatible=False, **kwargs)[source]

Bases: WaveformModel

Note: There are subtle differences between the model presented in the paper (as in Figure 1) and the code on Github.

Here we follow the implementation from Github to allow for compatibility with the pretrained weights.

The following parameters are available for the annotate/classify functions:

Argument

Description

Default value

batch_size

Batch size for the model

256

overlap

Overlap between prediction windows in samples (only for window prediction models)

1500

stacking

Stacking method for overlapping windows (only for window prediction models). Options are ‘max’ and ‘avg’.

avg

stride

Stride in samples (only for point prediction models)

1

strict

If true, only annotate if recordings for all components are available, otherwise impute missing data with zeros.

False

flexible_horizontal_components

If true, accepts traces with Z12 components as ZNE and vice versa. This is usually acceptable for rotationally invariant models, e.g., most picking models.

True

detection_threshold

Detection threshold

0.5

Hint

Please note that the default parameters can be superseded by the pretrained model weights. Check model.default_args to see which parameters are overwritten.
annotate_batch_pre(batch, argdict)[source]

Runs preprocessing on batch level for the annotate function, e.g., normalization. By default, returns the input batch unmodified. Optionally, this can return a tuple of the preprocessed batch and piggyback information that is passed to annotate_batch_post(). This can for example be used to transfer normalization information. Inheriting classes should overwrite this function if necessary.

Parameters:
  • batch (Tensor) – Input batch

  • argdict (dict[str, Any]) – Dictionary of arguments

Return type:

Tensor

Returns:

Preprocessed batch and optionally piggyback information that is passed to annotate_batch_post()

classify_aggregate(annotations, argdict)[source]

Converts the annotations to discrete detections using detections_from_annotations(). Trigger onset thresholds are derived from the argdict at key “detection_threshold”.

Parameters:
  • annotations – See description in superclass

  • argdict – See description in superclass

Return type:

ClassifyOutput

Returns:

List of detections

forward(x, logits=False)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_model_args()[source]

Obtain all model parameters for saving.

Returns:

Dictionary of all parameters for a model to store during saving.

Return type:

Dict

static waveforms_to_spectrogram(batch)[source]

Transforms waveforms into spectrogram using short term fourier transform :type batch: Tensor :param batch: Waveforms with shape (channels, samples) :rtype: Tensor :return: Spectrogram with shape (channels, times, frequencies)

DeepDenoiser

class DeepDenoiser(sampling_rate=100, **kwargs)[source]

Bases: WaveformModel

The following parameters are available for the annotate/classify functions:

Argument

Description

Default value

batch_size

Batch size for the model

256

overlap

Overlap between prediction windows in samples (only for window prediction models)

1500

stacking

Stacking method for overlapping windows (only for window prediction models). Options are ‘max’ and ‘avg’.

avg

stride

Stride in samples (only for point prediction models)

1

strict

If true, only annotate if recordings for all components are available, otherwise impute missing data with zeros.

False

flexible_horizontal_components

If true, accepts traces with Z12 components as ZNE and vice versa. This is usually acceptable for rotationally invariant models, e.g., most picking models.

True

Hint

Please note that the default parameters can be superseded by the pretrained model weights. Check model.default_args to see which parameters are overwritten.
annotate_batch_post(batch, piggyback, argdict)[source]

Runs postprocessing on the predictions of a window for the annotate function, e.g., reformatting them. By default, returns the original prediction. Inheriting classes should overwrite this function if necessary.

Parameters:
  • batch (Tensor) – Predictions for the batch. The data type depends on the model.

  • argdict (dict[str, Any]) – Dictionary of arguments

  • piggyback (Any) – Piggyback information, by default None.

Return type:

Tensor

Returns:

Postprocessed predictions

annotate_batch_pre(batch, argdict)[source]

Runs preprocessing on batch level for the annotate function, e.g., normalization. By default, returns the input batch unmodified. Optionally, this can return a tuple of the preprocessed batch and piggyback information that is passed to annotate_batch_post(). This can for example be used to transfer normalization information. Inheriting classes should overwrite this function if necessary.

Parameters:
  • batch (Tensor) – Input batch

  • argdict (dict[str, Any]) – Dictionary of arguments

Return type:

Tensor

Returns:

Preprocessed batch and optionally piggyback information that is passed to annotate_batch_post()

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

static generate_label(stations)[source]
get_model_args()[source]

Obtain all model parameters for saving.

Returns:

Dictionary of all parameters for a model to store during saving.

Return type:

Dict

Depth phase models

class DepthPhaseModel(time_before=12.5, depth_levels=None, tt_args=None, qc_std=None, qc_depth=None)[source]

Bases: object

Helper class implementing all tools for determining depth from depth phases

Parameters:
  • time_before (float) – Time included before the P pick in seconds

  • depth_levels (Optional[ndarray]) – Array of depth levels to search for

  • tt_args (Optional[dict[str, Any]]) – Arguments for the TTLookup

  • qc_std (Optional[float]) – Maximum standard deviation to pass quality control. If None, no quality control is applied.

  • qc_depth (Optional[float]) – Quality control is only applied to predictions shallower than this depth. If None, quality control is applied to all depth levels.

class DepthFinder(networks, depth_model, phase_model, p_window=10, p_threshold=0.15)[source]

Bases: object

This class is a high-level interface to the depth phase models. It determines event depth at teleseismic distances based on a preliminary location. In contrast to the depth phase models, it is not provided with waveforms, but automatically downloads data through FDSN. Furthermore, it automatically determines first P arrivals using predicted travel times and a deep learning picker.

The processing consists of several steps:

  • determine available station at the time of the event

  • predict P arrivals

  • download waveforms through FDSN

  • repick P arrivals with a deep learning model

  • determine depth with deep learning based depth model

If waveforms and P wave picks are already available, it is highly recommended to directly use the underlying depth phase model instead of this helper.

Example application
networks = {"GFZ": ["GE"], "IRIS": ["II", "IU"]}  # FDSN providers and networks
depth_model = sbm.DepthPhaseTEAM.from_pretrained("original")  # A depth phase model
phase_model = sbm.PhaseNet.from_pretrained("geofon")  # A teleseismic picking model
depth_finder = DepthFinder(networks, depth_model, phase_model)
Parameters:
  • networks (dict[str, list[str]]) – Dictionary of FDSN providers and seismic network codes to query

  • depth_model (DepthPhaseModel) – The depth phase model to use

  • phase_model (WaveformModel) – The phase picking model to use for pick refinement

  • p_window (float) – Seconds around the predicted P arrival to search for actual arrival

  • p_threshold (float) – Minimum detection confidence for the primary P phase to include a record

get_depth(lat, lon, depth, origin_time)[source]

Get the depth of an event based on its preliminary latitude, longitude, depth and origin time. A depth estimate needs to be input, as it is required to predict preliminary P arrivals. This is not a circular reasoning, as depth and origin_time trade off against each other.

Parameters:
  • lat (float) – Latitude of the event

  • lon (float) – Longitude of the event

  • depth (float) – Preliminary depth of the event

  • origin_time (UTCDateTime) – Preliminary origin time of the event

Return type:

ClassifyOutput

class DepthPhaseNet(phases=('P', 'pP', 'sP'), sampling_rate=20.0, depth_phase_args=None, norm='peak', **kwargs)[source]

Bases: PhaseNet, DepthPhaseModel

The following parameters are available for the annotate/classify functions:

Argument

Description

Default value

batch_size

Batch size for the model

256

overlap

Overlap between prediction windows in samples (only for window prediction models)

1500

stacking

Stacking method for overlapping windows (only for window prediction models). Options are ‘max’ and ‘avg’.

avg

stride

Stride in samples (only for point prediction models)

1

strict

If true, only annotate if recordings for all components are available, otherwise impute missing data with zeros.

False

flexible_horizontal_components

If true, accepts traces with Z12 components as ZNE and vice versa. This is usually acceptable for rotationally invariant models, e.g., most picking models.

True

*_threshold

Detection threshold for the provided phase

0.3

blinding

Number of prediction samples to discard on each side of each window prediction

(0, 0)

Hint

Please note that the default parameters can be superseded by the pretrained model weights. Check model.default_args to see which parameters are overwritten.
annotate(stream, p_picks, **kwargs)[source]

Get depth phase probabilities curves for one event. Note that the annotations are aligned to have the P arrival at UTCDateTime(0), i.e., 1970-01-01 00:00:00. The probability curves are not normalized, there absolute value is therefore meaningless.

Warning

This class does not expose an ‘annotate_async` function directly.

Parameters:
  • stream (Stream) – Obspy stream to annotate

  • p_picks (dict[str, UTCDateTime]) – Dictionary of P pick times. Station codes will be truncated to NET.STA.LOC.

  • kwargs – All kwargs are passed to the annotate function of the superclass.

classify(stream, p_picks, distances=None, inventory=None, epicenter=None, **kwargs)[source]

Calculate depth of an event using depth phase picking and a line search over the depth axis. Can only handle one event at a time.

For the line search, the epicentral distances of the stations to the event is required. These can either be provided directly or through an inventory and the event epicenter.

Warning

This class does not expose an ‘classify_async` function directly.

Parameters:
  • stream (Stream) – Obspy stream to classify

  • p_picks (dict[str, UTCDateTime]) – Dictionary of P pick times. Station codes will be truncated to NET.STA.LOC.

  • distances (Optional[dict[str, float]]) – Dictionary of epicentral distances for the stations in degrees

  • inventory (Optional[Inventory]) – Inventory for the stations

  • epicenter (Optional[tuple[float, float]]) – (latitude, longitude) of the event epicenter

Return type:

ClassifyOutput

forward(x, logits=False)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class DepthPhaseTEAM(phases=('P', 'pP', 'sP'), classes=3, sampling_rate=20.0, depth_phase_args=None, norm='peak', **kwargs)[source]

Bases: PhaseTEAM, DepthPhaseModel

The following parameters are available for the annotate/classify functions:

Argument

Description

Default value

batch_size

Batch size for the model

256

overlap

Overlap between prediction windows in samples (only for window prediction models)

0

stacking

Stacking method for overlapping windows (only for window prediction models). Options are ‘max’ and ‘avg’.

avg

stride

Stride in samples (only for point prediction models)

1

strict

If true, only annotate if recordings for all components are available, otherwise impute missing data with zeros.

False

flexible_horizontal_components

If true, accepts traces with Z12 components as ZNE and vice versa. This is usually acceptable for rotationally invariant models, e.g., most picking models.

True

Hint

Please note that the default parameters can be superseded by the pretrained model weights. Check model.default_args to see which parameters are overwritten.
annotate(stream, p_picks, **kwargs)[source]

Get depth phase probabilities curves for one event. Note that the annotations are aligned to have the P arrival at UTCDateTime(0), i.e., 1970-01-01 00:00:00. The probability curves are not normalized, there absolute value is therefore meaningless.

Parameters:
  • stream (Stream) – Obspy stream to annotate

  • p_picks (dict[str, UTCDateTime]) – Dictionary of P pick times. Station codes will be truncated to NET.STA.LOC.

  • kwargs – All kwargs are passed to the annotate function of the superclass.

classify(stream, p_picks, distances=None, inventory=None, epicenter=None, **kwargs)[source]

Calculate depth of an event using depth phase picking and a line search over the depth axis. Can only handle one event at a time.

For the line search, the epicentral distances of the stations to the event is required. These can either be provided directly or through an inventory and the event epicenter.

Parameters:
  • stream (Stream) – Obspy stream to classify

  • p_picks (dict[str, UTCDateTime]) – Dictionary of P pick times. Station codes will be truncated to NET.STA.LOC.

  • distances (Optional[dict[str, float]]) – Dictionary of epicentral distances for the stations in degrees

  • inventory (Optional[Inventory]) – Inventory for the stations

  • epicenter (Optional[tuple[float, float]]) – (latitude, longitude) of the event epicenter

Return type:

ClassifyOutput

EQTransformer

class EQTransformer(in_channels=3, in_samples=6000, classes=2, phases='PS', lstm_blocks=3, drop_rate=0.1, original_compatible=False, sampling_rate=100, norm='std', **kwargs)[source]

Bases: WaveformModel

The EQTransformer from Mousavi et al. (2020)

Implementation adapted from the Github repository https://github.com/smousavi05/EQTransformer Assumes padding=”same” and activation=”relu” as in the pretrained EQTransformer models

By instantiating the model with from_pretrained(“original”) a binary compatible version of the original EQTransformer with the original weights from Mousavi et al. (2020) can be loaded.

The following parameters are available for the annotate/classify functions:

Argument

Description

Default value

batch_size

Batch size for the model

256

overlap

Overlap between prediction windows in samples (only for window prediction models)

3000

stacking

Stacking method for overlapping windows (only for window prediction models). Options are ‘max’ and ‘avg’.

max

stride

Stride in samples (only for point prediction models)

1

strict

If true, only annotate if recordings for all components are available, otherwise impute missing data with zeros.

False

flexible_horizontal_components

If true, accepts traces with Z12 components as ZNE and vice versa. This is usually acceptable for rotationally invariant models, e.g., most picking models.

True

*_threshold

Detection threshold for the provided phase

0.1

detection_threshold

Detection threshold

0.3

blinding

Number of prediction samples to discard on each side of each window prediction

(500, 500)

Hint

Please note that the default parameters can be superseded by the pretrained model weights. Check model.default_args to see which parameters are overwritten.
Parameters:
  • in_channels – Number of input channels, by default 3.

  • in_samples – Number of input samples per channel, by default 6000. The model expects input shape (in_channels, in_samples)

  • classes – Number of output classes, by default 2. The detection channel is not counted.

  • phases – Phase hints for the classes, by default “PS”. Can be None.

  • res_cnn_blocks – Number of residual convolutional blocks

  • lstm_blocks – Number of LSTM blocks

  • drop_rate – Dropout rate

  • original_compatible – If True, uses a few custom layers for binary compatibility with original model from Mousavi et al. (2020). This option defaults to False. It is usually recommended to stick to the default value, as the custom layers show slightly worse performance than the PyTorch builtins. The exception is when loading the original weights using from_pretrained().

  • norm – Data normalization strategy, either “peak” or “std”.

  • kwargs – Keyword arguments passed to the constructor of WaveformModel.

annotate_batch_post(batch, piggyback, argdict)[source]

Runs postprocessing on the predictions of a window for the annotate function, e.g., reformatting them. By default, returns the original prediction. Inheriting classes should overwrite this function if necessary.

Parameters:
  • batch (Tensor) – Predictions for the batch. The data type depends on the model.

  • argdict (dict[str, Any]) – Dictionary of arguments

  • piggyback (Any) – Piggyback information, by default None.

Return type:

Tensor

Returns:

Postprocessed predictions

annotate_batch_pre(batch, argdict)[source]

Runs preprocessing on batch level for the annotate function, e.g., normalization. By default, returns the input batch unmodified. Optionally, this can return a tuple of the preprocessed batch and piggyback information that is passed to annotate_batch_post(). This can for example be used to transfer normalization information. Inheriting classes should overwrite this function if necessary.

Parameters:
  • batch (Tensor) – Input batch

  • argdict (dict[str, Any]) – Dictionary of arguments

Return type:

Tensor

Returns:

Preprocessed batch and optionally piggyback information that is passed to annotate_batch_post()

classify_aggregate(annotations, argdict)[source]

Converts the annotations to discrete picks using picks_from_annotations() and to discrete detections using detections_from_annotations(). Trigger onset thresholds for picks are derived from the argdict at keys “[phase]_threshold”. Trigger onset thresholds for detections are derived from the argdict at key “detection_threshold”.

Parameters:
  • annotations – See description in superclass

  • argdict – See description in superclass

Return type:

ClassifyOutput

Returns:

List of picks, list of detections

forward(x, logits=False)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_model_args()[source]

Obtain all model parameters for saving.

Returns:

Dictionary of all parameters for a model to store during saving.

Return type:

Dict

property phases

GPD

class GPD(in_channels=3, classes=3, phases=None, eps=1e-10, sampling_rate=100, pred_sample=200, original_compatible=False, **kwargs)[source]

Bases: WaveformModel

The following parameters are available for the annotate/classify functions:

Argument

Description

Default value

batch_size

Batch size for the model

256

overlap

Overlap between prediction windows in samples (only for window prediction models)

0

stacking

Stacking method for overlapping windows (only for window prediction models). Options are ‘max’ and ‘avg’.

avg

stride

Stride in samples (only for point prediction models)

10

strict

If true, only annotate if recordings for all components are available, otherwise impute missing data with zeros.

False

flexible_horizontal_components

If true, accepts traces with Z12 components as ZNE and vice versa. This is usually acceptable for rotationally invariant models, e.g., most picking models.

True

*_threshold

Detection threshold for the provided phase

0.7

Hint

Please note that the default parameters can be superseded by the pretrained model weights. Check model.default_args to see which parameters are overwritten.
annotate_batch_pre(batch, argdict)[source]

Runs preprocessing on batch level for the annotate function, e.g., normalization. By default, returns the input batch unmodified. Optionally, this can return a tuple of the preprocessed batch and piggyback information that is passed to annotate_batch_post(). This can for example be used to transfer normalization information. Inheriting classes should overwrite this function if necessary.

Parameters:
  • batch (Tensor) – Input batch

  • argdict (dict[str, Any]) – Dictionary of arguments

Return type:

Tensor

Returns:

Preprocessed batch and optionally piggyback information that is passed to annotate_batch_post()

classify_aggregate(annotations, argdict)[source]

Converts the annotations to discrete picks using picks_from_annotations(). Trigger onset thresholds for picks are derived from the argdict at keys “[phase]_threshold”.

Parameters:
  • annotations – See description in superclass

  • argdict – See description in superclass

Return type:

ClassifyOutput

Returns:

List of picks

forward(x, logits=False)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_model_args()[source]

Obtain all model parameters for saving.

Returns:

Dictionary of all parameters for a model to store during saving.

Return type:

Dict

property phases

OBSTransformer

class OBSTransformer(lstm_blocks=2, drop_rate=0.2, original_compatible='non-conservative', **kwargs)[source]

Bases: EQTransformer

Initialize an instance of OBSTransformer model. OBSTransformer is built based on the original (non-conservative) EqTransformer model.

Warning

Creating an OBSTransformer instance does not automatically load the model weights. To do so, use OBSTransformer.from_pretrained(“obst2024”).

LFEDetect

class LFEDetect(*args, **kwargs)[source]

Bases: VariableLengthPhaseNet

This detection and phase picking model for low-frequency earthquakes (LFEs) is based on PhaseNet. Please note that, for the time being, LFE detection models do not reach the quality of EQ detection models.

The following parameters are available for the annotate/classify functions:

Argument

Description

Default value

batch_size

Batch size for the model

256

overlap

Overlap between prediction windows in samples (only for window prediction models)

1500

stacking

Stacking method for overlapping windows (only for window prediction models). Options are ‘max’ and ‘avg’.

avg

stride

Stride in samples (only for point prediction models)

1

strict

If true, only annotate if recordings for all components are available, otherwise impute missing data with zeros.

False

flexible_horizontal_components

If true, accepts traces with Z12 components as ZNE and vice versa. This is usually acceptable for rotationally invariant models, e.g., most picking models.

True

*_threshold

Detection threshold for the provided phase

0.3

blinding

Number of prediction samples to discard on each side of each window prediction

(0, 0)

Hint

Please note that the default parameters can be superseded by the pretrained model weights. Check model.default_args to see which parameters are overwritten.

PhaseNet

class PhaseNet(in_channels=3, classes=3, phases='NPS', sampling_rate=100, norm='std', **kwargs)[source]

Bases: WaveformModel

The following parameters are available for the annotate/classify functions:

Argument

Description

Default value

batch_size

Batch size for the model

256

overlap

Overlap between prediction windows in samples (only for window prediction models)

1500

stacking

Stacking method for overlapping windows (only for window prediction models). Options are ‘max’ and ‘avg’.

avg

stride

Stride in samples (only for point prediction models)

1

strict

If true, only annotate if recordings for all components are available, otherwise impute missing data with zeros.

False

flexible_horizontal_components

If true, accepts traces with Z12 components as ZNE and vice versa. This is usually acceptable for rotationally invariant models, e.g., most picking models.

True

*_threshold

Detection threshold for the provided phase

0.3

blinding

Number of prediction samples to discard on each side of each window prediction

(0, 0)

Hint

Please note that the default parameters can be superseded by the pretrained model weights. Check model.default_args to see which parameters are overwritten.
annotate_batch_post(batch, piggyback, argdict)[source]

Runs postprocessing on the predictions of a window for the annotate function, e.g., reformatting them. By default, returns the original prediction. Inheriting classes should overwrite this function if necessary.

Parameters:
  • batch (Tensor) – Predictions for the batch. The data type depends on the model.

  • argdict (dict[str, Any]) – Dictionary of arguments

  • piggyback (Any) – Piggyback information, by default None.

Return type:

Tensor

Returns:

Postprocessed predictions

annotate_batch_pre(batch, argdict)[source]

Runs preprocessing on batch level for the annotate function, e.g., normalization. By default, returns the input batch unmodified. Optionally, this can return a tuple of the preprocessed batch and piggyback information that is passed to annotate_batch_post(). This can for example be used to transfer normalization information. Inheriting classes should overwrite this function if necessary.

Parameters:
  • batch (Tensor) – Input batch

  • argdict (dict[str, Any]) – Dictionary of arguments

Return type:

Tensor

Returns:

Preprocessed batch and optionally piggyback information that is passed to annotate_batch_post()

classify_aggregate(annotations, argdict)[source]

Converts the annotations to discrete thresholds using picks_from_annotations(). Trigger onset thresholds for picks are derived from the argdict at keys “[phase]_threshold”.

Parameters:
  • annotations – See description in superclass

  • argdict – See description in superclass

Return type:

ClassifyOutput

Returns:

List of picks

forward(x, logits=False)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

classmethod from_pretrained_expand(name, version_str='latest', update=False, force=False, wait_for_file=False)[source]

Load pretrained model with weights and copy the input channel weights that match the Z component to a new, 4th dimension that is used to process the hydrophone component of the input trace.

For further instructions, see from_pretrained(). This method differs from from_pretrained() in that it does not call helper functions to load the model weights. Instead it covers the same logic and, in addition, takes intermediate steps to insert a new in_channels dimension to the loaded model and copy weights.

Parameters:
  • name (str) – Model name prefix.

  • version_str (str) – Version of the weights to load. Either a version string or “latest”. The “latest” model is the model with the highest version number.

  • force (bool, optional) – Force execution of download callback, defaults to False

  • update (bool) – If true, downloads potential new weights file and config from the remote repository. The old files are retained with their version suffix.

  • wait_for_file (bool, optional) – Whether to wait on partially downloaded files, defaults to False

Returns:

Model instance

Return type:

SeisBenchModel

get_model_args()[source]

Obtain all model parameters for saving.

Returns:

Dictionary of all parameters for a model to store during saving.

Return type:

Dict

class PhaseNetLight(in_channels=3, classes=3, phases='NPS', sampling_rate=100, norm='std', **kwargs)[source]

Bases: PhaseNet

The following parameters are available for the annotate/classify functions:

Argument

Description

Default value

batch_size

Batch size for the model

256

overlap

Overlap between prediction windows in samples (only for window prediction models)

1500

stacking

Stacking method for overlapping windows (only for window prediction models). Options are ‘max’ and ‘avg’.

avg

stride

Stride in samples (only for point prediction models)

1

strict

If true, only annotate if recordings for all components are available, otherwise impute missing data with zeros.

False

flexible_horizontal_components

If true, accepts traces with Z12 components as ZNE and vice versa. This is usually acceptable for rotationally invariant models, e.g., most picking models.

True

*_threshold

Detection threshold for the provided phase

0.3

blinding

Number of prediction samples to discard on each side of each window prediction

(0, 0)

Hint

Please note that the default parameters can be superseded by the pretrained model weights. Check model.default_args to see which parameters are overwritten.

PhaseNetLight is a slightly reduced version of PhaseNet. It is primarily included for compatibility reasons with an earlier, incomplete implementation of PhaseNet in SeisBench prior to v0.3.

forward(x, logits=False)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class VariableLengthPhaseNet(in_samples=600, in_channels=3, classes=3, phases='PSN', sampling_rate=100, norm='peak', norm_axis=(-1,), output_activation='softmax', empty=False, **kwargs)[source]

Bases: PhaseNet

This version of PhaseNet has extended functionality:

  • The number of input samples can be changed. However, the number of layers in the model does not change, i.e., the receptive field stays unchanged. In addition, models will usually not perform well if applied to a different input length than trained on.

  • Output activation can be switched between softmax (all components sum to 1, i.e., no overlapping phases) and sigmoid (each component is normed individually between 0 and 1).

  • The axis for normalizing the waveforms before passing them to the model can be specified explicitly.

The following parameters are available for the annotate/classify functions:

Argument

Description

Default value

batch_size

Batch size for the model

256

overlap

Overlap between prediction windows in samples (only for window prediction models)

1500

stacking

Stacking method for overlapping windows (only for window prediction models). Options are ‘max’ and ‘avg’.

avg

stride

Stride in samples (only for point prediction models)

1

strict

If true, only annotate if recordings for all components are available, otherwise impute missing data with zeros.

False

flexible_horizontal_components

If true, accepts traces with Z12 components as ZNE and vice versa. This is usually acceptable for rotationally invariant models, e.g., most picking models.

True

*_threshold

Detection threshold for the provided phase

0.3

blinding

Number of prediction samples to discard on each side of each window prediction

(0, 0)

Hint

Please note that the default parameters can be superseded by the pretrained model weights. Check model.default_args to see which parameters are overwritten.
annotate_batch_pre(batch, argdict)[source]

Runs preprocessing on batch level for the annotate function, e.g., normalization. By default, returns the input batch unmodified. Optionally, this can return a tuple of the preprocessed batch and piggyback information that is passed to annotate_batch_post(). This can for example be used to transfer normalization information. Inheriting classes should overwrite this function if necessary.

Parameters:
  • batch (Tensor) – Input batch

  • argdict (dict[str, Any]) – Dictionary of arguments

Return type:

Tensor

Returns:

Preprocessed batch and optionally piggyback information that is passed to annotate_batch_post()

forward(x, logits=False)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_model_args()[source]

Obtain all model parameters for saving.

Returns:

Dictionary of all parameters for a model to store during saving.

Return type:

Dict

PickBlue

PickBlue(base='phasenet', **kwargs)[source]

Initialize a PickBlue model. All kwargs are passed to from_pretrained.

Parameters:

base (str) – Base model to use. Currently, supports either eqtransformer or phasenet.