Source code for seisbench.models.eqcct

"""
EQCCT P- and S-wave phase pickers as SeisBench :class:`~seisbench.models.base.WaveformModel` classes.

EQCCT uses separate models for P and S phase picking. Load the P-branch with
:py:class:`EQCCTP` and the S-branch with :py:class:`EQCCTS`, each via
:py:meth:`~seisbench.models.base.SeisBenchModel.from_pretrained`.
Pretrained weights are stored under ``<cache_model_root>/eqcct/`` and
``<cache_model_root>/eqccts/`` respectively.
"""

from __future__ import annotations

from typing import Any

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import seisbench.util as sbu
from .base import WaveformModel

_EQCCT_CITATION = (
    "Saad, O. M., Chen, Y., Siervo, D., Zhang, F., Savvaidis, A., Huang, G.-c., "
    "Igonin, N., Fomel, S., & Chen, Y. (2023). "
    "EQCCT: A Production-Ready Earthquake Detection and Phase-Picking Method Using "
    "the Compact Convolutional Transformer. "
    "IEEE Transactions on Geoscience and Remote Sensing, 61, 1-15. "
    "https://doi.org/10.1109/TGRS.2023.3319440"
)

# Model parameters (matching TensorFlow / EQCCT reference)
stochastic_depth_rate = 0.1
image_size = 6000
patch_size = 40
num_patches = image_size // patch_size  # 150
projection_dim = 40
num_heads = 4
patch_dim = 40 * 1 * patch_size
transformer_layers = 4


class ConvF1Block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=11, dropout_rate=0.1):
        super().__init__()
        self.conv1 = nn.Conv1d(
            in_channels, in_channels, kernel_size, padding=kernel_size // 2
        )
        self.bn1 = nn.BatchNorm1d(in_channels, eps=0.001)

        self.conv2 = nn.Conv1d(
            in_channels, in_channels, kernel_size, padding=kernel_size // 2
        )
        self.bn2 = nn.BatchNorm1d(in_channels, eps=0.001)

        self.conv3 = nn.Conv1d(
            in_channels, out_channels, kernel_size, padding=kernel_size // 2
        )
        self.bn3 = nn.BatchNorm1d(out_channels, eps=0.001)

        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.gelu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = F.gelu(out)
        out = out + x

        out = self.conv3(out)
        out = self.bn3(out)
        out = F.gelu(out)
        self.dropout(out)

        return out


class Patches(nn.Module):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def forward(self, images):
        B = images.size(0)
        P = self.patch_size

        patches = images.unfold(1, P, P)

        patches = patches.permute(0, 1, 4, 2, 3).contiguous()

        patches = patches.view(B, patches.size(1), -1)
        return patches


class PatchEncoder(nn.Module):
    def __init__(self, num_patches, projection_dim, patch_dim):
        super().__init__()
        self.projection = nn.Linear(patch_dim, projection_dim)
        self.position_embedding = nn.Embedding(num_patches, projection_dim)

    def forward(self, x):
        positions = (
            torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1)
        )
        x = self.projection(x) + self.position_embedding(positions)
        return x


class StochasticDepth(nn.Module):
    def __init__(self, drop_prob):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if not self.training or self.drop_prob == 0.0:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.dim() - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        binary_tensor = torch.floor(random_tensor)
        return (x / keep_prob) * binary_tensor


class TransformerMLP(nn.Module):
    def __init__(self, dim, dropout_rate=0.1):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim)
        self.fc2 = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout_rate)
        self.gelu = nn.GELU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.gelu(x)
        x = self.dropout(x)
        return x


class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, drop_prob=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = KerasMHA(embed_dim=dim, num_heads=num_heads, key_dim=40)
        self.drop_path1 = StochasticDepth(drop_prob)

        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.mlp = TransformerMLP(dim, dropout_rate=0.1)
        self.drop_path2 = StochasticDepth(drop_prob)

    def forward(self, x):
        identity = x
        x = self.norm1(x)
        attn_out = self.attn(x)
        x = identity + self.drop_path1(attn_out)

        identity = x
        x = self.norm2(x)
        x = identity + self.drop_path2(self.mlp(x))
        return x


class OutputHead(nn.Module):
    def __init__(self, in_channels=1, kernel_size=15):
        super().__init__()
        self.conv = nn.Conv1d(
            in_channels, 1, kernel_size=kernel_size, padding=kernel_size // 2
        )
        self.activation = nn.Sigmoid()

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.conv(x)
        return self.activation(x).transpose(1, 2)


