Source code for seisbench.models.seisdae

from typing import Any, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.signal import istft, stft

from .base import WaveformModel


class AttentionGate(nn.Module):
    """
    Attention Gate for U-Net architectures.

    This module implements an attention mechanism that selectively emphasizes relevant features
    in encoder outputs before concatenation with decoder features. It is based on the additive
    attention gating mechanism from *Attention U-Net: Learning Where to Look for the Pancreas*
    (Oktay et al., 2018).

        .. admonition:: Citation
          Ozan Oktay, Jo Schlemper, Loic Le Folgoc, Matthew Lee, Mattias Heinrich, Kazunari Misawa,
          Kensaku Morim, Steven McDonagh, Nils Y Hammerla, Bernhard Kainz, Ben Glocker, Daniel Rueckert (2018)
          Attention U-Net: Learning Where to Look for the Pancreas

          https://arxiv.org/abs/1804.03999

    :param in_channels_encoder: Number of input channels from the encoder (skip connection).
    :param in_channels_decoder: Number of input channels from the decoder (gating signal).
    :param inter_channels: Number of intermediate channels used in attention computations.
    :param bias: If True, adds a learnable bias to the convolution layers. Default is False.
    """

    def __init__(
        self,
        in_channels_encoder: int,
        in_channels_decoder: int,
        inter_channels: int,
        bias: bool = False,
    ):
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(
                in_channels_decoder,
                inter_channels,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=bias,
            ),
            nn.BatchNorm2d(inter_channels),
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(
                in_channels_encoder,
                inter_channels,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=bias,
            ),
            nn.BatchNorm2d(inter_channels),
        )
        self.psi = nn.Sequential(
            nn.Conv2d(inter_channels, 1, kernel_size=1, stride=1, padding=0, bias=bias),
            nn.BatchNorm2d(1),
            nn.Sigmoid(),
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x_encoder, g_decoder):
        g1 = self.W_g(g_decoder)
        x1 = self.W_x(x_encoder)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)

        return x_encoder * psi  # element-wise gating


class ConvBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, tuple[int, int]] = 3,
        stride: Union[int, tuple[int, int]] = 1,
        drop_rate: float = 0.3,
        use_bias: bool = False,
    ):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=1,
                bias=use_bias,
            ),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU(),
            nn.Dropout(p=drop_rate),
        )

    def forward(self, x):
        return self.conv(x)


class TransposeConvBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, tuple[int, int]] = 3,
        stride: Union[int, tuple[int, int]] = 1,
        drop_rate: float = 0.3,
        use_bias: bool = False,
    ):
        super().__init__()
        self.block = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=1,
                output_padding=1,
                bias=use_bias,
            ),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU(),
            nn.Dropout(drop_rate),
        )

    def forward(self, x):
        return self.block(x)


