Source code for seisbench.models.skynet

from typing import Any

import numpy as np
import torch
import torch.nn as nn

import seisbench.util as sbu

from .base import WaveformModel


[docs] class Skynet(WaveformModel): """ Skynet is a phase picker with a U-Net architecture with a focus on regional phases and a long receptive field. The model offers first arrival pickers and multiphase pickers for Pn, Pg, Sn, and Sg phases. .. document_args:: seisbench.models Skynet """ _annotate_args = WaveformModel._annotate_args.copy() _annotate_args["*_threshold"] = ("Detection threshold for the provided phase", 0.5) _annotate_args["blinding"] = ( "Number of prediction samples to discard on each side of each window prediction", (0, 0), ) _annotate_args["overlap"] = (_annotate_args["overlap"][0], 15000) _annotate_args["stacking"] = (_annotate_args["stacking"][0], "max") def __init__( self, in_channels=3, classes=3, phases="PSN", sampling_rate=100, norm="peak", **kwargs, ): citation = ( "Aguilar Suarez, A. L., & Beroza, G. (2025). " "Picking Regional Seismic Phase Arrival Times with Deep Learning. " "Seismica, 4(1). " "https://doi.org/10.26443/seismica.v4i1.1431" ) super().__init__( citation=citation, in_samples=30000, output_type="array", pred_sample=(0, 30000), labels=phases, sampling_rate=sampling_rate, **kwargs, ) self.in_channels = in_channels self.classes = classes self.norm = norm kernel_size = 7 stride = 4 pad_size = int(kernel_size / 2) self.conv1 = nn.Conv1d( in_channels, 8, kernel_size=kernel_size, stride=1, padding="same" ) self.bn1 = nn.BatchNorm1d(num_features=8, eps=1e-3) self.conv2 = nn.Conv1d(8, 8, kernel_size=kernel_size, stride=1, padding="same") self.bn2 = nn.BatchNorm1d(num_features=8, eps=1e-3) self.conv3 = nn.Conv1d( 8, 8, kernel_size=kernel_size, stride=stride, padding=pad_size ) self.bn3 = nn.BatchNorm1d(num_features=8, eps=1e-3) self.conv4 = nn.Conv1d(8, 11, kernel_size=kernel_size, stride=1, padding="same") self.bn4 = nn.BatchNorm1d(num_features=11, eps=1e-3) self.conv5 = nn.Conv1d( 11, 11, kernel_size=kernel_size, stride=stride, padding=pad_size ) self.bn5 = nn.BatchNorm1d(num_features=11, eps=1e-3) self.conv6 = nn.Conv1d( 11, 16, kernel_size=kernel_size, stride=1, padding="same" ) self.bn6 = nn.BatchNorm1d(num_features=16, eps=1e-3) self.conv7 = nn.Conv1d( 16, 16, kernel_size=kernel_size, stride=stride, padding=pad_size ) self.bn7 = nn.BatchNorm1d(num_features=16, eps=1e-3) self.conv8 = nn.Conv1d( 16, 22, kernel_size=kernel_size, stride=1, padding="same" ) self.bn8 = nn.BatchNorm1d(num_features=22, eps=1e-3) self.conv9 = nn.Conv1d( 22, 22, kernel_size=kernel_size, stride=stride, padding=pad_size ) self.bn9 = nn.BatchNorm1d(num_features=22, eps=1e-3) self.conv10 = nn.Conv1d( 22, 32, kernel_size=kernel_size, stride=1, padding="same" ) self.bn10 = nn.BatchNorm1d(num_features=32, eps=1e-3) # extra from original UNet self.conv11 = nn.Conv1d( 32, 32, kernel_size=kernel_size, stride=stride, padding=pad_size ) self.bn11 = nn.BatchNorm1d(num_features=32, eps=1e-3) self.conv12 = nn.Conv1d( 32, 40, kernel_size=kernel_size, stride=1, padding="same" ) self.bn12 = nn.BatchNorm1d(num_features=40, eps=1e-3) self.dconv0 = nn.ConvTranspose1d( 40, 32, kernel_size=kernel_size, stride=stride, padding=pad_size ) self.bnd0 = nn.BatchNorm1d(num_features=32, eps=1e-3) self.dconv01 = nn.Conv1d( 64, 32, kernel_size=kernel_size, stride=1, padding="same" ) self.bnd01 = nn.BatchNorm1d(num_features=32, eps=1e-3) # self.dconv1 = nn.ConvTranspose1d( 32, 22, kernel_size=kernel_size, stride=stride, padding=pad_size ) self.bnd1 = nn.BatchNorm1d(num_features=22, eps=1e-3) self.dconv2 = nn.Conv1d( 44, 22, kernel_size=kernel_size, stride=1, padding="same" ) self.bnd2 = nn.BatchNorm1d(num_features=22, eps=1e-3) self.dconv3 = nn.ConvTranspose1d( 22, 16, kernel_size=kernel_size, stride=stride, padding=pad_size - 1 ) self.bnd3 = nn.BatchNorm1d(num_features=16, eps=1e-3) self.dconv4 = nn.Conv1d( 32, 16, kernel_size=kernel_size, stride=1, padding="same" ) self.bnd4 = nn.BatchNorm1d(num_features=16, eps=1e-3) self.dconv5 = nn.ConvTranspose1d( 16, 11, kernel_size=kernel_size, stride=stride, padding=pad_size - 1 ) self.bnd5 = nn.BatchNorm1d(num_features=11, eps=1e-3) self.dconv6 = nn.Conv1d( 22, 11, kernel_size=kernel_size, stride=1, padding="same" ) self.bnd6 = nn.BatchNorm1d(num_features=11, eps=1e-3) self.dconv7 = nn.ConvTranspose1d( 11, 8, kernel_size=kernel_size, stride=stride, padding=pad_size - 1 ) self.bnd7 = nn.BatchNorm1d(num_features=8, eps=1e-3) self.dconv8 = nn.Conv1d( 16, 8, kernel_size=kernel_size, stride=1, padding="same" ) self.bnd8 = nn.BatchNorm1d(num_features=8, eps=1e-3) self.dconv9 = nn.Conv1d( 8, classes, kernel_size=kernel_size, stride=1, padding="same" ) self.softmax = nn.Softmax(dim=1)
[docs] def forward(self, X, logits=False): X1 = torch.relu(self.bn1(self.conv1(X))) X2 = torch.relu(self.bn2(self.conv2(X1))) X3 = torch.relu(self.bn3(self.conv3(X2))) X4 = torch.relu(self.bn4(self.conv4(X3))) X5 = torch.relu(self.bn5(self.conv5(X4))) X6 = torch.relu(self.bn6(self.conv6(X5))) X7 = torch.relu(self.bn7(self.conv7(X6))) X8 = torch.relu(self.bn8(self.conv8(X7))) X9 = torch.relu(self.bn9(self.conv9(X8))) X10 = torch.relu(self.bn10(self.conv10(X9))) # extra from original UNet X10_a = torch.relu(self.bn11(self.conv11(X10))) X10_b = torch.relu(self.bn12(self.conv12(X10_a))) X10_c = torch.relu(self.bnd0(self.dconv0(X10_b))) X10_c = torch.cat( ( X10_c, torch.zeros((X10_c.shape[0], X10_c.shape[1], 1), device=X10_c.device), ), dim=-1, ) X10_c = torch.cat((X10, X10_c), dim=1) X10_d = torch.relu(self.bnd01(self.dconv01(X10_c))) X11 = torch.relu(self.bnd1(self.dconv1(X10_d))) X12 = torch.cat((X11, X8), dim=1) X12 = torch.relu(self.bnd2(self.dconv2(X12))) X13 = torch.relu(self.bnd3(self.dconv3(X12))) X14 = torch.relu(self.bnd4(self.dconv4(torch.cat((X13, X6), dim=1)))) X15 = torch.relu(self.bnd5(self.dconv5(X14))) X15 = torch.cat( (X15, torch.zeros((X15.shape[0], X15.shape[1], 1), device=X15.device)), dim=2, ) X16 = torch.relu(self.bnd6(self.dconv6(torch.cat((X15, X4), dim=1)))) X17 = torch.relu(self.bnd7(self.dconv7(X16))) X17 = torch.cat( (X17, torch.zeros((X17.shape[0], X17.shape[1], 1), device=X17.device)), dim=2, ) X18 = torch.relu(self.bnd8(self.dconv8(torch.cat((X17, X2), dim=1)))) X19 = self.dconv9(X18) if logits: return X19 else: return self.softmax(X19)
[docs] def annotate_batch_pre( self, batch: torch.Tensor, argdict: dict[str, Any] ) -> torch.Tensor: batch = batch - batch.mean(axis=-1, keepdims=True) if self.norm == "std": std = batch.std(axis=-1, keepdims=True).mean(axis=-2, keepdims=True) batch = batch / (std + 1e-10) elif self.norm == "peak": peak = ( batch.abs() .max(axis=-2, keepdims=True)[0] .max(axis=-1, keepdims=True)[0] ) batch = batch / (peak + 1e-10) return batch
[docs] def annotate_batch_post( self, batch: torch.Tensor, piggyback: Any, argdict: dict[str, Any] ) -> torch.Tensor: # Transpose predictions to correct shape batch = torch.transpose(batch, -1, -2) prenan, postnan = argdict.get( "blinding", self._annotate_args.get("blinding")[1] ) if prenan > 0: batch[:, :prenan] = np.nan if postnan > 0: batch[:, -postnan:] = np.nan return batch
[docs] def classify_aggregate(self, annotations, argdict) -> sbu.ClassifyOutput: """ Converts the annotations to discrete thresholds using :py:func:`~seisbench.models.base.WaveformModel.picks_from_annotations`. Trigger onset thresholds for picks are derived from the argdict at keys "[phase]_threshold". :param annotations: See description in superclass :param argdict: See description in superclass :return: List of picks """ picks = sbu.PickList() for phase in self.labels: if phase == "N": # Don't pick noise continue picks += self.picks_from_annotations( annotations.select(channel=f"{self.__class__.__name__}_{phase}"), argdict.get( f"{phase}_threshold", self._annotate_args.get("*_threshold")[1] ), phase, ) picks = sbu.PickList(sorted(picks)) return sbu.ClassifyOutput(self.name, picks=picks)
[docs] def get_model_args(self): model_args = super().get_model_args() for key in [ "citation", "in_samples", "output_type", "default_args", "pred_sample", "labels", ]: del model_args[key] model_args["in_channels"] = self.in_channels model_args["classes"] = self.classes model_args["phases"] = self.labels model_args["sampling_rate"] = self.sampling_rate model_args["norm"] = self.norm return model_args