class KerasMHA(nn.Module):
    """
    Faithful re-implementation of tf.keras.layers.MultiHeadAttention
    with   key_dim = 40,  num_heads = 4,  embed_dim = 40.
    Internal hidden size = key_dim * num_heads = 160.
    """

    def __init__(self, embed_dim=40, num_heads=4, key_dim=40):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.inner_dim = num_heads * key_dim  # 160
        self.scale = 1.0 / np.sqrt(key_dim)

        self.q = nn.Linear(embed_dim, self.inner_dim, bias=True)
        self.k = nn.Linear(embed_dim, self.inner_dim, bias=True)
        self.v = nn.Linear(embed_dim, self.inner_dim, bias=True)
        self.o = nn.Linear(self.inner_dim, embed_dim, bias=True)

    # helper
    def _split(self, x):
        B, T, _ = x.shape
        return x.view(B, T, self.num_heads, self.key_dim).transpose(1, 2)  # (B,H,T,D)

    def _merge(self, x):
        B, H, T, D = x.shape
        return x.transpose(1, 2).reshape(B, T, H * D)  # (B,T,160)

    def forward(self, x):
        q = self._split(self.q(x))
        k = self._split(self.k(x))
        v = self._split(self.v(x))

        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        weights = F.softmax(scores, dim=-1)
        ctx = torch.matmul(weights, v)  # (B,H,T,D)
        ctx = self._merge(ctx)  # (B,T,160)
        return self.o(ctx)


class EQCCTModelP(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = ConvF1Block(3, 10)
        self.conv2 = ConvF1Block(10, 20)
        self.conv3 = ConvF1Block(20, 40)

        self.patch = Patches(patch_size)
        self.encoder = PatchEncoder(num_patches, projection_dim, patch_dim)

        self.transformer = nn.Sequential(
            *[
                TransformerBlock(
                    projection_dim,
                    num_heads,
                    drop_prob=stochastic_depth_rate * (i / transformer_layers),
                )
                for i in range(transformer_layers)
            ]
        )
        self.norm = nn.LayerNorm(projection_dim, eps=1e-6)
        self.head = OutputHead(in_channels=1)

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)

        x = x.unsqueeze(2).permute(0, 3, 2, 1)
        x = self.patch(x)
        x = self.encoder(x)
        x = self.transformer(x)
        x = self.norm(x)

        x = x.reshape(x.size(0), 6000, 1)
        return self.head(x)


class EQCCTModelS(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = ConvF1Block(3, 10)
        self.conv2 = ConvF1Block(10, 20)
        self.conv3 = ConvF1Block(20, 40)

        self.patch = Patches(patch_size)
        self.encoder = PatchEncoder(num_patches, projection_dim, patch_dim)

        self.extra_pre = nn.ModuleList(
            [ConvF1Block(40, 40) for _ in range(transformer_layers)]
        )
        self.extra_post = nn.ModuleList(
            [ConvF1Block(40, 40) for _ in range(transformer_layers)]
        )

        self.transformers = nn.ModuleList(
            [
                TransformerBlock(
                    projection_dim,
                    num_heads,
                    drop_prob=stochastic_depth_rate * (i / transformer_layers),
                )
                for i in range(transformer_layers)
            ]
        )

        self.norm = nn.LayerNorm(projection_dim, eps=1e-6)
        self.head = OutputHead(in_channels=1)

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)

        x = x.unsqueeze(2).permute(0, 3, 2, 1)
        x = self.patch(x)
        x = self.encoder(x)

        for i in range(transformer_layers):
            x_pre_conv = x.transpose(1, 2)
            x_pre_conv = self.extra_pre[i](x_pre_conv).transpose(1, 2)
            x = x_pre_conv

            identity = x
            x_norm1 = self.transformers[i].norm1(x)
            attention_output = self.transformers[i].attn(x_norm1)
            attention_output_post_conv = attention_output.transpose(1, 2)
            attention_output_post_conv = self.extra_post[i](
                attention_output_post_conv
            ).transpose(1, 2)
            x = identity + self.transformers[i].drop_path1(attention_output_post_conv)

            identity2 = x
            x_norm2 = self.transformers[i].norm2(x)
            x_mlp = self.transformers[i].mlp(x_norm2)
            x = identity2 + self.transformers[i].drop_path2(x_mlp)

        x = self.norm(x)
        x = x.reshape(x.size(0), 6000, 1)
        return self.head(x)