[docs] class SeisDAE(WaveformModel): """ Seismic Denoising Autoencoder using U-Net Architecture with additional attention gates. A configurable denoising autoencoder for seismic waveform data that operates in the time-frequency domain using the Short-Time Fourier Transform (STFT). The model is based on a U-Net structure with optional attention gates and skip connections. :param in_samples: Length of the input waveform in samples. Default is 3000 samples. :param in_channels: Number of input channels (e.g., 2 for real and imaginary STFT components). Default are 2 channels. :param sampling_rate: Sampling rate of the waveform data in Hz. Default sampling rate is 100 Hz. :param filters_root: Number of filters in the first convolutional layer (doubles with depth). Default is 8. :param depth: Number of encoding/decoding levels in the U-Net. Default is 6 for STFT :param kernel_size: Kernel size for convolutional layers. Default is (3, 3). :param strides: Stride size used for down/upsampling using Conv2D and transpose Conv2D layers. Default is (2, 2). :param output_activation: Activation function applied to final output Default is Softmax. :param drop_rate: Dropout rate used throughout the network. Default drop_rate is 0 :param use_bias: Whether to use bias in convolutional layers. Default is False. :param norm: Type of normalization applied to traces ("peak" or "std"). Default is "peak" :param eps: Factor to avoid division by zero. Default value is 1e-13. :param nfft: Length of the FFT used, if a zero padded FFT is desired for scipy STFT. If None, the FFT length is nperseg. :param nperseg: Length of each segment for scipy STFT. Default is 60 :param attention: Whether to use attention gates in skip connections or not. Default is False :param kwargs: Additional arguments passed to the base `WaveformModel`. """ _annotate_args = WaveformModel._annotate_args.copy() _annotate_args["overlap"] = (_annotate_args["overlap"][0], 0.5) _annotate_args["blinding"] = ( "Number of prediction samples to discard on each side of each window prediction", (0, 0), ) def __init__( self, in_samples: int = 3000, in_channels: int = 2, sampling_rate: float = 100, filters_root: int = 8, depth: int = 6, kernel_size: tuple[int, int] = (3, 3), strides: tuple[int, int] = (2, 2), output_activation=torch.nn.Softmax(dim=1), drop_rate: float = 0.0, use_bias: bool = False, norm: str = "peak", eps: float = 1e-13, nfft: int = 60, nperseg: int = 30, attention: bool = False, **kwargs, ): citation = ( "Heuel, J., & Friederich, W. (2022). " "Suppression of wind turbine noise from seismological data using nonlinear thresholding " "and denoising autoencoder. " "Journal of Seismology, 26(5), 913-934. " "https://doi.org/10.1007/s10950-022-10097-6" ) WaveformModel.__init__( self, citation=citation, in_samples=in_samples, output_type="array", pred_sample=(0, in_samples), sampling_rate=sampling_rate, labels=self.generate_label, grouping="channel", **kwargs, ) self.in_samples = in_samples self.sampling_rate = sampling_rate self.in_channels = in_channels self.filters_root = filters_root self.depth = depth self.kernel_size = kernel_size self.stride = strides self.output_activation = output_activation self.drop_rate = drop_rate self.use_bias = use_bias self.norm = norm self.eps = eps self.norm_factors = None self.attention = attention # Determine input shape from STFT and check if STFT and ISTFT work self.nfft = nfft self.nperseg = nperseg # Performing STFT _, _, dummystft = stft( x=np.random.rand(self.in_samples), fs=self.sampling_rate, nfft=self.nfft, nperseg=self.nperseg, ) # Performing ISTFT t, dummy_x = istft( Zxx=dummystft, fs=self.sampling_rate, nfft=self.nfft, nperseg=self.nperseg, ) if len(dummy_x) != self.in_samples: msg = ( f"If data with length {self.in_samples} are transformed with STFT and back transformed with " f"ISTFT, the output lenght of ISTFT ({len(dummy_x)}) does not match. Choose different values " f"for nfft={self.nfft} and nperseg={self.nperseg}." ) raise ValueError(msg) self.input_shape = dummystft.shape # Write STFT values to default args dictionary self.default_args["nfft"] = self.nfft self.default_args["nperseg"] = self.nperseg # Allocate empty lists for all branches of Autoencoder self.encoder_blocks = nn.ModuleList() self.down_blocks = nn.ModuleList() self.decoder_blocks = nn.ModuleList() self.up_blocks = nn.ModuleList() self.attention_gates = nn.ModuleList() # Encoder cur_channels = in_channels for i in range(depth): out_channels = 2**i * self.filters_root self.encoder_blocks.append( ConvBlock( in_channels=cur_channels, out_channels=out_channels, kernel_size=self.kernel_size, drop_rate=self.drop_rate, use_bias=self.use_bias, ) ) if i < depth - 1: self.down_blocks.append( ConvBlock( in_channels=out_channels, out_channels=out_channels, kernel_size=self.kernel_size, stride=self.stride, drop_rate=self.drop_rate, use_bias=self.use_bias, ) ) cur_channels = out_channels # Decoder for i in range(depth - 2, -1, -1): in_channels = 2 ** (i + 1) * self.filters_root out_channels = 2**i * self.filters_root self.up_blocks.append( TransposeConvBlock( in_channels=in_channels, out_channels=out_channels, kernel_size=self.kernel_size, stride=self.stride, drop_rate=self.drop_rate, use_bias=self.use_bias, ) ) if self.attention: self.attention_gates.append( AttentionGate( in_channels_encoder=out_channels, in_channels_decoder=out_channels, inter_channels=out_channels // 2, ) ) else: self.attention_gates.append(None) self.decoder_blocks.append( ConvBlock( in_channels=out_channels, out_channels=out_channels, kernel_size=self.kernel_size, drop_rate=self.drop_rate, use_bias=self.use_bias, ) ) self.output_conv = nn.Conv2d( in_channels=self.filters_root, out_channels=self.in_channels, kernel_size=1 )
[docs] def forward(self, x): enc_features = [] # List to store encoder feature for skip connections # Encoder for i in range(self.depth): x = self.encoder_blocks[i](x) enc_features.append(x) if i < self.depth - 1: x = self.down_blocks[i](x) # Decoder for i in range(self.depth - 2, -1, -1): x = self.up_blocks[self.depth - 2 - i](x) # Pad if needed to match encoder feature size diff_y = enc_features[i].size(2) - x.size(2) diff_x = enc_features[i].size(3) - x.size(3) x = F.pad( x, [diff_x // 2, diff_x - diff_x // 2, diff_y // 2, diff_y - diff_y // 2], ) if self.attention: # Apply attention gate to encoder feature skip = self.attention_gates[self.depth - 2 - i](enc_features[i], x) x = x + skip # gated skip connection else: # No attention gate x = x + enc_features[i] # Skip connection x = self.decoder_blocks[self.depth - 2 - i](x) x = self.output_conv(x) # Apply out activation function if self.output_activation: x = self.output_activation(x) return x
[docs] @staticmethod def generate_label(stations): # Simply use channel as label return stations[0].split(".")[-1]
[docs] def annotate_batch_pre( self, batch: torch.Tensor, argdict: dict[str, Any] ) -> torch.Tensor: """ Does preprocessing for prediction """ # STFT of each batch and component batch, norm_factors = self._normalize_trace(batch=batch) noisy_stft = torch.stft( input=batch, n_fft=self.nfft, win_length=self.nperseg, window=torch.hann_window(self.nperseg).to(batch.device), hop_length=self.nperseg // 2, # for 50% overlap like SciPy default pad_mode="constant", return_complex=True, normalized=False, ) # Normalize real and imaginary input to range [-1, 1] # Normalizing real and imaginary part might distort amplitude and phase, however, from my experiences the # denoising result is more accurate. If you train a SeisDAE model without normalization, don't forget to # remove the normalization in the training (labeling.STFTDenoiserLabeller) noisy_stft_real = noisy_stft.real / torch.max(torch.abs(noisy_stft.real)) noisy_stft_imag = noisy_stft.imag / torch.max(torch.abs(noisy_stft.imag)) noisy_input = torch.stack( tensors=[torch.Tensor(noisy_stft_real), torch.Tensor(noisy_stft_imag)], dim=1, ) # Replace nans and infs noisy_input[torch.isnan(noisy_input)] = 0 noisy_input[torch.isinf(noisy_input)] = 0 return noisy_input, (noisy_stft, norm_factors)
def _normalize_trace(self, batch: torch.Tensor): """ Normalize each trace and save norm factors to scale back denoised traces. """ # Demean each trace batch = batch - batch.mean(dim=1, keepdim=True) # Compute norm factor if self.norm == "peak": norm = batch.abs().max(dim=1, keepdim=True).values elif self.norm == "std": norm = batch.std(dim=1, keepdim=True) # Normalize (and avoid division by zero) batch = batch / (norm + self.eps) # Save normalization factors norm_factors = norm.squeeze(1) return batch, norm_factors
[docs] def annotate_batch_post( self, batch: torch.Tensor, piggyback: Any, argdict: dict[str, Any] ) -> torch.Tensor: """ Does postprocessing when predicting datasets """ # Multiply piggyback[0] (STFT of noisy signal) with batch (predicted mask) for signal signal_stft = piggyback[0] * batch[:, 0, :] denoised_signal = torch.istft( input=signal_stft, n_fft=self.nfft, win_length=self.nperseg, hop_length=self.nperseg // 2, # match overlap used in stft window=torch.hann_window(self.nperseg).to(batch.device), normalized=False, ) # Apply blinding prenan, postnan = argdict.get( "blinding", self._annotate_args.get("blinding")[1] ) if prenan > 0: denoised_signal[:, :prenan] = np.nan if postnan > 0: denoised_signal[:, -postnan:] = np.nan return ( denoised_signal * piggyback[1][:, None] ) # Convert denoised to original amplitude
[docs] def get_model_args(self): model_args = super().get_model_args() for key in [ "citation", "in_samples", "output_type", "default_args", "pred_sample", "labels", "grouping", ]: del model_args[key] model_args["sampling_rate"] = self.sampling_rate model_args["norm"] = self.norm model_args["in_samples"] = self.in_samples model_args["nfft"] = self.nfft model_args["nperseg"] = self.nperseg return model_args