import json
from typing import Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging import version
import seisbench.util as sbu
from .base import Conv1dSame, WaveformModel, _cache_migration_v0_v3
[docs]
class PhaseNet(WaveformModel):
"""
.. document_args:: seisbench.models PhaseNet
:param filter_factor: Increase the number of filters used in each layer by this factor compared to the original
PhaseNet. Based on PhaseNetWC proposed by Naoi et al. (2024)
"""
_annotate_args = WaveformModel._annotate_args.copy()
_annotate_args["*_threshold"] = ("Detection threshold for the provided phase", 0.3)
_annotate_args["blinding"] = (
"Number of prediction samples to discard on each side of each window prediction",
(0, 0),
)
_annotate_args["overlap"] = (_annotate_args["overlap"][0], 1500)
_weight_warnings = [
(
"ethz|geofon|instance|iquique|lendb|neic|scedc|stead",
"1",
"The normalization for this weight version is incorrect and will lead to degraded performance. "
"Run from_pretrained with update=True once to solve this issue. "
"For details, see https://github.com/seisbench/seisbench/pull/188 .",
),
(
"diting",
"1",
"This version of the Diting picker uses an incorrect sampling rate (100 Hz instead of 50 Hz). "
"Run from_pretrained with update=True once to solve this issue. "
"For details, see https://github.com/JUNZHU-SEIS/USTC-Pickers/issues/1 .",
),
]
def __init__(
self,
in_channels=3,
classes=3,
phases="NPS",
sampling_rate=100,
norm="std",
filter_factor: int = 1,
**kwargs,
):
citation = (
"Zhu, W., & Beroza, G. C. (2019). "
"PhaseNet: a deep-neural-network-based seismic arrival-time picking method. "
"Geophysical Journal International, 216(1), 261-273. "
"https://doi.org/10.1093/gji/ggy423"
)
# PickBlue options
for option in ("norm_amp_per_comp", "norm_detrend"):
if option in kwargs:
setattr(self, option, kwargs[option])
del kwargs[option]
else:
setattr(self, option, False)
super().__init__(
citation=citation,
in_samples=3001,
output_type="array",
pred_sample=(0, 3001),
labels=phases,
sampling_rate=sampling_rate,
**kwargs,
)
self.in_channels = in_channels
self.classes = classes
self.norm = norm
self.filter_factor = filter_factor
self.depth = 5
self.kernel_size = 7
self.stride = 4
self.filters_root = 8
self.activation = torch.relu
self.inc = nn.Conv1d(
self.in_channels,
self.filters_root * filter_factor,
self.kernel_size,
padding="same",
)
self.in_bn = nn.BatchNorm1d(self.filters_root * filter_factor, eps=1e-3)
self.down_branch = nn.ModuleList()
self.up_branch = nn.ModuleList()
last_filters = self.filters_root * filter_factor
for i in range(self.depth):
filters = int(2**i * self.filters_root) * filter_factor
conv_same = nn.Conv1d(
last_filters, filters, self.kernel_size, padding="same", bias=False
)
last_filters = filters
bn1 = nn.BatchNorm1d(filters, eps=1e-3)
if i == self.depth - 1:
conv_down = None
bn2 = None
else:
if i in [1, 2, 3]:
padding = 0 # Pad manually
else:
padding = self.kernel_size // 2
conv_down = nn.Conv1d(
filters,
filters,
self.kernel_size,
self.stride,
padding=padding,
bias=False,
)
bn2 = nn.BatchNorm1d(filters, eps=1e-3)
self.down_branch.append(nn.ModuleList([conv_same, bn1, conv_down, bn2]))
for i in range(self.depth - 1):
filters = int(2 ** (3 - i) * self.filters_root) * filter_factor
conv_up = nn.ConvTranspose1d(
last_filters, filters, self.kernel_size, self.stride, bias=False
)
last_filters = filters
bn1 = nn.BatchNorm1d(filters, eps=1e-3)
conv_same = nn.Conv1d(
2 * filters, filters, self.kernel_size, padding="same", bias=False
)
bn2 = nn.BatchNorm1d(filters, eps=1e-3)
self.up_branch.append(nn.ModuleList([conv_up, bn1, conv_same, bn2]))
self.out = nn.Conv1d(last_filters, self.classes, 1, padding="same")
self.softmax = torch.nn.Softmax(dim=1)
[docs]
def forward(self, x, logits=False):
x = self.activation(self.in_bn(self.inc(x)))
skips = []
for i, (conv_same, bn1, conv_down, bn2) in enumerate(self.down_branch):
x = self.activation(bn1(conv_same(x)))
if conv_down is not None:
skips.append(x)
if i == 1:
x = F.pad(x, (2, 3), "constant", 0)
elif i == 2:
x = F.pad(x, (1, 3), "constant", 0)
elif i == 3:
x = F.pad(x, (2, 3), "constant", 0)
x = self.activation(bn2(conv_down(x)))
for i, ((conv_up, bn1, conv_same, bn2), skip) in enumerate(
zip(self.up_branch, skips[::-1])
):
x = self.activation(bn1(conv_up(x)))
x = x[:, :, 1:-2]
x = self._merge_skip(skip, x)
x = self.activation(bn2(conv_same(x)))
x = self.out(x)
if logits:
return x
else:
return self.softmax(x)
@staticmethod
def _merge_skip(skip, x):
offset = (x.shape[-1] - skip.shape[-1]) // 2
x_resize = x[:, :, offset : offset + skip.shape[-1]]
return torch.cat([skip, x_resize], dim=1)
[docs]
def annotate_batch_pre(
self, batch: torch.Tensor, argdict: dict[str, Any]
) -> torch.Tensor:
batch = batch - batch.mean(axis=-1, keepdims=True)
if self.norm_detrend:
batch = sbu.torch_detrend(batch)
if self.norm_amp_per_comp:
peak = batch.abs().max(axis=-1, keepdims=True)[0]
batch = batch / (peak + 1e-10)
else:
if self.norm == "std":
std = batch.std(axis=-1, keepdims=True)
batch = batch / (std + 1e-10)
elif self.norm == "peak":
peak = batch.abs().max(axis=-1, keepdims=True)[0]
batch = batch / (peak + 1e-10)
return batch
[docs]
def annotate_batch_post(
self, batch: torch.Tensor, piggyback: Any, argdict: dict[str, Any]
) -> torch.Tensor:
# Transpose predictions to correct shape
batch = torch.transpose(batch, -1, -2)
prenan, postnan = argdict.get(
"blinding", self._annotate_args.get("blinding")[1]
)
if prenan > 0:
batch[:, :prenan] = np.nan
if postnan > 0:
batch[:, -postnan:] = np.nan
return batch
[docs]
def classify_aggregate(self, annotations, argdict) -> sbu.ClassifyOutput:
"""
Converts the annotations to discrete thresholds using
:py:func:`~seisbench.models.base.WaveformModel.picks_from_annotations`.
Trigger onset thresholds for picks are derived from the argdict at keys "[phase]_threshold".
:param annotations: See description in superclass
:param argdict: See description in superclass
:return: List of picks
"""
picks = sbu.PickList()
for phase in self.labels:
if phase == "N":
# Don't pick noise
continue
picks += self.picks_from_annotations(
annotations.select(channel=f"{self.__class__.__name__}_{phase}"),
argdict.get(
f"{phase}_threshold", self._annotate_args.get("*_threshold")[1]
),
phase,
)
picks = sbu.PickList(sorted(picks))
return sbu.ClassifyOutput(self.name, picks=picks)
[docs]
def get_model_args(self):
model_args = super().get_model_args()
for key in [
"citation",
"in_samples",
"output_type",
"default_args",
"pred_sample",
"labels",
]:
del model_args[key]
model_args["in_channels"] = self.in_channels
model_args["classes"] = self.classes
model_args["phases"] = self.labels
model_args["sampling_rate"] = self.sampling_rate
model_args["norm"] = self.norm
model_args["norm_amp_per_comp"] = self.norm_amp_per_comp
model_args["norm_detrend"] = self.norm_detrend
return model_args
[docs]
@classmethod
def from_pretrained_expand(
cls, name, version_str="latest", update=False, force=False, wait_for_file=False
):
"""
Load pretrained model with weights and copy the input channel weights that match the Z component to a new,
4th dimension that is used to process the hydrophone component of the input trace.
For further instructions, see :py:func:`~seisbench.models.base.SeisBenchModel.from_pretrained`. This method
differs from :py:func:`~seisbench.models.base.SeisBenchModel.from_pretrained` in that it does not call helper
functions to load the model weights. Instead it covers the same logic and, in addition, takes intermediate
steps to insert a new `in_channels` dimension to the loaded model and copy weights.
:param name: Model name prefix.
:type name: str
:param version_str: Version of the weights to load. Either a version string or "latest". The "latest" model is
the model with the highest version number.
:type version_str: str
:param force: Force execution of download callback, defaults to False
:type force: bool, optional
:param update: If true, downloads potential new weights file and config from the remote repository.
The old files are retained with their version suffix.
:type update: bool
:param wait_for_file: Whether to wait on partially downloaded files, defaults to False
:type wait_for_file: bool, optional
:return: Model instance
:rtype: SeisBenchModel
"""
cls._cleanup_local_repository()
_cache_migration_v0_v3()
if version_str == "latest":
versions = cls.list_versions(name, remote=update)
# Always query remote versions if cache is empty
if len(versions) == 0:
versions = cls.list_versions(name, remote=True)
if len(versions) == 0:
raise ValueError(f"No version for weight '{name}' available.")
version_str = max(versions, key=version.parse)
weight_path, metadata_path = cls._pretrained_path(name, version_str)
cls._ensure_weight_files(
name, version_str, weight_path, metadata_path, force, wait_for_file
)
if metadata_path.is_file():
with open(metadata_path, "r") as f:
weights_metadata = json.load(f)
else:
weights_metadata = {}
model_args = weights_metadata.get("model_args", {})
model_args["in_channels"] = 4
cls._check_version_requirement(weights_metadata)
model = cls(**model_args)
model._weights_metadata = weights_metadata
model._parse_metadata()
state_dict = torch.load(weight_path)
old_weight = state_dict["inc.weight"]
state_dict["inc.weight"] = torch.zeros(
old_weight.shape[0], old_weight.shape[1] + 1, old_weight.shape[2]
).type_as(old_weight)
state_dict["inc.weight"][:, :3, ...] = old_weight
state_dict["inc.weight"][:, 3, ...] = old_weight[:, 0, ...]
model.load_state_dict(state_dict)
return model
[docs]
class PhaseNetLight(PhaseNet):
"""
.. document_args:: seisbench.models PhaseNetLight
PhaseNetLight is a slightly reduced version of PhaseNet. It is primarily included for compatibility reasons with
an earlier, incomplete implementation of PhaseNet in SeisBench prior to v0.3.
"""
_weight_warnings = [
(
"ethz|geofon|instance|iquique|lendb|neic|scedc|stead",
"1",
"The normalization for this weight version is incorrect and will lead to degraded performance. "
"Run from_pretrained with update=True once to solve this issue. "
"For details, see https://github.com/seisbench/seisbench/pull/188 .",
),
]
def __init__(
self,
in_channels=3,
classes=3,
phases="NPS",
sampling_rate=100,
norm="std",
**kwargs,
):
citation = (
"Zhu, W., & Beroza, G. C. (2019). "
"PhaseNet: a deep-neural-network-based seismic arrival-time picking method. "
"Geophysical Journal International, 216(1), 261-273. "
"https://doi.org/10.1093/gji/ggy423"
)
# PickBlue options
for option in ("norm_amp_per_comp", "norm_detrend"):
if option in kwargs:
setattr(self, option, kwargs[option])
del kwargs[option]
else:
setattr(self, option, False)
# Skip super call in favour of super-super class
WaveformModel.__init__(
self,
citation=citation,
in_samples=3001,
output_type="array",
pred_sample=(0, 3001),
labels=phases,
sampling_rate=sampling_rate,
**kwargs,
)
self.in_channels = in_channels
self.classes = classes
self.norm = norm
self.kernel_size = 7
self.stride = 4
self.activation = torch.relu
self.inc = nn.Conv1d(self.in_channels, 8, 1)
self.in_bn = nn.BatchNorm1d(8)
self.conv1 = Conv1dSame(8, 11, self.kernel_size, self.stride)
self.bnd1 = nn.BatchNorm1d(11)
self.conv2 = Conv1dSame(11, 16, self.kernel_size, self.stride)
self.bnd2 = nn.BatchNorm1d(16)
self.conv3 = Conv1dSame(16, 22, self.kernel_size, self.stride)
self.bnd3 = nn.BatchNorm1d(22)
self.conv4 = Conv1dSame(22, 32, self.kernel_size, self.stride)
self.bnd4 = nn.BatchNorm1d(32)
self.up1 = nn.ConvTranspose1d(
32, 22, self.kernel_size, self.stride, padding=self.conv4.padding
)
self.bnu1 = nn.BatchNorm1d(22)
self.up2 = nn.ConvTranspose1d(
44,
16,
self.kernel_size,
self.stride,
padding=self.conv3.padding,
output_padding=1,
)
self.bnu2 = nn.BatchNorm1d(16)
self.up3 = nn.ConvTranspose1d(
32, 11, self.kernel_size, self.stride, padding=self.conv2.padding
)
self.bnu3 = nn.BatchNorm1d(11)
self.up4 = nn.ConvTranspose1d(22, 8, self.kernel_size, self.stride, padding=3)
self.bnu4 = nn.BatchNorm1d(8)
self.out = nn.ConvTranspose1d(16, self.classes, 1)
self.softmax = torch.nn.Softmax(dim=1)
[docs]
def forward(self, x, logits=False):
x_in = self.activation(self.in_bn(self.inc(x)))
x1 = self.activation(self.bnd1(self.conv1(x_in)))
x2 = self.activation(self.bnd2(self.conv2(x1)))
x3 = self.activation(self.bnd3(self.conv3(x2)))
x4 = self.activation(self.bnd4(self.conv4(x3)))
x = torch.cat([self.activation(self.bnu1(self.up1(x4))), x3], dim=1)
x = torch.cat([self.activation(self.bnu2(self.up2(x))), x2], dim=1)
x = torch.cat([self.activation(self.bnu3(self.up3(x))), x1], dim=1)
x = torch.cat([self.activation(self.bnu4(self.up4(x))), x_in], dim=1)
x = self.out(x)
if logits:
return x
else:
return self.softmax(x)
[docs]
class VariableLengthPhaseNet(PhaseNet):
"""
This version of PhaseNet has extended functionality:
- The number of input samples can be changed.
However, the number of layers in the model does not change, i.e., the receptive field stays unchanged.
In addition, models will usually not perform well if applied to a different input length than trained on.
- Output activation can be switched between softmax (all components sum to 1, i.e., no overlapping phases)
and sigmoid (each component is normed individually between 0 and 1).
- The axis for normalizing the waveforms before passing them to the model can be specified explicitly.
.. document_args:: seisbench.models VariableLengthPhaseNet
"""
_annotate_args = PhaseNet._annotate_args.copy()
_annotate_args["overlap"] = (_annotate_args["overlap"][0], 0.5)
def __init__(
self,
in_samples=600,
in_channels=3,
classes=3,
phases="PSN",
sampling_rate=100,
norm="peak",
norm_axis=(-1,),
output_activation="softmax",
empty=False,
**kwargs,
):
citation = (
"Zhu, W., & Beroza, G. C. (2019). "
"PhaseNet: a deep-neural-network-based seismic arrival-time picking method. "
"Geophysical Journal International, 216(1), 261-273. "
"https://doi.org/10.1093/gji/ggy423"
)
WaveformModel.__init__(
self,
citation=citation,
in_samples=in_samples,
output_type="array",
pred_sample=(0, in_samples),
labels=phases,
sampling_rate=sampling_rate,
**kwargs,
)
self.in_channels = in_channels
self.classes = classes
self.norm = norm
self.norm_axis = tuple(norm_axis)
self.depth = 5
self.kernel_size = 7
self.stride = 4
self.filters_root = 8
self.activation = torch.relu
if output_activation == "softmax":
self.output_activation = torch.nn.Softmax(dim=1)
elif output_activation == "sigmoid":
self.output_activation = torch.nn.Sigmoid()
else:
raise ValueError("Output activation needs to be softmax or sigmoid")
# PhaseNet extra arguments
self.norm_amp_per_comp = False
self.norm_detrend = False
if empty:
self.inc = None
self.in_bn = None
self.down_branch = None
self.up_branch = None
self.out = None
else:
self.inc = nn.Conv1d(
self.in_channels, self.filters_root, self.kernel_size, padding="same"
)
self.in_bn = nn.BatchNorm1d(8, eps=1e-3)
self.down_branch = nn.ModuleList()
self.up_branch = nn.ModuleList()
last_filters = self.filters_root
for i in range(self.depth):
filters = int(2**i * self.filters_root)
conv_same = nn.Conv1d(
last_filters, filters, self.kernel_size, padding="same", bias=False
)
last_filters = filters
bn1 = nn.BatchNorm1d(filters, eps=1e-3)
if i == self.depth - 1:
conv_down = None
bn2 = None
else:
padding = self.kernel_size // 2
conv_down = nn.Conv1d(
filters,
filters,
self.kernel_size,
self.stride,
padding=padding,
bias=False,
)
bn2 = nn.BatchNorm1d(filters, eps=1e-3)
self.down_branch.append(nn.ModuleList([conv_same, bn1, conv_down, bn2]))
for i in range(self.depth - 1):
filters = int(2 ** (3 - i) * self.filters_root)
conv_up = nn.ConvTranspose1d(
last_filters, filters, self.kernel_size, self.stride, bias=False
)
last_filters = filters
bn1 = nn.BatchNorm1d(filters, eps=1e-3)
conv_same = nn.Conv1d(
2 * filters, filters, self.kernel_size, padding="same", bias=False
)
bn2 = nn.BatchNorm1d(filters, eps=1e-3)
self.up_branch.append(nn.ModuleList([conv_up, bn1, conv_same, bn2]))
self.out = nn.Conv1d(last_filters, self.classes, 1, padding="same")
[docs]
def forward(self, x, logits=False):
x = self._forward_single(x)
if logits:
return x
else:
return self.output_activation(x)
def _forward_single(self, x):
x = self.activation(self.in_bn(self.inc(x)))
skips = []
for i, (conv_same, bn1, conv_down, bn2) in enumerate(self.down_branch):
x = self.activation(bn1(conv_same(x)))
if conv_down is not None:
skips.append(x)
x = self.activation(bn2(conv_down(x)))
for i, ((conv_up, bn1, conv_same, bn2), skip) in enumerate(
zip(self.up_branch, skips[::-1])
):
x = self.activation(bn1(conv_up(x)))
x = self._merge_skip(skip, x)
x = self.activation(bn2(conv_same(x)))
return self.out(x)
[docs]
def annotate_batch_pre(
self, batch: torch.Tensor, argdict: dict[str, Any]
) -> torch.Tensor:
batch = batch - batch.mean(axis=-1, keepdims=True)
if self.norm_detrend:
batch = sbu.torch_detrend(batch)
if self.norm == "std":
std = batch.std(axis=self.norm_axis, keepdims=True)
batch = batch / (std + 1e-10)
elif self.norm == "peak":
peak = batch.abs().amax(axis=self.norm_axis, keepdims=True)
batch = batch / (peak + 1e-10)
return batch
[docs]
def get_model_args(self):
model_args = super().get_model_args()
model_args["in_samples"] = self.in_samples
model_args["in_channels"] = self.in_channels
model_args["classes"] = self.classes
model_args["phases"] = self.labels
model_args["sampling_rate"] = self.sampling_rate
model_args["norm"] = self.norm
model_args["norm_axis"] = self.norm_axis
model_args["output_activation"] = (
self.output_activation.__class__.__name__.lower()
)
return model_args