Source code for seisbench.models.eqtransformer

import warnings
from typing import Any

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import seisbench.util as sbu

from .base import ActivationLSTMCell, CustomLSTM, WaveformModel


# For implementation, potentially follow: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
[docs] class EQTransformer(WaveformModel): """ The EQTransformer from Mousavi et al. (2020) Implementation adapted from the Github repository https://github.com/smousavi05/EQTransformer Assumes padding="same" and activation="relu" as in the pretrained EQTransformer models By instantiating the model with `from_pretrained("original")` a binary compatible version of the original EQTransformer with the original weights from Mousavi et al. (2020) can be loaded. .. document_args:: seisbench.models EQTransformer :param in_channels: Number of input channels, by default 3. :param in_samples: Number of input samples per channel, by default 6000. The model expects input shape (in_channels, in_samples) :param classes: Number of output classes, by default 2. The detection channel is not counted. :param phases: Phase hints for the classes, by default "PS". Can be None. :param lstm_blocks: Number of LSTM blocks :param drop_rate: Dropout rate :param original_compatible: If True, uses a few custom layers for binary compatibility with original model from Mousavi et al. (2020). This option defaults to False. It is usually recommended to stick to the default value, as the custom layers show slightly worse performance than the PyTorch builtins. The exception is when loading the original weights using :py:func:`from_pretrained`. :param norm: Data normalization strategy, either "peak" or "std". :param kwargs: Keyword arguments passed to the constructor of :py:class:`WaveformModel`. """ _annotate_args = WaveformModel._annotate_args.copy() _annotate_args["*_threshold"] = ("Detection threshold for the provided phase", 0.1) _annotate_args["detection_threshold"] = ("Detection threshold", 0.3) _annotate_args["blinding"] = ( "Number of prediction samples to discard on each side of each window prediction", (500, 500), ) # Overwrite default stacking method _annotate_args["stacking"] = ( "Stacking method for overlapping windows (only for window prediction models). " "Options are 'max' and 'avg'. ", "max", ) _annotate_args["overlap"] = (_annotate_args["overlap"][0], 3000) _weight_warnings = [ ( "ethz|geofon|instance|iquique|lendb|neic|scedc|stead", "1", "The normalization for this weight version is incorrect and will lead to degraded performance. " "Run from_pretrained with update=True once to solve this issue. " "For details, see https://github.com/seisbench/seisbench/pull/188 .", ), ] def __init__( self, in_channels=3, in_samples=6000, classes=2, phases="PS", lstm_blocks=3, drop_rate=0.1, original_compatible=False, sampling_rate=100, norm="std", **kwargs, ): citation = ( "Mousavi, S.M., Ellsworth, W.L., Zhu, W., Chuang, L, Y., and Beroza, G, C. " "Earthquake transformer—an attentive deep-learning model for simultaneous earthquake " "detection and phase picking. Nat Commun 11, 3952 (2020). " "https://doi.org/10.1038/s41467-020-17591-w" ) # PickBlue options for option in ("norm_amp_per_comp", "norm_detrend"): if option in kwargs: setattr(self, option, kwargs[option]) del kwargs[option] else: setattr(self, option, False) # Blinding defines how many samples at beginning and end of the prediction should be ignored # This is usually required to mitigate prediction problems from training properties, e.g., # if all picks in the training fall between seconds 5 and 55. super().__init__( citation=citation, output_type="array", in_samples=in_samples, pred_sample=(0, in_samples), labels=["Detection"] + list(phases), sampling_rate=sampling_rate, **kwargs, ) self.in_channels = in_channels self.classes = classes self.lstm_blocks = lstm_blocks self.drop_rate = drop_rate self.norm = norm # Add options for conservative and the true original - see https://github.com/seisbench/seisbench/issues/96#issuecomment-1155158224 if isinstance(original_compatible, bool) and original_compatible: warnings.warn( "Using the non-conservative 'original' model, set `original_compatible='conservative' to use the more conservative model" ) original_compatible = "non-conservative" if original_compatible: eps = 1e-7 # See Issue #96 - original models use tensorflow default epsilon of 1e-7 else: eps = 1e-5 self.original_compatible = original_compatible if original_compatible and in_samples != 6000: raise ValueError("original_compatible=True requires in_samples=6000.") self._phases = phases if phases is not None and len(phases) != classes: raise ValueError( f"Number of classes ({classes}) does not match number of phases ({len(phases)})." ) # Parameters from EQTransformer repository self.filters = [ 8, 16, 16, 32, 32, 64, 64, ] # Number of filters for the convolutions self.kernel_sizes = [11, 9, 7, 7, 5, 5, 3] # Kernel sizes for the convolutions self.res_cnn_kernels = [3, 3, 3, 3, 2, 3, 2] # TODO: Add regularizers when training model # kernel_regularizer=keras.regularizers.l2(1e-6), # bias_regularizer=keras.regularizers.l1(1e-4), # Encoder stack self.encoder = Encoder( input_channels=self.in_channels, filters=self.filters, kernel_sizes=self.kernel_sizes, in_samples=self.in_samples, ) # Res CNN Stack self.res_cnn_stack = ResCNNStack( kernel_sizes=self.res_cnn_kernels, filters=self.filters[-1], drop_rate=self.drop_rate, ) # BiLSTM stack self.bi_lstm_stack = BiLSTMStack( blocks=self.lstm_blocks, input_size=self.filters[-1], drop_rate=self.drop_rate, original_compatible=original_compatible, ) # Global attention - two transformers self.transformer_d0 = Transformer( input_size=16, drop_rate=self.drop_rate, eps=eps ) self.transformer_d = Transformer( input_size=16, drop_rate=self.drop_rate, eps=eps ) # Detection decoder and final Conv self.decoder_d = Decoder( input_channels=16, filters=self.filters[::-1], kernel_sizes=self.kernel_sizes[::-1], out_samples=in_samples, original_compatible=original_compatible, ) self.conv_d = nn.Conv1d( in_channels=self.filters[0], out_channels=1, kernel_size=11, padding=5 ) # Picking branches self.pick_lstms = [] self.pick_attentions = [] self.pick_decoders = [] self.pick_convs = [] self.dropout = nn.Dropout(drop_rate) for _ in range(self.classes): if original_compatible == "conservative": # The non-conservative model uses a sigmoid activiation as handled by the base nn.LSTM lstm = CustomLSTM(ActivationLSTMCell, 16, 16, bidirectional=False) else: lstm = nn.LSTM(16, 16, bidirectional=False) self.pick_lstms.append(lstm) attention = SeqSelfAttention(input_size=16, attention_width=3, eps=eps) self.pick_attentions.append(attention) decoder = Decoder( input_channels=16, filters=self.filters[::-1], kernel_sizes=self.kernel_sizes[::-1], out_samples=in_samples, original_compatible=original_compatible, ) self.pick_decoders.append(decoder) conv = nn.Conv1d( in_channels=self.filters[0], out_channels=1, kernel_size=11, padding=5 ) self.pick_convs.append(conv) self.pick_lstms = nn.ModuleList(self.pick_lstms) self.pick_attentions = nn.ModuleList(self.pick_attentions) self.pick_decoders = nn.ModuleList(self.pick_decoders) self.pick_convs = nn.ModuleList(self.pick_convs)
[docs] def forward(self, x, logits=False): assert x.ndim == 3 assert x.shape[1:] == (self.in_channels, self.in_samples) # Shared encoder part x = self.encoder(x) x = self.res_cnn_stack(x) x = self.bi_lstm_stack(x) x, _ = self.transformer_d0(x) x, _ = self.transformer_d(x) # Detection part detection = self.decoder_d(x) if logits: detection = self.conv_d(detection) else: detection = torch.sigmoid(self.conv_d(detection)) detection = torch.squeeze(detection, dim=1) # Remove channel dimension outputs = [detection] # Pick parts for lstm, attention, decoder, conv in zip( self.pick_lstms, self.pick_attentions, self.pick_decoders, self.pick_convs ): px = x.permute( 2, 0, 1 ) # From batch, channels, sequence to sequence, batch, channels px = lstm(px)[0] px = self.dropout(px) px = px.permute( 1, 2, 0 ) # From sequence, batch, channels to batch, channels, sequence px, _ = attention(px) px = decoder(px) if logits: pred = conv(px) else: pred = torch.sigmoid(conv(px)) pred = torch.squeeze(pred, dim=1) # Remove channel dimension outputs.append(pred) return tuple(outputs)
[docs] def annotate_batch_post( self, batch: torch.Tensor, piggyback: Any, argdict: dict[str, Any] ) -> torch.Tensor: # Transpose predictions to correct shape batch = torch.stack(batch, dim=-1) prenan, postnan = argdict.get( "blinding", self._annotate_args.get("blinding")[1] ) if prenan > 0: batch[:, :prenan] = np.nan if postnan > 0: batch[:, -postnan:] = np.nan return batch
[docs] def annotate_batch_pre( self, batch: torch.Tensor, argdict: dict[str, Any] ) -> torch.Tensor: batch = batch - batch.mean(axis=-1, keepdims=True) if self.norm_detrend: batch = sbu.torch_detrend(batch) if self.norm_amp_per_comp: peak = batch.abs().max(axis=-1, keepdims=True)[0] batch = batch / (peak + 1e-10) else: if self.norm == "std": std = batch.std(axis=(-1, -2), keepdims=True) batch = batch / (std + 1e-10) elif self.norm == "peak": peak = batch.abs().max(axis=-1, keepdims=True)[0] batch = batch / (peak + 1e-10) # Cosine taper (very short, i.e., only six samples on each side) tap = 0.5 * ( 1 + torch.cos(torch.linspace(np.pi, 2 * np.pi, 6, device=batch.device)) ) batch[:, :, :6] *= tap batch[:, :, -6:] *= tap.flip(dims=(0,)) return batch
@property def phases(self): if self._phases is not None: return self._phases else: return list(range(self.classes))
[docs] def classify_aggregate(self, annotations, argdict) -> sbu.ClassifyOutput: """ Converts the annotations to discrete picks using :py:func:`~seisbench.models.base.WaveformModel.picks_from_annotations` and to discrete detections using :py:func:`~seisbench.models.base.WaveformModel.detections_from_annotations`. Trigger onset thresholds for picks are derived from the argdict at keys "[phase]_threshold". Trigger onset thresholds for detections are derived from the argdict at key "detection_threshold". :param annotations: See description in superclass :param argdict: See description in superclass :return: List of picks, list of detections """ picks = sbu.PickList() for phase in self.phases: picks += self.picks_from_annotations( annotations.select(channel=f"{self.__class__.__name__}_{phase}"), argdict.get( f"{phase}_threshold", self._annotate_args.get("*_threshold")[1] ), phase, ) picks = sbu.PickList(sorted(picks)) detections = self.detections_from_annotations( annotations.select(channel=f"{self.__class__.__name__}_Detection"), argdict.get( "detection_threshold", self._annotate_args.get("detection_threshold")[1] ), ) return sbu.ClassifyOutput(self.name, picks=picks, detections=detections)
[docs] def get_model_args(self): model_args = super().get_model_args() for key in [ "citation", "output_type", "default_args", "pred_sample", "labels", ]: del model_args[key] model_args["in_channels"] = self.in_channels model_args["in_samples"] = self.in_samples model_args["classes"] = self.classes model_args["phases"] = self.phases model_args["lstm_blocks"] = self.lstm_blocks model_args["drop_rate"] = self.drop_rate model_args["original_compatible"] = self.original_compatible model_args["sampling_rate"] = self.sampling_rate model_args["norm"] = self.norm model_args["norm_amp_per_comp"] = self.norm_amp_per_comp model_args["norm_detrend"] = self.norm_detrend return model_args
class Encoder(nn.Module): """ Encoder stack """ def __init__(self, input_channels, filters, kernel_sizes, in_samples): super().__init__() convs = [] pools = [] self.paddings = [] for in_channels, out_channels, kernel_size in zip( [input_channels] + filters[:-1], filters, kernel_sizes ): convs.append( nn.Conv1d( in_channels, out_channels, kernel_size, padding=kernel_size // 2 ) ) # To be consistent with the behaviour in tensorflow, # padding needs to be added for odd numbers of input_samples padding = in_samples % 2 # Padding for MaxPool1d needs to be handled manually to conform with tf padding self.paddings.append(padding) pools.append(nn.MaxPool1d(2, padding=0)) in_samples = (in_samples + padding) // 2 self.convs = nn.ModuleList(convs) self.pools = nn.ModuleList(pools) def forward(self, x): for conv, pool, padding in zip(self.convs, self.pools, self.paddings): x = torch.relu(conv(x)) if padding != 0: # Only pad right, use -1e10 as negative infinity x = F.pad(x, (0, padding), "constant", -1e10) x = pool(x) return x class Decoder(nn.Module): def __init__( self, input_channels, filters, kernel_sizes, out_samples, original_compatible=False, ): super().__init__() self.upsample = nn.Upsample(scale_factor=2, mode="nearest") self.original_compatible = original_compatible # We need to trim off the final sample sometimes to get to the right number of output samples self.crops = [] current_samples = out_samples for i, _ in enumerate(filters): padding = current_samples % 2 current_samples = (current_samples + padding) // 2 if padding == 1: self.crops.append(len(filters) - 1 - i) convs = [] for in_channels, out_channels, kernel_size in zip( [input_channels] + filters[:-1], filters, kernel_sizes ): convs.append( nn.Conv1d( in_channels, out_channels, kernel_size, padding=kernel_size // 2 ) ) self.convs = nn.ModuleList(convs) def forward(self, x): for i, conv in enumerate(self.convs): x = self.upsample(x) if self.original_compatible: if i == 3: x = x[:, :, 1:-1] else: if i in self.crops: x = x[:, :, :-1] x = F.relu(conv(x)) return x class ResCNNStack(nn.Module): def __init__(self, kernel_sizes, filters, drop_rate): super().__init__() members = [] for ker in kernel_sizes: members.append(ResCNNBlock(filters, ker, drop_rate)) self.members = nn.ModuleList(members) def forward(self, x): for member in self.members: x = member(x) return x class ResCNNBlock(nn.Module): def __init__(self, filters, ker, drop_rate): super().__init__() self.manual_padding = False if ker == 3: padding = 1 else: # ker == 2 # Manual padding emulate the padding in tensorflow self.manual_padding = True padding = 0 self.dropout = SpatialDropout1d(drop_rate) self.norm1 = nn.BatchNorm1d(filters, eps=1e-3) self.conv1 = nn.Conv1d(filters, filters, ker, padding=padding) self.norm2 = nn.BatchNorm1d(filters, eps=1e-3) self.conv2 = nn.Conv1d(filters, filters, ker, padding=padding) def forward(self, x): y = self.norm1(x) y = F.relu(y) y = self.dropout(y) if self.manual_padding: y = F.pad(y, (0, 1), "constant", 0) y = self.conv1(y) y = self.norm2(y) y = F.relu(y) y = self.dropout(y) if self.manual_padding: y = F.pad(y, (0, 1), "constant", 0) y = self.conv2(y) return x + y class BiLSTMStack(nn.Module): def __init__( self, blocks, input_size, drop_rate, hidden_size=16, original_compatible=False ): super().__init__() # First LSTM has a different input size as the subsequent ones self.members = nn.ModuleList( [ BiLSTMBlock( input_size, hidden_size, drop_rate, original_compatible=original_compatible, ) ] + [ BiLSTMBlock( hidden_size, hidden_size, drop_rate, original_compatible=original_compatible, ) for _ in range(blocks - 1) ] ) def forward(self, x): for member in self.members: x = member(x) return x class BiLSTMBlock(nn.Module): def __init__(self, input_size, hidden_size, drop_rate, original_compatible=False): super().__init__() if original_compatible == "conservative": # The non-conservative model uses a sigmoid activiation as handled by the base nn.LSTM self.lstm = CustomLSTM(ActivationLSTMCell, input_size, hidden_size) elif original_compatible == "non-conservative": self.lstm = CustomLSTM( ActivationLSTMCell, input_size, hidden_size, gate_activation=torch.sigmoid, ) else: self.lstm = nn.LSTM(input_size, hidden_size, bidirectional=True) self.dropout = nn.Dropout(drop_rate) self.conv = nn.Conv1d(2 * hidden_size, hidden_size, 1) self.norm = nn.BatchNorm1d(hidden_size, eps=1e-3) def forward(self, x): x = x.permute( 2, 0, 1 ) # From batch, channels, sequence to sequence, batch, channels x = self.lstm(x)[0] x = self.dropout(x) x = x.permute( 1, 2, 0 ) # From sequence, batch, channels to batch, channels, sequence x = self.conv(x) x = self.norm(x) return x class Transformer(nn.Module): def __init__(self, input_size, drop_rate, attention_width=None, eps=1e-5): super().__init__() self.attention = SeqSelfAttention( input_size, attention_width=attention_width, eps=eps ) self.norm1 = LayerNormalization(input_size) self.ff = FeedForward(input_size, drop_rate) self.norm2 = LayerNormalization(input_size) def forward(self, x): y, weight = self.attention(x) y = x + y y = self.norm1(y) y2 = self.ff(y) y2 = y + y2 y2 = self.norm2(y2) return y2, weight class SeqSelfAttention(nn.Module): """ Additive self attention """ def __init__(self, input_size, units=32, attention_width=None, eps=1e-5): super().__init__() self.attention_width = attention_width self.Wx = nn.Parameter(uniform(-0.02, 0.02, input_size, units)) self.Wt = nn.Parameter(uniform(-0.02, 0.02, input_size, units)) self.bh = nn.Parameter(torch.zeros(units)) self.Wa = nn.Parameter(uniform(-0.02, 0.02, units, 1)) self.ba = nn.Parameter(torch.zeros(1)) self.eps = eps def forward(self, x): # x.shape == (batch, channels, time) x = x.permute(0, 2, 1) # to (batch, time, channels) q = torch.unsqueeze( torch.matmul(x, self.Wt), 2 ) # Shape (batch, time, 1, channels) k = torch.unsqueeze( torch.matmul(x, self.Wx), 1 ) # Shape (batch, 1, time, channels) h = torch.tanh(q + k + self.bh) # Emissions e = torch.squeeze( torch.matmul(h, self.Wa) + self.ba, -1 ) # Shape (batch, time, time) # This is essentially softmax with an additional attention component. e = ( e - torch.max(e, dim=-1, keepdim=True).values ) # In versions <= 0.2.1 e was incorrectly normalized by max(x) e = torch.exp(e) if self.attention_width is not None: lower = ( torch.arange(0, e.shape[1], device=e.device) - self.attention_width // 2 ) upper = lower + self.attention_width indices = torch.unsqueeze(torch.arange(0, e.shape[1], device=e.device), 1) mask = torch.logical_and(lower <= indices, indices < upper) e = torch.where(mask, e, torch.zeros_like(e)) a = e / (torch.sum(e, dim=-1, keepdim=True) + self.eps) v = torch.matmul(a, x) v = v.permute(0, 2, 1) # to (batch, channels, time) return v, a def uniform(a, b, *args): return a + (b - a) * torch.rand(*args) class LayerNormalization(nn.Module): def __init__(self, filters, eps=1e-14): super().__init__() gamma = torch.ones(filters, 1) self.gamma = nn.Parameter(gamma) beta = torch.zeros(filters, 1) self.beta = nn.Parameter(beta) self.eps = eps def forward(self, x): mean = torch.mean(x, 1, keepdim=True) var = torch.mean((x - mean) ** 2, 1, keepdim=True) + self.eps std = torch.sqrt(var) outputs = (x - mean) / std outputs = outputs * self.gamma outputs = outputs + self.beta return outputs class FeedForward(nn.Module): def __init__(self, io_size, drop_rate, hidden_size=128): super().__init__() self.lin1 = nn.Linear(io_size, hidden_size) self.lin2 = nn.Linear(hidden_size, io_size) self.dropout = nn.Dropout(drop_rate) def forward(self, x): x = x.permute(0, 2, 1) # To (batch, time, channel) x = F.relu(self.lin1(x)) x = self.dropout(x) x = self.lin2(x) x = x.permute(0, 2, 1) # To (batch, channel, time) return x class SpatialDropout1d(nn.Module): def __init__(self, drop_rate): super().__init__() self.drop_rate = drop_rate self.dropout = nn.Dropout2d(drop_rate) def forward(self, x): x = x.unsqueeze(dim=-1) # Add fake dimension x = self.dropout(x) x = x.squeeze(dim=-1) # Remove fake dimension return x