from typing import Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base import WaveformModel
[docs]
class DeepDenoiser(WaveformModel):
"""
.. document_args:: seisbench.models DeepDenoiser
"""
_annotate_args = WaveformModel._annotate_args.copy()
_annotate_args["overlap"] = (_annotate_args["overlap"][0], 1500)
def __init__(self, sampling_rate=100, **kwargs):
citation = (
"Zhu, W., Mousavi, S. M., & Beroza, G. C. (2019). "
"Seismic signal denoising and decomposition using deep neural networks. "
"IEEE Transactions on Geoscience and Remote Sensing, 57.11(2019), 9476 - 9488. "
"https://doi.org/10.1109/TGRS.2019.2926772"
)
super().__init__(
citation=citation,
in_samples=3000,
output_type="array",
pred_sample=(0, 3000),
labels=self.generate_label,
sampling_rate=sampling_rate,
grouping="channel",
**kwargs,
)
self.inc = nn.Conv2d(2, 8, (3, 3), padding=(1, 1), bias=False)
self.in_bn = nn.BatchNorm2d(8, eps=1e-3)
self.down_conv_blocks = nn.ModuleList(
[DownConvBlock(8 * 2 ** max(0, i - 1), 8 * 2**i) for i in range(5)]
)
self.conv5 = nn.Conv2d(128, 256, (3, 3), padding=(1, 1), bias=False)
self.bn5 = nn.BatchNorm2d(256, eps=1e-3)
self.up_conv_blocks = nn.ModuleList(
[UpConvBlock(8 * 2 ** (5 - i), 8 * 2 ** (4 - i)) for i in range(5)]
)
self.outc = nn.Conv2d(8, 2, (1, 1), bias=True)
[docs]
def forward(self, x):
x = torch.relu(self.in_bn(self.inc(x)))
mids = []
for layer in self.down_conv_blocks:
x, mid = layer(x)
mids.append(mid)
x = torch.relu(self.bn5(self.conv5(x)))
for layer, mid in zip(self.up_conv_blocks, mids[::-1]):
x = layer(x, mid)
logits = self.outc(x)
preds = torch.softmax(logits, dim=1)
return preds
[docs]
@staticmethod
def generate_label(stations):
# Simply use channel as label
return stations[0].split(".")[-1]
[docs]
def annotate_batch_pre(
self, batch: torch.Tensor, argdict: dict[str, Any]
) -> torch.Tensor:
# Reproduces scipy.signal.stft(window, fs=self.sampling_rate, nperseg=30, nfft=60, boundary="zeros")
# Note that the outputs need to be rotated by a factor of i ** dim to match numpy
tmp_signal = (
2
/ 30
* torch.stft(
batch,
n_fft=60,
return_complex=True,
win_length=30,
window=torch.hann_window(30).to(batch.device),
pad_mode="constant",
normalized=False,
)
* torch.pow(1j, torch.arange(31, device=batch.device)).unsqueeze(-1)
)
noisy_signal = torch.stack([tmp_signal.real, tmp_signal.imag], dim=1)
noisy_signal[torch.isnan(noisy_signal)] = 0
noisy_signal[torch.isinf(noisy_signal)] = 0
return self._normalize_batch(noisy_signal), noisy_signal
@staticmethod
def _normalize_batch(data: torch.Tensor, window: int = 200) -> torch.Tensor:
"""
Adapted from original DeepDenoiser implementation available at
https://github.com/wayneweiqiang/DeepDenoiser/blob/7bd9284ece73e25c99db2ad101aacda2a215a41a/deepdenoiser/app.py#L72
data shape: 2, nf, nt
data: nbn, nf, nt, 2
"""
data = data.permute(0, 2, 3, 1) # nbn, nf, nt, 2
assert len(data.shape) == 4
shift = window // 2
nbt, nf, nt, nimg = data.shape
# std in slide windows
data_pad = torch.nn.functional.pad(
data, (0, 0, window // 2, window // 2), mode="reflect"
)
t = torch.arange(
0, nt + shift - 1, shift, device=data.device
) # 201 => 0, 100, 200
std = torch.zeros([nbt, len(t)], dtype=data.dtype, device=data.device)
mean = torch.zeros([nbt, len(t)], dtype=data.dtype, device=data.device)
for i in range(std.shape[1]):
std[:, i] = torch.std(
data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3)
)
mean[:, i] = torch.mean(
data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3)
)
std[:, -1], mean[:, -1] = std[:, -2], mean[:, -2]
std[:, 0], mean[:, 0] = std[:, 1], mean[:, 1]
# normalize data with interpolated std
interp_matrix = torch.zeros(3, 201, dtype=std.dtype, device=std.device)
interp_matrix[0, :101] = 1 - torch.linspace(
0, 1, 101, dtype=std.dtype, device=std.device
)
interp_matrix[1, :101] = torch.linspace(
0, 1, 101, dtype=std.dtype, device=std.device
)
interp_matrix[1, 100:] = 1 - torch.linspace(
0, 1, 101, dtype=std.dtype, device=std.device
)
interp_matrix[2, 100:] = torch.linspace(
0, 1, 101, dtype=std.dtype, device=std.device
)
std_interp = std @ interp_matrix
std_interp[std_interp == 0] = 1.0
mean_interp = mean @ interp_matrix
data = (data - mean_interp[:, np.newaxis, :, np.newaxis]) / std_interp[
:, np.newaxis, :, np.newaxis
]
if len(t) > 3: # need to address this normalization issue in training
data /= 2.0
data = data.permute(0, 3, 1, 2) # 1 (nbt), 2, nf, nt
return data
[docs]
def annotate_batch_post(
self, batch: torch.Tensor, piggyback: Any, argdict: dict[str, Any]
) -> torch.Tensor:
signal = (piggyback[:, 0] + piggyback[:, 1] * 1j) * batch[:, 0]
signal = signal / torch.pow(
1j, torch.arange(signal.shape[-2], device=signal.device).unsqueeze(-1)
)
denoised_signal = 15 * torch.istft(
signal,
n_fft=60,
win_length=30,
window=torch.hann_window(30).to(batch.device),
normalized=False,
)
return denoised_signal
[docs]
def get_model_args(self):
model_args = super().get_model_args()
for key in [
"citation",
"in_samples",
"output_type",
"default_args",
"pred_sample",
"labels",
"sampling_rate",
"grouping",
]:
del model_args[key]
model_args["sampling_rate"] = self.sampling_rate
return model_args
class DownConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels, out_channels, (3, 3), padding=(1, 1), bias=False
)
self.bn1 = nn.BatchNorm2d(out_channels, eps=1e-3)
self.conv2 = nn.Conv2d(
out_channels,
out_channels,
(3, 3),
stride=(2, 2),
padding=(1, 1),
bias=False,
)
self.bn2 = nn.BatchNorm2d(out_channels, eps=1e-3)
def forward(self, x):
x = torch.relu(self.bn1(self.conv1(x)))
mid = x
# Required for compatibility with tensorflow version - stride is treated differently otherwise
remove2 = False
remove3 = False
if x.shape[2] % 2 == 0:
x = F.pad(x, (0, 0, 1, 0), "constant", 0)
remove2 = True
if x.shape[3] % 2 == 0:
x = F.pad(x, (1, 0), "constant", 0)
remove3 = True
x = self.conv2(x)
if remove2:
x = x[:, :, 1:] # Remove padding
if remove3:
x = x[:, :, :, 1:] # Remove padding
x = torch.relu(self.bn2(x))
return x, mid
class UpConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = nn.ConvTranspose2d(
in_channels, out_channels, (3, 3), stride=(2, 2), padding=(0, 0), bias=False
)
self.bn1 = nn.BatchNorm2d(out_channels, eps=1e-3)
# Again in_channels to account for the added residual connections
self.conv2 = nn.Conv2d(
in_channels, out_channels, (3, 3), padding=(1, 1), bias=False
)
self.bn2 = nn.BatchNorm2d(out_channels, eps=1e-3)
def forward(self, x, mid):
x = self.conv1(x)
x = torch.relu(self.bn1(x))
# Truncation is necessary to get correct shapes and be compatible with tensorflow implementation
if mid.shape[2] % 2 == 0:
x = x[:, :, :-1]
else:
x = x[:, :, :-2]
if mid.shape[3] % 2 == 0:
x = x[:, :, :, :-1]
else:
x = x[:, :, :, :-2]
x = torch.cat([mid, x], dim=1)
x = torch.relu(self.bn2(self.conv2(x)))
return x