class _EQCCTBranchWaveform(WaveformModel):
    """
    Shared windowing, preprocessing, and I/O for EQCCT P/S branches.

    Expects 6000-sample (60 s at 100 Hz) three-component windows. Subclasses supply
    the PyTorch backbone and phase label list (:class:`EQCCTP` for P, :class:`EQCCTS`
    for S).
    """

    _annotate_args = WaveformModel._annotate_args.copy()
    _annotate_args["blinding"] = (
        "Number of prediction samples to discard on each side of each window prediction",
        (500, 500),
    )
    _annotate_args["S_threshold"] = (
        "Detection threshold for S-phase probability curve",
        0.5,
    )
    _annotate_args["stacking"] = (
        "Stacking method for overlapping windows (only for window prediction models). "
        "Options are 'max' and 'avg'. ",
        "max",
    )
    _annotate_args["overlap"] = (_annotate_args["overlap"][0], 0.5)

    def __init__(
        self,
        backbone: nn.Module,
        labels: list[str],
        citation: str,
        sampling_rate: float = 100,
        norm: str = "std",
        norm_amp_per_comp: bool = False,
        norm_detrend: bool = False,
        **kwargs,
    ):
        super().__init__(
            citation=citation,
            output_type="array",
            in_samples=6000,
            pred_sample=(0, 6000),
            labels=labels,
            sampling_rate=sampling_rate,
            **kwargs,
        )

        self.norm = norm
        self.norm_amp_per_comp = norm_amp_per_comp
        self.norm_detrend = norm_detrend
        self._eqcct_backbone = backbone

    def forward(self, x):
        """
        Run the EQCCT backbone on a SeisBench-formatted batch.

        :param x: Tensor of shape ``(batch, 3, in_samples)``.
        :return: Phase probability curves of shape ``(batch, n_labels, in_samples)``.
        """
        if x.ndim != 3:
            raise ValueError(f"Expected 3D input (B, 3, T), got shape {tuple(x.shape)}")
        if x.shape[2] != self.in_samples:
            raise ValueError(
                f"Expected last dim in_samples={self.in_samples}, got {x.shape[2]}"
            )
        if x.shape[1] == 3:
            x = x.transpose(1, 2)
        y = self._eqcct_backbone(x)
        return y.transpose(1, 2)

    def annotate_batch_pre(
        self, batch: torch.Tensor, argdict: dict[str, Any]
    ) -> torch.Tensor:
        """
        EQCCT preprocessing applied to each window before inference.

        Removes the per-window mean, optionally detrends and peak/std-normalizes the
        waveforms, then applies a short cosine taper to the first and last six samples
        of each trace.
        """
        batch = batch - batch.mean(axis=-1, keepdims=True)
        if self.norm_detrend:
            batch = sbu.torch_detrend(batch)
        if self.norm_amp_per_comp:
            peak = batch.abs().max(axis=-1, keepdims=True)[0]
            batch = batch / (peak + 1e-10)
        else:
            if self.norm == "std":
                std = batch.std(axis=(-1, -2), keepdims=True)
                batch = batch / (std + 1e-10)
            elif self.norm == "peak":
                peak = batch.abs().max(axis=-1, keepdims=True)[0]
                batch = batch / (peak + 1e-10)

        tap = 0.5 * (
            1
            + torch.cos(
                torch.linspace(
                    np.pi,
                    2 * np.pi,
                    6,
                    device=batch.device,
                    dtype=batch.dtype,
                )
            )
        )
        batch[:, :, :6] *= tap
        batch[:, :, -6:] *= tap.flip(dims=(0,))

        return batch

    def annotate_batch_post(
        self, batch: torch.Tensor, piggyback: Any, argdict: dict[str, Any]
    ) -> torch.Tensor:
        """
        Transpose model outputs to SeisBench channel layout and blind edge samples.

        Samples at the beginning and end of each prediction (see ``blinding`` in
        :py:func:`annotate` / :py:func:`classify`) are set to NaN before stacking
        overlapping windows.
        """
        batch = torch.transpose(batch, -1, -2)
        prenan, postnan = argdict.get(
            "blinding", self._annotate_args.get("blinding")[1]
        )
        if prenan > 0:
            batch[:, :prenan] = np.nan
        if postnan > 0:
            batch[:, -postnan:] = np.nan
        return batch

    def classify_aggregate(self, annotations, argdict) -> sbu.ClassifyOutput:
        """
        Convert stacked probability annotations into discrete phase picks.

        Uses the phase-specific threshold from ``argdict`` (for example
        ``P_threshold`` or ``S_threshold``).
        """
        picks = sbu.PickList()
        for phase in self.labels:
            if phase == "N":
                continue
            threshold_key = f"{phase}_threshold"
            default_threshold = self._annotate_args.get(threshold_key, (None, 0.3))[1]
            picks += self.picks_from_annotations(
                annotations.select(channel=f"{self.__class__.__name__}_{phase}"),
                argdict.get(threshold_key, default_threshold),
                phase,
            )
        picks = sbu.PickList(sorted(picks))
        return sbu.ClassifyOutput(self.name, picks=picks)

    def get_model_args(self):
        model_args = super().get_model_args()
        for key in [
            "citation",
            "in_samples",
            "output_type",
            "default_args",
            "pred_sample",
            "labels",
        ]:
            del model_args[key]
        model_args["sampling_rate"] = self.sampling_rate
        model_args["norm"] = self.norm
        model_args["norm_amp_per_comp"] = self.norm_amp_per_comp
        model_args["norm_detrend"] = self.norm_detrend
        return model_args


