paper: https://arxiv.org/pdf/1703.10135
audio sample: https://google.github.io/tacotron/publications/tacotron/index.html
code: [original - tf] https://github.com/keithito/tacotron/blob/master/models/tacotron.py
[torch] https://github.com/r9y9/tacotron_pytorch/blob/master/tacotron_pytorch/tacotron.py
작성
[참고] https://khw11044.github.io/blog/papers/paper-etc/2021-01-31-tacotron1_summary/
attention 기반의 sequence-to-sequence 모델을 사용한 end-to-end TTS Tacotron 제안
WaveNet (2016)
DeepVoice (2017)
Wang et al (2016)
:"First step towards end-to-end parametric TTS synthesis:
Generating spectral parameters with neural attention."
Char2Wav (2017)
encoder | attention-based decoder | post-processing net
Model achitecture
Tacotron의 backbone은 attention기반 seq2seq model이다.
input: characters
encoder:
decoder:
합성:
class Tacotron(nn.Module):
def __init__(self, n_vocab, embedding_dim=256, mel_dim=80, linear_dim=1025,
r=5, padding_idx=None, use_memory_mask=False):
super(Tacotron, self).__init__()
self.mel_dim = mel_dim
self.linear_dim = linear_dim
self.use_memory_mask = use_memory_mask
self.embedding = nn.Embedding(n_vocab, embedding_dim,
padding_idx=padding_idx)
# Trying smaller std
self.embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(embedding_dim)
self.decoder = Decoder(mel_dim, r)
self.CBHG = CBHG(mel_dim, K=8, projections=[256, mel_dim])
self.last_linear = nn.Linear(mel_dim * 2, linear_dim)
def forward(self, inputs, targets=None, input_lengths=None):
#one-hot vector
B = inputs.size(0)
# embedding
inputs = self.embedding(inputs)
# (B, T', in_dim)
# encoder
encoder_outputs = self.encoder(inputs, input_lengths)
if self.use_memory_mask:
memory_lengths = input_lengths
else:
memory_lengths = None
# (B, T', mel_dim*r)
# decoder: output은 80band mel spectrogram
mel_outputs, alignments = self.decoder(
encoder_outputs, targets, memory_lengths=memory_lengths)
# Post net processing below
# Reshape
# (B, T, mel_dim)
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
# mel을 다시 waveform으로 바꿔주기 위해 필요함!
linear_outputs = self.CBHG(mel_outputs)
linear_outputs = self.last_linear(linear_outputs)
return mel_outputs, linear_outputs, alignments
Encoder 목표: 텍스트의 robust sequential representations 추출
작동순서
class Encoder(nn.Module):
def __init__(self, in_dim):
super(Encoder, self).__init__()
self.prenet = Prenet(in_dim, sizes=[256, 128])
self.cbhg = CBHG(128, K=16, projections=[128, 128])
def forward(self, inputs, input_lengths=None):
inputs = self.prenet(inputs)
return self.cbhg(inputs, input_lengths)
class Prenet(nn.Module):
def __init__(self, in_dim, sizes=[256, 128]):
super(Prenet, self).__init__()
in_sizes = [in_dim] + sizes[:-1]
self.layers = nn.ModuleList(
[nn.Linear(in_size, out_size)
for (in_size, out_size) in zip(in_sizes, sizes)])
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
def forward(self, inputs):
for linear in self.layers:
inputs = self.dropout(self.relu(linear(inputs)))
return inputs
Figure 2: Lee et al. (2016)에 채택된 The CBHG (1-D convolution bank + highway network + bidirectional GRU) module
class BatchNormConv1d(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, stride, padding,
activation=None):
super(BatchNormConv1d, self).__init__()
self.conv1d = nn.Conv1d(in_dim, out_dim,
kernel_size=kernel_size,
stride=stride, padding=padding, bias=False)
self.bn = nn.BatchNorm1d(out_dim)
self.activation = activation
def forward(self, x):
x = self.conv1d(x)
if self.activation is not None:
x = self.activation(x)
return self.bn(x)
class Highway(nn.Module):
def __init__(self, in_size, out_size):
super(Highway, self).__init__()
self.H = nn.Linear(in_size, out_size)
self.H.bias.data.zero_()
self.T = nn.Linear(in_size, out_size)
self.T.bias.data.fill_(-1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, inputs):
H = self.relu(self.H(inputs))
T = self.sigmoid(self.T(inputs))
return H * T + inputs * (1.0 - T)
class CBHG(nn.Module):
"""CBHG module: a recurrent neural network composed of:
- 1-d convolution banks
- Highway networks + residual connections
- Bidirectional gated recurrent units
"""
def __init__(self, in_dim, K=16, projections=[128, 128]):
super(CBHG, self).__init__()
self.in_dim = in_dim
self.relu = nn.ReLU()
# A. 1-d convolution banks
# 1)
self.conv1d_banks = nn.ModuleList(
[BatchNormConv1d(in_dim, in_dim, kernel_size=k, stride=1,
padding=k // 2, activation=self.relu)
for k in range(1, K + 1)])
# 2)
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
in_sizes = [K * in_dim] + projections[:-1]
activations = [self.relu] * (len(projections) - 1) + [None]
# 3)
self.conv1d_projections = nn.ModuleList(
[BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
padding=1, activation=ac)
for (in_size, out_size, ac) in zip(
in_sizes, projections, activations)])
# B. highway
self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False)
self.highways = nn.ModuleList(
[Highway(in_dim, in_dim) for _ in range(4)])
# C. GNN
self.gru = nn.GRU(
in_dim, in_dim, 1, batch_first=True, bidirectional=True)
def forward(self, inputs, input_lengths=None):
# (B, T_in, in_dim)
x = inputs
# Needed to perform conv1d on time-axis # time-resolution 보존위해
# (B, in_dim, T_in)
if x.size(-1) == self.in_dim:
x = x.transpose(1, 2)
T = x.size(-1)
# (B, in_dim*K, T_in)
# Concat conv1d bank outputs
x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1)
assert x.size(1) == self.in_dim * len(self.conv1d_banks)
x = self.max_pool1d(x)[:, :, :T]
for conv1d in self.conv1d_projections:
x = conv1d(x)
# (B, T_in, in_dim)
# Back to the original shape
x = x.transpose(1, 2)
if x.size(-1) != self.in_dim:
x = self.pre_highway(x)
# Residual connection
x += inputs
for highway in self.highways:
x = highway(x)
if input_lengths is not None:
x = nn.utils.rnn.pack_padded_sequence(
x, input_lengths, batch_first=True)
# (B, T_in, in_dim*2)
outputs, _ = self.gru(x)
if input_lengths is not None:
outputs, _ = nn.utils.rnn.pad_packed_sequence(
outputs, batch_first=True)
return outputs
(더 적은 bands나 cepstrum과 같은 더 간결한(concise) target을 사용할 수 있다)
speech signal과 text 사이에 alignment를 학습 잘 하는게 중요함.
raw spectrogram은 불필요한 representation이 많음. (길이 너무 길고, 용량도 너무 큼)
압축을 효율적으로 하기 위해서는 seq2seq output이 inversion process에 대한 prosody(운율) 정보와 sufficient intelligibility를 제공해야함.
>- 비교:
> - Raw 스펙트로그램: 모든 주파수를 동일하게 다루기 때문에 높은 주파수에서 불필요한 정보를 많이 포함할 수 있습니다. 이는 데이터의 크기와 처리 시간을 증가시킬 수 있습니다.
> - 80 밴드 mel-scale 스펙트로그램: 주파수 정보를 압축하여 중요한 음성 정보를 강조하는 반면, 높은 주파수에서의 세부 정보는 상대적으로 덜 표현합니다. 이로 인해 데이터는 더 작고, 처리는 더 빠르며, 효율성이 높아집니다.
class Decoder(nn.Module):
def __init__(self, in_dim, r):
super(Decoder, self).__init__()
self.in_dim = in_dim
self.r = r
self.prenet = Prenet(in_dim * r, sizes=[256, 128])
# (prenet_out + attention context) -> output
# Attention RNN
self.attention_rnn = AttentionWrapper(
nn.GRUCell(256 + 128, 256),
BahdanauAttention(256)
)
self.memory_layer = nn.Linear(256, 256, bias=False)
self.project_to_decoder_in = nn.Linear(512, 256)
self.decoder_rnns = nn.ModuleList(
[nn.GRUCell(256, 256) for _ in range(2)])
self.proj_to_mel = nn.Linear(256, in_dim * r)
self.max_decoder_steps = 200
def forward(self, encoder_outputs, inputs=None, memory_lengths=None):
"""
Decoder forward step.
If decoder inputs are not given (e.g., at testing time), as noted in
Tacotron paper, greedy decoding is adapted.
Args:
encoder_outputs: Encoder outputs. (B, T_encoder, dim)
inputs: Decoder inputs. i.e., mel-spectrogram. If None (at eval-time),
decoder outputs are used as decoder inputs.
memory_lengths: Encoder output (memory) lengths. If not None, used for
attention masking.
"""
B = encoder_outputs.size(0)
processed_memory = self.memory_layer(encoder_outputs)
if memory_lengths is not None:
mask = get_mask_from_lengths(processed_memory, memory_lengths)
else:
mask = None
# Run greedy decoding if inputs is None
greedy = inputs is None
# r-frame 단위로
if inputs is not None:
# Grouping multiple frames if necessary
if inputs.size(-1) == self.in_dim:
inputs = inputs.view(B, inputs.size(1) // self.r, -1)
assert inputs.size(-1) == self.in_dim * self.r
T_decoder = inputs.size(1)
# go frames (초기값)
initial_input = Variable(
encoder_outputs.data.new(B, self.in_dim * self.r).zero_())
# Init decoder states
attention_rnn_hidden = Variable(
encoder_outputs.data.new(B, 256).zero_())
decoder_rnn_hiddens = [Variable(
encoder_outputs.data.new(B, 256).zero_())
for _ in range(len(self.decoder_rnns))]
current_attention = Variable(
encoder_outputs.data.new(B, 256).zero_())
# Time first (T_decoder, B, in_dim)
if inputs is not None:
inputs = inputs.transpose(0, 1)
outputs = []
alignments = []
t = 0
current_input = initial_input
while True:
if t > 0:
current_input = outputs[-1] if greedy else inputs[t - 1]
# Prenet
current_input = self.prenet(current_input)
# Attention RNN
attention_rnn_hidden, current_attention, alignment = self.attention_rnn(
current_input, current_attention, attention_rnn_hidden,
encoder_outputs, processed_memory=processed_memory, mask=mask)
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(
torch.cat((attention_rnn_hidden, current_attention), -1))
# Pass through the decoder RNNs
for idx in range(len(self.decoder_rnns)):
decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
decoder_input, decoder_rnn_hiddens[idx])
# Residual connectinon
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
output = decoder_input
# 80-band mel spectrogram
output = self.proj_to_mel(output)
outputs += [output]
alignments += [alignment]
t += 1
if greedy:
if t > 1 and is_end_of_frames(output):
break
elif t > self.max_decoder_steps:
print("Warning! doesn't seems to be converged")
break
else:
if t >= T_decoder:
break
assert greedy or len(outputs) == T_decoder
# Back to batch first #output은 mel
alignments = torch.stack(alignments).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
return outputs, alignments
# 바다나우 어텐션 (self-attention이랑 같음)
class BahdanauAttention(nn.Module):
def __init__(self, dim):
super(BahdanauAttention, self).__init__()
self.query_layer = nn.Linear(dim, dim, bias=False)
self.tanh = nn.Tanh()
self.v = nn.Linear(dim, 1, bias=False)
def forward(self, query, processed_memory):
"""
Args:
query: (batch, 1, dim) or (batch, dim)
processed_memory: (batch, max_time, dim)
"""
if query.dim() == 2:
# insert time-axis for broadcasting
query = query.unsqueeze(1)
# (batch, 1, dim)
processed_query = self.query_layer(query)
# (batch, max_time, 1)
alignment = self.v(self.tanh(processed_query + processed_memory))
# (batch, max_time)
return alignment.squeeze(-1)
def get_mask_from_lengths(memory, memory_lengths):
"""Get mask tensor from list of length
Args:
memory: (batch, max_time, dim)
memory_lengths: array like
"""
mask = memory.data.new(memory.size(0), memory.size(1)).byte().zero_()
for idx, l in enumerate(memory_lengths):
mask[idx][:l] = 1
return ~mask
class AttentionWrapper(nn.Module):
def __init__(self, rnn_cell, attention_mechanism,
score_mask_value=-float("inf")):
super(AttentionWrapper, self).__init__()
self.rnn_cell = rnn_cell
self.attention_mechanism = attention_mechanism
self.score_mask_value = score_mask_value
def forward(self, query, attention, cell_state, memory,
processed_memory=None, mask=None, memory_lengths=None):
# attention_rnn_hidden, current_attention, alignment = self.attention_rnn(
# current_input, current_attention, attention_rnn_hidden,
# encoder_outputs, processed_memory=processed_memory, mask=mask)
if processed_memory is None:
processed_memory = memory
if memory_lengths is not None and mask is None:
mask = get_mask_from_lengths(memory, memory_lengths)
# Concat input query and previous attention context
cell_input = torch.cat((query, attention), -1)
# rnn에서 받은거랑 self-attention 값 concat해서 다음 rnn에 보내줌
# Feed it to RNN
cell_output = self.rnn_cell(cell_input, cell_state)
# Alignment
# (batch, max_time)
alignment = self.attention_mechanism(cell_output, processed_memory)
if mask is not None:
mask = mask.view(query.size(0), -1)
alignment.data.masked_fill_(mask, self.score_mask_value)
# Normalize attention weight
alignment = F.softmax(alignment)
# Attention context vector
# (batch, 1, dim)
attention = torch.bmm(alignment.unsqueeze(1), memory)
# (batch, dim)
attention = attention.squeeze(1)
return cell_output, attention, alignment
목적:
Motivation
특징
Neural Vocoder 이전에 전통적인 방법의 Vocoder 기술로 Griffin-Lim 알고리즘이 있다.
Griffin-Lim 알고리즘은 Mel-spectrogram으로 계산된 STFT magnitude 값만 가지고 원본 음성을 예측하는 rule-based 알고리즘이다. 원본 음성 신호를 복원하기 위해서는 STFT magnitude 값과 STFT phase 값이 필요하기 때문에 이 phase(위상) 값을 임의로 두고 예측을 시작한다. 그렇게 예측된 음성의 STFT magnitude 값과 원래 Mel-spectrogram으로 계산된 STFT magnitude 값의 mean squared error(MSE)가 최소가 되도록 반복 수행하여 원본 음성을 찾아낸다.
#Synthesis
# Greedy decoding
# return mel_outputs, linear_outputs, alignments
mel_outputs, linear_outputs, alignments = model(sequence)
linear_output = linear_outputs[0].cpu().data.numpy()
spectrogram = audio._denormalize(linear_output)
alignment = alignments[0].cpu().data.numpy()
# Predicted audio signal
waveform = audio.inv_spectrogram(linear_output.T)
def inv_spectrogram(spectrogram):
'''Converts spectrogram to waveform using librosa'''
S = _db_to_amp(_denormalize(spectrogram) + hparams.ref_level_db) # Convert back to linear
return inv_preemphasis(_griffin_lim(S ** hparams.power)) # Reconstruct phase
def _griffin_lim(S):
'''librosa implementation of Griffin-Lim
Based on https://github.com/librosa/librosa/issues/434
'''
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
S_complex = np.abs(S).astype(np.complex)
y = _istft(S_complex * angles)
for i in range(hparams.griffin_lim_iters):
angles = np.exp(1j * np.angle(_stft(y)))
y = _istft(S_complex * angles)
return y
Table 1: Hyper-parameters와 network architectures. “conv-k-c-ReLU”는 ReLu activation과 함께 width와 output channels이 있는 1-D convolution을 의미한다.
Setting
결과
seq2seq:
Tacotron은 Figure 3(c)과 같이 깨끗하고 부드러운 alignment을 학습한다.
Figure 3: Attention alignments on a test phrase. Tacotron에서 decoder 길이는 output reduction factor(출력감소계수) r=5를 사용하기 때문에 더 짧다.
setting
모델의 나머지는 encoder pre-net을 포함하여 전부 같다.
결과
audio-text pair만 있으면 학습이 가능한 end-to-end 모델이다. 이것이 왜 가능하냐면 alignment를 attention을 통해서 학습하기 때문이다. 처음 제시한 Tacotron에서는 Bahdanau attention을 사용하였다. Alignment를 직접 학습하는 모델의 특성상 훈련 시의 alignment가 잘되면 generalization의 성능이 좋다고 한다. 그래서 학습이 잘 안될떈 local sensitve attention과 같이 다양한 시도를 해보는게 좋다고 한다.
이 계열 모델의 단점은 auto-regressive한 방식으로 frame별로 생성하기 때문에 느리다. 그리고 alignment가 잘 안되면 아무 소용이 없다. alignment를 시각화 했을 때 중간이 끊어진다던가 반복하기 쉽다고 한다.
audio sample: https://google.github.io/tacotron/publications/tacotron/index.html
paper: https://arxiv.org/pdf/1712.05884
audio sample: https://google.github.io/tacotron/publications/tacotron2/
code:
음절 또는 음소마다의 음성파일을 이어붙이는 방법인 Concatenative Approach
https://pytorch.org/hub/nvidia_deeplearningexamples_tacotron2/
skip connection
수렴 속도 향상: 위의 모든 이유들로 인해, skip connection은 전반적으로 수렴 속도를 향상시키는 효과가 있습니다. 그래디언트의 효과적인 전달과 피쳐의 재사용은 모델이 더 빠르게 안정적인 성능에 도달하도록 돕습니다.