from typing import Any
import numpy as np
import obspy
import torch
import torch.nn as nn
import seisbench.util as sbu
from .eqtransformer import (
EQTransformer,
SeqSelfAttention,
Decoder,
Encoder,
ResCNNStack,
Transformer,
ActivationLSTMCell,
CustomLSTM,
)
[docs]
class EQTP(EQTransformer):
"""
The EQTP from Peng et al. (2025)
It is an extended version of the EQTransformer model, which builds upon its
phase picking capabilities by adding P-wave polarity determination functionality.
This model is designed for processing three-component seismic waveform data,
and can simultaneously output picking results for phases such as P-waves/S-waves
and polarity determination results for P-waves (Up U/Down D/Unknown N).
Implementation is adapted from the EQTransformer with SeisBench GitHub repository (https://github.com/seisbench/seisbench).
The EQTP model can be instantiated via the `from_pretrained("ncedc")` method.
.. document_args:: seisbench.models EQTP
"""
_annotate_args = EQTransformer._annotate_args.copy()
_annotate_args["polarity_threshold"] = ("Polarity threshold", 0.3)
def __init__(
self,
in_samples=12000,
**kwargs,
):
# Update citation for EQTP
citation = (
"Peng L, Li L, Zeng X. "
"A Microseismic Phase Picking and Polarity Determination Model Based on the Earthquake Transformer[J ]. Applied Sciences, 2025, 15(7): 3424."
"https://doi.org/10.3390/app15073424"
)
# Initialize parent class
super().__init__(in_samples=in_samples, **kwargs)
# Override citation and labels for EQTP
self._citation = citation
self.labels = ["Polarity_U", "Polarity_D"] + list(self.phases)
# Override EQTP specific filter configurations
self.filters = [8, 16, 16, 32, 64]
self.kernel_sizes = [11, 9, 7, 7, 3]
self.res_cnn_kernels = [3, 3, 3, 3, 2]
# Rebuild encoder and res_cnn_stack with EQTP parameters
self._rebuild_eqtp_components()
# Remove detection branch (EQTP doesn't have detection)
del self.decoder_d
del self.conv_d
# Override transformer input size to 64 for EQTP
eps = 1e-7 if self.original_compatible else 1e-5
self.transformer_d0 = self._create_transformer(64, self.drop_rate, eps)
self.transformer_d = self._create_transformer(64, self.drop_rate, eps)
# Add polarity branches
self._add_polarity_branches(self.original_compatible, eps)
# Rebuild picking branches with 64 input size
self._rebuild_picking_branches(self.original_compatible, eps)
def _rebuild_eqtp_components(self):
self.encoder = Encoder(
input_channels=self.in_channels,
filters=self.filters,
kernel_sizes=self.kernel_sizes,
in_samples=self.in_samples,
)
self.res_cnn_stack = ResCNNStack(
kernel_sizes=self.res_cnn_kernels,
filters=self.filters[-1],
drop_rate=self.drop_rate,
)
def _create_transformer(self, input_size, drop_rate, eps):
return Transformer(input_size=input_size, drop_rate=drop_rate, eps=eps)
def _add_polarity_branches(self, original_compatible, eps):
self.pol_lstms = []
self.pol_attentions = []
self.pol_decoders = []
self.pol_convs = []
for _ in range(2): # Two polarity branches (U and D)
if original_compatible == "conservative":
lstm = CustomLSTM(ActivationLSTMCell, 64, 64, bidirectional=False)
else:
lstm = nn.LSTM(64, 64, bidirectional=False)
self.pol_lstms.append(lstm)
attention = SeqSelfAttention(input_size=64, attention_width=3, eps=eps)
self.pol_attentions.append(attention)
decoder = Decoder(
input_channels=64,
filters=self.filters[::-1],
kernel_sizes=self.kernel_sizes[::-1],
out_samples=self.in_samples,
original_compatible=original_compatible,
)
self.pol_decoders.append(decoder)
conv = nn.Conv1d(
in_channels=self.filters[0], out_channels=1, kernel_size=11, padding=5
)
self.pol_convs.append(conv)
self.pol_lstms = nn.ModuleList(self.pol_lstms)
self.pol_attentions = nn.ModuleList(self.pol_attentions)
self.pol_decoders = nn.ModuleList(self.pol_decoders)
self.pol_convs = nn.ModuleList(self.pol_convs)
def _rebuild_picking_branches(self, original_compatible, eps):
# Clear existing picking branches
self.pick_lstms = nn.ModuleList()
self.pick_attentions = nn.ModuleList()
self.pick_decoders = nn.ModuleList()
self.pick_convs = nn.ModuleList()
for _ in range(self.classes):
if original_compatible == "conservative":
lstm = CustomLSTM(ActivationLSTMCell, 64, 64, bidirectional=False)
else:
lstm = nn.LSTM(64, 64, bidirectional=False)
self.pick_lstms.append(lstm)
attention = SeqSelfAttention(input_size=64, attention_width=3, eps=eps)
self.pick_attentions.append(attention)
decoder = Decoder(
input_channels=64,
filters=self.filters[::-1],
kernel_sizes=self.kernel_sizes[::-1],
out_samples=self.in_samples,
original_compatible=original_compatible,
)
self.pick_decoders.append(decoder)
conv = nn.Conv1d(
in_channels=self.filters[0], out_channels=1, kernel_size=11, padding=5
)
self.pick_convs.append(conv)
self.pick_lstms = nn.ModuleList(self.pick_lstms)
self.pick_attentions = nn.ModuleList(self.pick_attentions)
self.pick_decoders = nn.ModuleList(self.pick_decoders)
self.pick_convs = nn.ModuleList(self.pick_convs)
[docs]
def forward(self, x, logits=False):
assert x.ndim == 3
assert x.shape[1:] == (self.in_channels, self.in_samples)
# Shared encoder part
x = self.encoder(x)
x = self.res_cnn_stack(x)
# Skip BiLSTM stack for EQTP
x, _ = self.transformer_d0(x)
x, _ = self.transformer_d(x)
outputs = []
# Polarity part
for lstm, attention, decoder, conv in zip(
self.pol_lstms, self.pol_attentions, self.pol_decoders, self.pol_convs
):
polx = x.permute(2, 0, 1)
polx = lstm(polx)[0]
polx = self.dropout(polx)
polx = polx.permute(1, 2, 0)
polx, _ = attention(polx)
polx = decoder(polx)
if logits:
predp = conv(polx)
else:
predp = torch.sigmoid(conv(polx))
predp = torch.squeeze(predp, dim=1)
outputs.append(predp)
# Pick parts
for lstm, attention, decoder, conv in zip(
self.pick_lstms, self.pick_attentions, self.pick_decoders, self.pick_convs
):
px = x.permute(2, 0, 1)
px = lstm(px)[0]
px = self.dropout(px)
px = px.permute(1, 2, 0)
px, _ = attention(px)
px = decoder(px)
if logits:
pred = conv(px)
else:
pred = torch.sigmoid(conv(px))
pred = torch.squeeze(pred, dim=1)
outputs.append(pred)
return tuple(outputs)
[docs]
def classify_aggregate(self, annotations, argdict) -> sbu.ClassifyOutput:
"""
Converts the annotations to discrete picks using
:py:func:`~seisbench.models.base.WaveformModel.picks_from_annotations`
and to discrete detections using :py:func:`~seisbench.models.base.WaveformModel.detections_from_annotations`.
Trigger onset thresholds for picks are derived from the argdict at keys "[phase]_threshold".
Trigger onset thresholds for detections are derived from the argdict at key "detection_threshold".
:param annotations: See description in superclass
:param argdict: See description in superclass
:return: List of picks, list of detections
"""
picks = sbu.PickList()
for phase in self.phases:
picks += self.picks_from_annotations(
annotations.select(channel=f"{self.__class__.__name__}_{phase}"),
argdict.get(
f"{phase}_threshold", self._annotate_args.get("*_threshold")[1]
),
phase,
)
picks = sbu.PickList(sorted(picks))
self._extract_polarities(annotations, picks, argdict)
return sbu.ClassifyOutput(self.name, picks=picks)
def _extract_polarities(
self, annotations: obspy.Stream, picks: sbu.PickList, argdict: dict[str, Any]
):
polarity_threshold = (
argdict.get(
"polarity_threshold", self._annotate_args.get("*_threshold")[1]
),
)
for pick in picks:
if pick.phase == "P":
t = pick.peak_time
scores = {}
for pol in "UD":
trace = annotations.select(
id=f"{pick.trace_id}.{self.__class__.__name__}_Polarity_{pol}"
).slice(t - 5 / self.sampling_rate, t + 5 * self.sampling_rate)
if len(trace) != 1:
continue
trace = trace[0]
sample = int(
(t - trace.stats.starttime) * trace.stats.sampling_rate
)
segment = trace.data[
max(0, sample - 3) : sample + 3
] # Take a small tolerance around
scores[pol] = np.max(segment)
if len(scores) != 2:
continue
polarity = max(scores, key=scores.get)
if scores[polarity] > polarity_threshold:
pick.polarity = polarity
pick.polarity_value = scores[polarity]
else:
pick.polarity = "N"
pick.polarity_value = 1 - scores["U"] - scores["D"]
return picks