[docs] class EQCCTP(_EQCCTBranchWaveform): """ The EQCCT P-wave phase picker from Saad et al. (2023). EQCCT uses separate compact convolutional-transformer models for P and S picking. This class wraps the P-branch architecture for SeisBench :py:func:`annotate` and :py:func:`classify` on 6000-sample (60 s) three-component windows at 100 Hz. By instantiating the model with ``from_pretrained("original")``, the PyTorch weights converted from the institutional TensorFlow EQCCT P-branch checkpoint can be loaded. .. document_args:: seisbench.models EQCCTP :param sampling_rate: Target sampling rate in Hz, by default 100. Incoming traces are resampled automatically when this differs. :param norm: Data normalization strategy, either ``"peak"`` or ``"std"``, by default ``"std"``. :param norm_amp_per_comp: If True, normalize each component independently by its peak amplitude. Defaults to False. :param norm_detrend: If True, apply linear detrending before normalization. Defaults to False. :param kwargs: Keyword arguments passed to the constructor of :py:class:`~seisbench.models.base.WaveformModel`. """ def __init__( self, sampling_rate=100, norm="std", norm_amp_per_comp=False, norm_detrend=False, **kwargs, ): super().__init__( EQCCTModelP(), ["P"], _EQCCT_CITATION, sampling_rate=sampling_rate, norm=norm, norm_amp_per_comp=norm_amp_per_comp, norm_detrend=norm_detrend, **kwargs, )
[docs] class EQCCTS(_EQCCTBranchWaveform): """ The EQCCT S-wave phase picker from Saad et al. (2023). EQCCT uses separate compact convolutional-transformer models for P and S picking. This class wraps the deeper S-branch architecture (with additional convolutional stems around each transformer block) for SeisBench :py:func:`annotate` and :py:func:`classify` on 6000-sample (60 s) three-component windows at 100 Hz. Use a **separate** S-branch checkpoint; do not load P-branch weights into this model. By instantiating the model with ``from_pretrained("original")``, the PyTorch weights converted from the institutional TensorFlow EQCCT S-branch checkpoint can be loaded. .. document_args:: seisbench.models EQCCTS :param sampling_rate: Target sampling rate in Hz, by default 100. Incoming traces are resampled automatically when this differs. :param norm: Data normalization strategy, either ``"peak"`` or ``"std"``, by default ``"std"``. :param norm_amp_per_comp: If True, normalize each component independently by its peak amplitude. Defaults to False. :param norm_detrend: If True, apply linear detrending before normalization. Defaults to False. :param kwargs: Keyword arguments passed to the constructor of :py:class:`~seisbench.models.base.WaveformModel`. """ def __init__( self, sampling_rate=100, norm="std", norm_amp_per_comp=False, norm_detrend=False, **kwargs, ): super().__init__( EQCCTModelS(), ["S"], _EQCCT_CITATION, sampling_rate=sampling_rate, norm=norm, norm_amp_per_comp=norm_amp_per_comp, norm_detrend=norm_detrend, **kwargs, )