from typing import Any
import torch
import torch.nn as nn
import seisbench.util as sbu
from .base import WaveformModel
[docs]
class GPD(WaveformModel):
"""
.. document_args:: seisbench.models GPD
"""
_annotate_args = WaveformModel._annotate_args.copy()
_annotate_args["*_threshold"] = ("Detection threshold for the provided phase", 0.7)
_annotate_args["stride"] = (_annotate_args["stride"][0], 10)
def __init__(
self,
in_channels=3,
classes=3,
phases=None,
eps=1e-10,
sampling_rate=100,
pred_sample=200,
original_compatible=False,
**kwargs,
):
citation = (
"Ross, Z. E., Meier, M.-A., Hauksson, E., & Heaton, T. H. (2018). "
"Generalized Seismic Phase Detection with Deep Learning. "
"ArXiv:1805.01075 [Physics]. https://arxiv.org/abs/1805.01075"
)
super().__init__(
citation=citation,
output_type="point",
in_samples=400,
pred_sample=pred_sample,
labels=phases,
sampling_rate=sampling_rate,
**kwargs,
)
self.in_channels = in_channels
self.classes = classes
self.eps = eps
self._phases = phases
self.original_compatible = original_compatible
if phases is not None and len(phases) != classes:
raise ValueError(
f"Number of classes ({classes}) does not match number of phases ({len(phases)})."
)
self.conv1 = nn.Conv1d(in_channels, 32, 21, padding=10)
self.bn1 = nn.BatchNorm1d(32, eps=1e-3)
self.conv2 = nn.Conv1d(32, 64, 15, padding=7)
self.bn2 = nn.BatchNorm1d(64, eps=1e-3)
self.conv3 = nn.Conv1d(64, 128, 11, padding=5)
self.bn3 = nn.BatchNorm1d(128, eps=1e-3)
self.conv4 = nn.Conv1d(128, 256, 9, padding=4)
self.bn4 = nn.BatchNorm1d(256, eps=1e-3)
self.fc1 = nn.Linear(6400, 200)
self.bn5 = nn.BatchNorm1d(200, eps=1e-3)
self.fc2 = nn.Linear(200, 200)
self.bn6 = nn.BatchNorm1d(200, eps=1e-3)
self.fc3 = nn.Linear(200, classes)
self.activation = torch.relu
self.pool = nn.MaxPool1d(2, 2)
[docs]
def forward(self, x, logits=False):
# Max normalization
x = x / (
torch.max(
torch.max(torch.abs(x), dim=-1, keepdims=True)[0], dim=-2, keepdims=True
)[0]
+ self.eps
)
x = self.pool(self.activation(self.bn1(self.conv1(x))))
x = self.pool(self.activation(self.bn2(self.conv2(x))))
x = self.pool(self.activation(self.bn3(self.conv3(x))))
x = self.pool(self.activation(self.bn4(self.conv4(x))))
if self.original_compatible:
# Permutation is required to be consistent with the following fully connected layer
x = x.permute(0, 2, 1)
x = torch.flatten(x, 1)
x = self.activation(self.bn5(self.fc1(x)))
x = self.activation(self.bn6(self.fc2(x)))
x = self.fc3(x)
if logits:
return x
else:
if self.classes == 1:
return torch.sigmoid(x)
else:
return torch.softmax(x, -1)
@property
def phases(self):
if self._phases is not None:
return self._phases
else:
return list(range(self.classes))
[docs]
def annotate_batch_pre(
self, batch: torch.Tensor, argdict: dict[str, Any]
) -> torch.Tensor:
return batch - batch.mean(axis=-1, keepdims=True)
[docs]
def classify_aggregate(self, annotations, argdict) -> sbu.ClassifyOutput:
"""
Converts the annotations to discrete picks using
:py:func:`~seisbench.models.base.WaveformModel.picks_from_annotations`.
Trigger onset thresholds for picks are derived from the argdict at keys "[phase]_threshold".
:param annotations: See description in superclass
:param argdict: See description in superclass
:return: List of picks
"""
picks = sbu.PickList()
for phase in self.phases:
if phase == "N":
# Don't pick noise
continue
picks += self.picks_from_annotations(
annotations.select(channel=f"{self.__class__.__name__}_{phase}"),
argdict.get(
f"{phase}_threshold", self._annotate_args.get("*_threshold")[1]
),
phase,
)
picks = sbu.PickList(sorted(picks))
return sbu.ClassifyOutput(self.name, picks=picks)
[docs]
def get_model_args(self):
model_args = super().get_model_args()
for key in [
"citation",
"in_samples",
"output_type",
"default_args",
"pred_sample",
"labels",
"sampling_rate",
]:
del model_args[key]
model_args["sampling_rate"] = self.sampling_rate
model_args["in_channels"] = self.in_channels
model_args["classes"] = self.classes
model_args["phases"] = self._phases
model_args["eps"] = self.eps
model_args["sampling_rate"] = self.sampling_rate
model_args["pred_sample"] = self.pred_sample
model_args["original_compatible"] = self.original_compatible
return model_args