|
|
| """ SincNet model """ |
| from functools import lru_cache |
| import numpy as np |
| import logging |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class SincNetFilterConvLayer(nn.Module): |
| """SincNet fast convolution filter layer""" |
|
|
| def __init__(self, out_channels: int, kernel_size: int, sample_rate=16000, |
| stride=1, padding=0, dilation=1, min_low_hz=50, min_band_hz=50, |
| in_channels=1, requires_grad=False): |
| """ |
| Args: |
| out_channels : `int` number of filters. |
| kernel_size : `int` filter length. |
| sample_rate : `int`, optional sample rate. Defaults to 16000. |
| """ |
| super(SincNetFilterConvLayer, self).__init__() |
|
|
| if in_channels != 1: |
| raise ValueError(f"SincNetFilterConvLayer only support in_channels = 1, was in_channels = {in_channels}") |
|
|
| self._out_channels = out_channels |
| self._kernel_size = kernel_size |
|
|
| if kernel_size % 2 == 0: |
| self._kernel_size += 1 |
| |
| self._stride = stride |
| self._padding = padding |
| self._dilation = dilation |
| self._sample_rate = sample_rate |
| self._min_low_hz = min_low_hz |
| self._min_band_hz = min_band_hz |
|
|
| |
| low_hz = 30 |
| high_hz = self._sample_rate / 2 - (self._min_low_hz + self._min_band_hz) |
| mel = np.linspace( |
| 2595 * np.log10(1 + low_hz / 700), |
| 2595 * np.log10(1 + high_hz / 700), |
| self._out_channels // 2 + 1 |
| ) |
| hz = 700 * (10 ** (mel / 2595) - 1) |
| |
| self._low_hz = nn.Parameter( |
| torch.Tensor(hz[:-1]).view(-1, 1), |
| requires_grad=requires_grad |
| ) |
| self._band_hz = nn.Parameter( |
| torch.Tensor(np.diff(hz)).view(-1, 1), |
| requires_grad=requires_grad |
| ) |
| self.register_buffer( |
| "_window", |
| torch.from_numpy(np.hamming(self._kernel_size)[: self._kernel_size // 2]).float() |
| ) |
| self.register_buffer( |
| "_n", |
| (2* np.pi * torch.arange(-(self._kernel_size // 2), 0.0).view(1, -1) / self._sample_rate) |
| ) |
|
|
| @property |
| @lru_cache(maxsize=1) |
| def filters(self) -> torch.Tensor: |
| low = self._min_low_hz + torch.abs(self._low_hz) |
| high = torch.clamp(low + self._min_band_hz + torch.abs(self._band_hz), self._min_low_hz, self._sample_rate/2) |
| band = (high-low)[:,0] |
|
|
| f_times_t_low = torch.matmul(low, self._n) |
| f_times_t_high = torch.matmul(high, self._n) |
|
|
| band_pass_left = ((torch.sin(f_times_t_high)-torch.sin(f_times_t_low))/(self._n/2))*self._window |
| band_pass_center = 2 * band.view(-1, 1) |
| band_pass_right = torch.flip(band_pass_left, dims=[1]) |
|
|
| band_pass = torch.cat([band_pass_left,band_pass_center,band_pass_right],dim=1) |
| band_pass = band_pass / (2*band[:,None]) |
| return band_pass.view(self._out_channels, 1, self._kernel_size) |
|
|
| def forward(self, waveforms: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| waveforms : (batch_size, 1, n_samples) batch of waveforms. |
| |
| Returns: |
| features : (batch_size, out_channels, n_samples_out) batch of sinc filters activations. |
| """ |
| return F.conv1d(waveforms, self.filters, stride=self._stride, |
| padding=self._padding, dilation=self._dilation, |
| ).abs_() |
|
|
| class SincNet(nn.Module): |
| """SincNet""" |
|
|
| def __init__( |
| self, |
| num_sinc_filters: int = 80, |
| sinc_filter_length: int = 251, |
| num_conv_filters: int = 60, |
| conv_filter_length: int = 5, |
| pool_kernel_size: int = 3, |
| pool_stride: int = 3, |
| sample_rate: int = 16000, |
| sinc_filter_stride: int = 10, |
| sinc_filter_padding: int = 0, |
| sinc_filter_dilation: int = 1, |
| min_low_hz: int = 50, |
| min_band_hz: int = 50, |
| sinc_filter_in_channels: int = 1, |
| num_wavform_channels: int = 1, |
| ): |
| super().__init__() |
|
|
| if sample_rate != 16000: |
| raise NotImplementedError(f"SincNet only supports 16kHz audio (sample_rate = 16000), was sample_rate = {sample_rate}") |
|
|
| self.wav_norm1d = nn.InstanceNorm1d(num_wavform_channels, affine=True) |
|
|
| self.conv1d = nn.ModuleList([ |
| SincNetFilterConvLayer( |
| num_sinc_filters, |
| sinc_filter_length, |
| sample_rate=sample_rate, |
| stride=sinc_filter_stride, |
| padding=sinc_filter_padding, |
| dilation=sinc_filter_dilation, |
| min_low_hz=min_low_hz, |
| min_band_hz=min_band_hz, |
| in_channels=sinc_filter_in_channels, |
| ), |
| nn.Conv1d(num_sinc_filters, num_conv_filters, conv_filter_length), |
| nn.Conv1d(num_conv_filters, num_conv_filters, conv_filter_length), |
| ]) |
| self.pool1d = nn.ModuleList([ |
| nn.MaxPool1d(pool_kernel_size, stride=pool_stride), |
| nn.MaxPool1d(pool_kernel_size, stride=pool_stride), |
| nn.MaxPool1d(pool_kernel_size, stride=pool_stride), |
| ]) |
| self.norm1d = nn.ModuleList([ |
| nn.InstanceNorm1d(num_sinc_filters, affine=True), |
| nn.InstanceNorm1d(num_conv_filters, affine=True), |
| nn.InstanceNorm1d(num_conv_filters, affine=True), |
| ]) |
|
|
| def forward(self, waveforms: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| waveforms : (batch, channel, sample) |
| """ |
| outputs = self.wav_norm1d(waveforms) |
|
|
| for _, (conv1d, pool1d, norm1d) in enumerate( |
| zip(self.conv1d, self.pool1d, self.norm1d) |
| ): |
| outputs = conv1d(outputs) |
| outputs = F.leaky_relu(norm1d(pool1d(outputs))) |
|
|
| return outputs |
|
|