From b85dda63bc14d74b7b2acb925bef331dc425660c Mon Sep 17 00:00:00 2001 From: gosha20777 Date: Tue, 4 Dec 2018 20:43:23 +0300 Subject: [PATCH 1/7] add CMakeLists.txt file --- src/CMakeLists.txt | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 src/CMakeLists.txt diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 00000000..5ca50a1a --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,35 @@ +cmake_minimum_required(VERSION 3.1) + +project(lpcnet) + +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) + +add_executable(dump_data + denoise.c + kiss_fft.c + pitch.c + celt_lpc.c +) +target_compile_definitions(dump_data PRIVATE "TRAINING=1") +target_include_directories(dump_data PRIVATE ${CMAKE_CURRENT_SOURCE}/../include) +target_link_libraries(dump_data PRIVATE + "m" +) + + +add_executable(test_lpcnet + lpcnet.c + nnet.c + nnet_data.c +) +target_compile_definitions(test_lpcnet PRIVATE "mfma" "mavx2") +target_include_directories(test_lpcnet PRIVATE ${CMAKE_CURRENT_SOURCE}/../include) +target_link_libraries(test_lpcnet PRIVATE + "m" +) + +# gcc -DTRAINING=1 -Wall -W -O3 -g -I../include denoise.c kiss_fft.c pitch.c celt_lpc.c -o dump_data -lm +# gcc -mfma -msse3 -mavx2 -o test_lpcnet -g -O2 -Wall -W -Wextra lpcnet.c nnet.c nnet_data.c -lm + From d4d132da683e4b17e87caf95cd6a0f6f46fc8910 Mon Sep 17 00:00:00 2001 From: gosha20777 Date: Wed, 26 Dec 2018 15:35:16 +0300 Subject: [PATCH 2/7] sync with original --- COPYING | 31 ------------------------------- src/CMakeLists.txt | 35 ----------------------------------- 2 files changed, 66 deletions(-) delete mode 100644 COPYING delete mode 100644 src/CMakeLists.txt diff --git a/COPYING b/COPYING deleted file mode 100644 index feef112b..00000000 --- a/COPYING +++ /dev/null @@ -1,31 +0,0 @@ -Copyright (c) 2017-2018, Mozilla -Copyright (c) 2007-2017, Jean-Marc Valin -Copyright (c) 2005-2017, Xiph.Org Foundation -Copyright (c) 2003-2004, Mark Borgerding - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions -are met: - -- Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - -- Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in the -documentation and/or other materials provided with the distribution. - -- Neither the name of the Xiph.Org Foundation nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION -OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt deleted file mode 100644 index 5ca50a1a..00000000 --- a/src/CMakeLists.txt +++ /dev/null @@ -1,35 +0,0 @@ -cmake_minimum_required(VERSION 3.1) - -project(lpcnet) - -set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) -set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) -set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) - -add_executable(dump_data - denoise.c - kiss_fft.c - pitch.c - celt_lpc.c -) -target_compile_definitions(dump_data PRIVATE "TRAINING=1") -target_include_directories(dump_data PRIVATE ${CMAKE_CURRENT_SOURCE}/../include) -target_link_libraries(dump_data PRIVATE - "m" -) - - -add_executable(test_lpcnet - lpcnet.c - nnet.c - nnet_data.c -) -target_compile_definitions(test_lpcnet PRIVATE "mfma" "mavx2") -target_include_directories(test_lpcnet PRIVATE ${CMAKE_CURRENT_SOURCE}/../include) -target_link_libraries(test_lpcnet PRIVATE - "m" -) - -# gcc -DTRAINING=1 -Wall -W -O3 -g -I../include denoise.c kiss_fft.c pitch.c celt_lpc.c -o dump_data -lm -# gcc -mfma -msse3 -mavx2 -o test_lpcnet -g -O2 -Wall -W -Wextra lpcnet.c nnet.c nnet_data.c -lm - From 4c3a0e4434dc7f9145d9c777b345f57f962c0225 Mon Sep 17 00:00:00 2001 From: gosha20777 Date: Wed, 26 Dec 2018 17:19:47 +0300 Subject: [PATCH 3/7] add speed up version for training lpcnet --- src/train_lpcnet.py | 6 +- src/train_lpcnet_boost.py | 156 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 5 deletions(-) mode change 100755 => 100644 src/train_lpcnet.py create mode 100755 src/train_lpcnet_boost.py diff --git a/src/train_lpcnet.py b/src/train_lpcnet.py old mode 100755 new mode 100644 index 81d7f5a1..2f822fe5 --- a/src/train_lpcnet.py +++ b/src/train_lpcnet.py @@ -1,17 +1,13 @@ #!/usr/bin/python3 '''Copyright (c) 2018 Mozilla - Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR @@ -147,4 +143,4 @@ #model.load_weights('lpcnet9b_384_10_G16_01.h5') model.compile(optimizer=Adam(0.001, amsgrad=True, decay=5e-5), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) -model.fit([in_data, in_exc, features, periods], out_data, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, lpcnet.Sparsify(2000, 40000, 400, (0.1, 0.1, 0.1))]) +model.fit([in_data, in_exc, features, periods], out_data, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, lpcnet.Sparsify(2000, 40000, 400, (0.1, 0.1, 0.1))]) \ No newline at end of file diff --git a/src/train_lpcnet_boost.py b/src/train_lpcnet_boost.py new file mode 100755 index 00000000..aaf8a937 --- /dev/null +++ b/src/train_lpcnet_boost.py @@ -0,0 +1,156 @@ +#!/usr/bin/python3 +'''Copyright (c) 2018 Mozilla + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR + CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +''' + +# Train a LPCNet model (note not a Wavenet model) + +import lpcnet +import sys +import numpy as np +from keras.optimizers import Adam +from keras.callbacks import ModelCheckpoint +from ulaw import ulaw2lin, lin2ulaw +import keras.backend as K +import h5py + +import tensorflow as tf +from keras.backend.tensorflow_backend import set_session +config = tf.ConfigProto() + +# use this option to reserve GPU memory, e.g. for running more than +# one thing at a time. Best to disable for GPUs with small memory +config.gpu_options.per_process_gpu_memory_fraction = 0.44 + +set_session(tf.Session(config=config)) + +nb_epochs = 120 + +# Try reducing batch_size if you run out of memory on your GPU +batch_size = 64 + +model, _, _ = lpcnet.new_lpcnet_model() + +model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) +model.summary() + +feature_file = sys.argv[1] +pcm_file = sys.argv[2] # 16 bit unsigned short PCM samples +frame_size = 160 +nb_features = 55 +nb_used_features = model.nb_used_features +feature_chunk_size = 15 +pcm_chunk_size = frame_size*feature_chunk_size + +# u for unquantised, load 16 bit PCM samples and convert to mu-law + +class Dataset(object): + def __init__(self, pcm_file, feature_file, batch_size): + print("start loading data from files...") + self._udata = np.fromfile(pcm_file, dtype='int16') + self._features = np.fromfile(feature_file, dtype='float32') + _nb_frames = len(self._udata) // pcm_chunk_size + self.loopcount = _nb_frames // batch_size + print("loopcount: ", self.loopcount) + # limit to discrete number of frames + self._udata = self._udata[:_nb_frames * pcm_chunk_size] + self._features = self._features[:_nb_frames * feature_chunk_size * nb_features] + + def __len__(self): + return self.loopcount + + def __iter__(self): + return self + + def __next__(self): + i = np.random.randint(0, self.loopcount) + nb_frames = batch_size + udata = self._udata[i * batch_size * pcm_chunk_size: + (i + 1) * batch_size * pcm_chunk_size] + features = self._features[i * batch_size * feature_chunk_size * nb_features: + (i + 1) * batch_size * feature_chunk_size * nb_features] + + data = lin2ulaw(udata) + in_data = np.concatenate([data[0:1], data[:-1]]) + noise = np.concatenate([np.zeros((len(data) * 1 // 5)), np.random.randint(-3, 3, len(data) * 1 // 5), + np.random.randint(-2, 2, len(data) * 1 // 5), + np.random.randint(-1, 1, len(data) * 2 // 5)]) + # noise = np.round(np.concatenate([np.zeros((len(data)*1//5)), np.random.laplace(0, 1.2, len(data)*1//5), np.random.laplace(0, .77, len(data)*1//5), np.random.laplace(0, .33, len(data)*1//5), np.random.randint(-1, 1, len(data)*1//5)])) + in_data = in_data + noise + in_data = np.clip(in_data, 0, 255) + + features = np.reshape(features, (nb_frames * feature_chunk_size, nb_features)) + + # Note: the LPC predictor output is now calculated by the loop below, this code was + # for an ealier version that implemented the prediction filter in C + + upred = np.zeros((nb_frames * pcm_chunk_size,), dtype='float32') + + # Use 16th order LPC to generate LPC prediction output upred[] and (in + # mu-law form) pred[] + + pred_in = ulaw2lin(in_data) + for i in range(2, nb_frames * feature_chunk_size): + upred[i * frame_size:(i + 1) * frame_size] = 0 + for k in range(16): + upred[i * frame_size:(i + 1) * frame_size] = upred[i * frame_size:(i + 1) * frame_size] - \ + pred_in[i * frame_size - k:(i + 1) * frame_size - k] * \ + features[i, nb_features - 16 + k] + + pred = lin2ulaw(upred) + in_data = np.reshape(in_data, (nb_frames, pcm_chunk_size, 1)) + in_data = in_data.astype('uint8') + + # LPC residual, which is the difference between the input speech and + # the predictor output, with a slight time shift this is also the + # ideal excitation in_exc + + out_data = lin2ulaw(udata - upred) + in_exc = np.concatenate([out_data[0:1], out_data[:-1]]) + out_data = np.reshape(out_data, (nb_frames, pcm_chunk_size, 1)) + out_data = out_data.astype('uint8') + in_exc = np.reshape(in_exc, (nb_frames, pcm_chunk_size, 1)) + in_exc = in_exc.astype('uint8') + features = np.reshape(features, (nb_frames, feature_chunk_size, nb_features)) + features = features[:, :, :nb_used_features] + features[:, :, 18:36] = 0 + pred = np.reshape(pred, (nb_frames, pcm_chunk_size, 1)) + pred = pred.astype('uint8') + periods = (.1 + 50 * features[:, :, 36:37] + 100).astype('int16') + in_data = np.concatenate([in_data, pred], axis=-1) + return [in_data, in_exc, features, periods], out_data + + +# dump models to disk as we go +checkpoint = ModelCheckpoint('lpcnet15_384_10_G16_{epoch:02d}.h5') + +# model.load_weights('lpcnet9b_384_10_G16_01.h5') +model.compile(optimizer=Adam(0.001, amsgrad=True, decay=5e-5), loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy']) + +g = Dataset(pcm_file, feature_file, batch_size) +model.fit_generator(g, epochs=nb_epochs, + steps_per_epoch=len(g), + callbacks=[checkpoint, lpcnet.Sparsify(2000, 40000, 400, (0.1, 0.1, 0.1))]) \ No newline at end of file From c6fa41114ec05d0fbc9f42cf160f4a55a343e90e Mon Sep 17 00:00:00 2001 From: gosha20777 Date: Wed, 2 Jan 2019 23:26:34 +0300 Subject: [PATCH 4/7] sync last updates with original branch --- COPYING | 31 ++++++++ README.md | 4 +- src/common.h | 48 +++++++++++- src/dump_data.c | 49 ++++++++++-- src/lpcnet.c | 23 ------ src/train_lpcnet.py | 78 +++++-------------- src/train_lpcnet_boost.py | 156 -------------------------------------- 7 files changed, 138 insertions(+), 251 deletions(-) create mode 100644 COPYING mode change 100644 => 100755 src/train_lpcnet.py delete mode 100755 src/train_lpcnet_boost.py diff --git a/COPYING b/COPYING new file mode 100644 index 00000000..feef112b --- /dev/null +++ b/COPYING @@ -0,0 +1,31 @@ +Copyright (c) 2017-2018, Mozilla +Copyright (c) 2007-2017, Jean-Marc Valin +Copyright (c) 2005-2017, Xiph.Org Foundation +Copyright (c) 2003-2004, Mark Borgerding + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +- Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +- Neither the name of the Xiph.Org Foundation nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION +OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md index c9b2f0a9..6108ed25 100644 --- a/README.md +++ b/README.md @@ -19,13 +19,13 @@ This software is an open source starting point for WaveRNN-based speech synthesi 1. Generate training data: ``` make dump_data - ./dump_data -train input.s16 features.f32 pcm.s16 + ./dump_data -train input.s16 features.f32 data.u8 ``` where the first file contains 16 kHz 16-bit raw PCM audio (no header) and the other files are output files. This program makes several passes over the data with different filters to generate a large amount of training data. 1. Now that you have your files, train with: ``` - ./train_lpcnet.py features.f32 pcm.s16 + ./train_lpcnet.py features.f32 data.u8 ``` and it will generate a wavenet*.h5 file for each iteration. If it stops with a "Failed to allocate RNN reserve space" message try reducing the *batch\_size* variable in train_wavenet_audio.py. diff --git a/src/common.h b/src/common.h index 1e24a28f..9510072b 100644 --- a/src/common.h +++ b/src/common.h @@ -3,14 +3,58 @@ #ifndef COMMON_H #define COMMON_H -#include "stdlib.h" -#include "string.h" +#include +#include +#include #define RNN_INLINE inline #define OPUS_INLINE inline float lpc_from_cepstrum(float *lpc, const float *cepstrum); +#define LOG256 5.5451774445f +static RNN_INLINE float log2_approx(float x) +{ + int integer; + float frac; + union { + float f; + int i; + } in; + in.f = x; + integer = (in.i>>23)-127; + in.i -= integer<<23; + frac = in.f - 1.5f; + frac = -0.41445418f + frac*(0.95909232f + + frac*(-0.33951290f + frac*0.16541097f)); + return 1+integer+frac; +} + +#define log_approx(x) (0.69315f*log2_approx(x)) + +static RNN_INLINE float ulaw2lin(float u) +{ + float s; + float scale_1 = 32768.f/255.f; + u = u - 128; + s = u >= 0 ? 1 : -1; + u = fabs(u); + return s*scale_1*(exp(u/128.*LOG256)-1); +} + +static RNN_INLINE int lin2ulaw(float x) +{ + float u; + float scale = 255.f/32768.f; + int s = x >= 0 ? 1 : -1; + x = fabs(x); + u = (s*(128*log_approx(1+scale*x)/LOG256)); + u = 128 + u; + if (u < 0) u = 0; + if (u > 255) u = 255; + return (int)floor(.5 + u); +} + /** RNNoise wrapper for malloc(). To do your own dynamic allocation, all you need t o do is replace this function and rnnoise_free */ diff --git a/src/dump_data.c b/src/dump_data.c index 5ec671f2..1db653d0 100644 --- a/src/dump_data.c +++ b/src/dump_data.c @@ -55,13 +55,12 @@ typedef struct { float analysis_mem[OVERLAP_SIZE]; float cepstral_mem[CEPS_MEM][NB_BANDS]; - int memid; float pitch_buf[PITCH_BUF_SIZE]; - float pitch_enh_buf[PITCH_BUF_SIZE]; float last_gain; int last_period; - float mem_hp_x[2]; - float lastg[NB_BANDS]; + float lpc[LPC_ORDER]; + float sig_mem[LPC_ORDER]; + int exc_mem; } DenoiseState; static int rnnoise_get_size() { @@ -112,7 +111,6 @@ static void compute_frame_features(DenoiseState *st, kiss_fft_cpx *X, kiss_fft_c int i; float E = 0; float Ly[NB_BANDS]; - float lpc[LPC_ORDER]; float p[WINDOW_SIZE]; float pitch_buf[PITCH_BUF_SIZE]; int pitch_index; @@ -154,7 +152,7 @@ static void compute_frame_features(DenoiseState *st, kiss_fft_cpx *X, kiss_fft_c } dct(features, Ly); features[0] -= 4; - g = lpc_from_cepstrum(lpc, features); + g = lpc_from_cepstrum(st->lpc, features); #if 0 for (i=0;ilpc[i]; #if 0 for (i=0;ilpc[j]*st->sig_mem[j]; + e = lin2ulaw(pcm[i] - p); + /* Signal. */ + data[4*i] = lin2ulaw(st->sig_mem[0]); + /* Prediction. */ + data[4*i+1] = lin2ulaw(p); + /* Excitation in. */ + data[4*i+2] = st->exc_mem; + /* Excitation out. */ + data[4*i+3] = e; + /* Simulate error on excitation. */ + noise = (int)floor(.5 + noise_std*.707*(log_approx((float)rand()/RAND_MAX)-log_approx((float)rand()/RAND_MAX))); + e += noise; + e = IMIN(255, IMAX(0, e)); + + RNN_MOVE(&st->sig_mem[1], &st->sig_mem[0], LPC_ORDER-1); + st->sig_mem[0] = p + ulaw2lin(e); + st->exc_mem = e; + } + fwrite(data, 4*FRAME_SIZE, 1, file); +} + int main(int argc, char **argv) { int i; int count=0; @@ -225,6 +253,7 @@ int main(int argc, char **argv) { float old_speech_gain = 1; int one_pass_completed = 0; DenoiseState *st; + float noise_std=0; int training = -1; st = rnnoise_create(); if (argc == 5 && strcmp(argv[1], "-train")==0) training = 1; @@ -286,11 +315,14 @@ int main(int argc, char **argv) { } if (count>=5000000 && one_pass_completed) break; if (training && ++gain_change_count > 2821) { + float tmp; speech_gain = pow(10., (-20+(rand()%40))/20.); if (rand()%20==0) speech_gain *= .01; if (rand()%100==0) speech_gain = 0; gain_change_count = 0; rand_resp(a_sig, b_sig); + tmp = (float)rand()/RAND_MAX; + noise_std = 4*tmp*tmp; } biquad(x, mem_hp_x, x, b_hp, a_hp, FRAME_SIZE); biquad(x, mem_resp_x, x, b_sig, a_sig, FRAME_SIZE); @@ -306,7 +338,8 @@ int main(int argc, char **argv) { fwrite(features, sizeof(float), NB_FEATURES, ffeat); /* PCM is delayed by 1/2 frame to make the features centered on the frames. */ for (i=0;i= 0 ? 1 : -1; - u = fabs(u); - return s*scale_1*(exp(u/128.*log(256))-1); -} - -static int lin2ulaw(float x) -{ - float u; - float scale = 255.f/32768.f; - int s = x >= 0 ? 1 : -1; - x = fabs(x); - u = (s*(128*log(1+scale*x)/log(256))); - u = 128 + u; - if (u < 0) u = 0; - if (u > 255) u = 255; - return (int)floor(.5 + u); -} - #if 0 static void print_vector(float *x, int N) { diff --git a/src/train_lpcnet.py b/src/train_lpcnet.py old mode 100644 new mode 100755 index 2f822fe5..837778f0 --- a/src/train_lpcnet.py +++ b/src/train_lpcnet.py @@ -1,13 +1,17 @@ #!/usr/bin/python3 '''Copyright (c) 2018 Mozilla + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + - Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR @@ -62,85 +66,39 @@ # u for unquantised, load 16 bit PCM samples and convert to mu-law -udata = np.fromfile(pcm_file, dtype='int16') -data = lin2ulaw(udata) -nb_frames = len(data)//pcm_chunk_size +data = np.fromfile(pcm_file, dtype='uint8') +nb_frames = len(data)//(4*pcm_chunk_size) features = np.fromfile(feature_file, dtype='float32') # limit to discrete number of frames -data = data[:nb_frames*pcm_chunk_size] -udata = udata[:nb_frames*pcm_chunk_size] +data = data[:nb_frames*4*pcm_chunk_size] features = features[:nb_frames*feature_chunk_size*nb_features] -# Noise injection: the idea is that the real system is going to be -# predicting samples based on previously predicted samples rather than -# from the original. Since the previously predicted samples aren't -# expected to be so good, I add noise to the training data. Exactly -# how the noise is added makes a huge difference - -in_data = np.concatenate([data[0:1], data[:-1]]); -noise = np.concatenate([np.zeros((len(data)*1//5)), np.random.randint(-3, 3, len(data)*1//5), np.random.randint(-2, 2, len(data)*1//5), np.random.randint(-1, 1, len(data)*2//5)]) -#noise = np.round(np.concatenate([np.zeros((len(data)*1//5)), np.random.laplace(0, 1.2, len(data)*1//5), np.random.laplace(0, .77, len(data)*1//5), np.random.laplace(0, .33, len(data)*1//5), np.random.randint(-1, 1, len(data)*1//5)])) -del data -in_data = in_data + noise -del noise -in_data = np.clip(in_data, 0, 255) - features = np.reshape(features, (nb_frames*feature_chunk_size, nb_features)) -# Note: the LPC predictor output is now calculated by the loop below, this code was -# for an ealier version that implemented the prediction filter in C - -upred = np.zeros((nb_frames*pcm_chunk_size,), dtype='float32') - -# Use 16th order LPC to generate LPC prediction output upred[] and (in -# mu-law form) pred[] - -pred_in = ulaw2lin(in_data) -for i in range(2, nb_frames*feature_chunk_size): - upred[i*frame_size:(i+1)*frame_size] = 0 - for k in range(16): - upred[i*frame_size:(i+1)*frame_size] = upred[i*frame_size:(i+1)*frame_size] - \ - pred_in[i*frame_size-k:(i+1)*frame_size-k]*features[i, nb_features-16+k] -del pred_in - -pred = lin2ulaw(upred) - -in_data = np.reshape(in_data, (nb_frames, pcm_chunk_size, 1)) -in_data = in_data.astype('uint8') - -# LPC residual, which is the difference between the input speech and -# the predictor output, with a slight time shift this is also the -# ideal excitation in_exc - -out_data = lin2ulaw(udata-upred) -del upred -del udata -in_exc = np.concatenate([out_data[0:1], out_data[:-1]]); - -out_data = np.reshape(out_data, (nb_frames, pcm_chunk_size, 1)) -out_data = out_data.astype('uint8') - -in_exc = np.reshape(in_exc, (nb_frames, pcm_chunk_size, 1)) -in_exc = in_exc.astype('uint8') +sig = np.reshape(data[0::4], (nb_frames, pcm_chunk_size, 1)) +pred = np.reshape(data[1::4], (nb_frames, pcm_chunk_size, 1)) +in_exc = np.reshape(data[2::4], (nb_frames, pcm_chunk_size, 1)) +out_exc = np.reshape(data[3::4], (nb_frames, pcm_chunk_size, 1)) +del data +print("ulaw std = ", np.std(out_exc)) features = np.reshape(features, (nb_frames, feature_chunk_size, nb_features)) features = features[:, :, :nb_used_features] features[:,:,18:36] = 0 -pred = np.reshape(pred, (nb_frames, pcm_chunk_size, 1)) -pred = pred.astype('uint8') periods = (.1 + 50*features[:,:,36:37]+100).astype('int16') -in_data = np.concatenate([in_data, pred], axis=-1) +in_data = np.concatenate([sig, pred], axis=-1) +del sig del pred # dump models to disk as we go -checkpoint = ModelCheckpoint('lpcnet15_384_10_G16_{epoch:02d}.h5') +checkpoint = ModelCheckpoint('lpcnet18_384_10_G16_{epoch:02d}.h5') -#model.load_weights('lpcnet9b_384_10_G16_01.h5') +model.load_weights('lpcnet9b_384_10_G16_01.h5') model.compile(optimizer=Adam(0.001, amsgrad=True, decay=5e-5), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) -model.fit([in_data, in_exc, features, periods], out_data, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, lpcnet.Sparsify(2000, 40000, 400, (0.1, 0.1, 0.1))]) \ No newline at end of file +model.fit([in_data, in_exc, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, lpcnet.Sparsify(2000, 40000, 400, (0.05, 0.05, 0.2))]) diff --git a/src/train_lpcnet_boost.py b/src/train_lpcnet_boost.py deleted file mode 100755 index aaf8a937..00000000 --- a/src/train_lpcnet_boost.py +++ /dev/null @@ -1,156 +0,0 @@ -#!/usr/bin/python3 -'''Copyright (c) 2018 Mozilla - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions - are met: - - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR - CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -''' - -# Train a LPCNet model (note not a Wavenet model) - -import lpcnet -import sys -import numpy as np -from keras.optimizers import Adam -from keras.callbacks import ModelCheckpoint -from ulaw import ulaw2lin, lin2ulaw -import keras.backend as K -import h5py - -import tensorflow as tf -from keras.backend.tensorflow_backend import set_session -config = tf.ConfigProto() - -# use this option to reserve GPU memory, e.g. for running more than -# one thing at a time. Best to disable for GPUs with small memory -config.gpu_options.per_process_gpu_memory_fraction = 0.44 - -set_session(tf.Session(config=config)) - -nb_epochs = 120 - -# Try reducing batch_size if you run out of memory on your GPU -batch_size = 64 - -model, _, _ = lpcnet.new_lpcnet_model() - -model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) -model.summary() - -feature_file = sys.argv[1] -pcm_file = sys.argv[2] # 16 bit unsigned short PCM samples -frame_size = 160 -nb_features = 55 -nb_used_features = model.nb_used_features -feature_chunk_size = 15 -pcm_chunk_size = frame_size*feature_chunk_size - -# u for unquantised, load 16 bit PCM samples and convert to mu-law - -class Dataset(object): - def __init__(self, pcm_file, feature_file, batch_size): - print("start loading data from files...") - self._udata = np.fromfile(pcm_file, dtype='int16') - self._features = np.fromfile(feature_file, dtype='float32') - _nb_frames = len(self._udata) // pcm_chunk_size - self.loopcount = _nb_frames // batch_size - print("loopcount: ", self.loopcount) - # limit to discrete number of frames - self._udata = self._udata[:_nb_frames * pcm_chunk_size] - self._features = self._features[:_nb_frames * feature_chunk_size * nb_features] - - def __len__(self): - return self.loopcount - - def __iter__(self): - return self - - def __next__(self): - i = np.random.randint(0, self.loopcount) - nb_frames = batch_size - udata = self._udata[i * batch_size * pcm_chunk_size: - (i + 1) * batch_size * pcm_chunk_size] - features = self._features[i * batch_size * feature_chunk_size * nb_features: - (i + 1) * batch_size * feature_chunk_size * nb_features] - - data = lin2ulaw(udata) - in_data = np.concatenate([data[0:1], data[:-1]]) - noise = np.concatenate([np.zeros((len(data) * 1 // 5)), np.random.randint(-3, 3, len(data) * 1 // 5), - np.random.randint(-2, 2, len(data) * 1 // 5), - np.random.randint(-1, 1, len(data) * 2 // 5)]) - # noise = np.round(np.concatenate([np.zeros((len(data)*1//5)), np.random.laplace(0, 1.2, len(data)*1//5), np.random.laplace(0, .77, len(data)*1//5), np.random.laplace(0, .33, len(data)*1//5), np.random.randint(-1, 1, len(data)*1//5)])) - in_data = in_data + noise - in_data = np.clip(in_data, 0, 255) - - features = np.reshape(features, (nb_frames * feature_chunk_size, nb_features)) - - # Note: the LPC predictor output is now calculated by the loop below, this code was - # for an ealier version that implemented the prediction filter in C - - upred = np.zeros((nb_frames * pcm_chunk_size,), dtype='float32') - - # Use 16th order LPC to generate LPC prediction output upred[] and (in - # mu-law form) pred[] - - pred_in = ulaw2lin(in_data) - for i in range(2, nb_frames * feature_chunk_size): - upred[i * frame_size:(i + 1) * frame_size] = 0 - for k in range(16): - upred[i * frame_size:(i + 1) * frame_size] = upred[i * frame_size:(i + 1) * frame_size] - \ - pred_in[i * frame_size - k:(i + 1) * frame_size - k] * \ - features[i, nb_features - 16 + k] - - pred = lin2ulaw(upred) - in_data = np.reshape(in_data, (nb_frames, pcm_chunk_size, 1)) - in_data = in_data.astype('uint8') - - # LPC residual, which is the difference between the input speech and - # the predictor output, with a slight time shift this is also the - # ideal excitation in_exc - - out_data = lin2ulaw(udata - upred) - in_exc = np.concatenate([out_data[0:1], out_data[:-1]]) - out_data = np.reshape(out_data, (nb_frames, pcm_chunk_size, 1)) - out_data = out_data.astype('uint8') - in_exc = np.reshape(in_exc, (nb_frames, pcm_chunk_size, 1)) - in_exc = in_exc.astype('uint8') - features = np.reshape(features, (nb_frames, feature_chunk_size, nb_features)) - features = features[:, :, :nb_used_features] - features[:, :, 18:36] = 0 - pred = np.reshape(pred, (nb_frames, pcm_chunk_size, 1)) - pred = pred.astype('uint8') - periods = (.1 + 50 * features[:, :, 36:37] + 100).astype('int16') - in_data = np.concatenate([in_data, pred], axis=-1) - return [in_data, in_exc, features, periods], out_data - - -# dump models to disk as we go -checkpoint = ModelCheckpoint('lpcnet15_384_10_G16_{epoch:02d}.h5') - -# model.load_weights('lpcnet9b_384_10_G16_01.h5') -model.compile(optimizer=Adam(0.001, amsgrad=True, decay=5e-5), loss='sparse_categorical_crossentropy', - metrics=['sparse_categorical_accuracy']) - -g = Dataset(pcm_file, feature_file, batch_size) -model.fit_generator(g, epochs=nb_epochs, - steps_per_epoch=len(g), - callbacks=[checkpoint, lpcnet.Sparsify(2000, 40000, 400, (0.1, 0.1, 0.1))]) \ No newline at end of file From 2dcd5dd3e8a7a8eb85e3fff1819a1f50927bb69e Mon Sep 17 00:00:00 2001 From: gosha20777 Date: Fri, 4 Jan 2019 02:11:51 +0300 Subject: [PATCH 5/7] CMakeLists.txt --- CMakeLists.txt | 46 +++++++++++++++++++++++++++++++++++++ src/compile.sh | 4 ++-- src/vec_avx.h | 61 +++++++++++++++++++++++++++++++++----------------- 3 files changed, 88 insertions(+), 23 deletions(-) create mode 100644 CMakeLists.txt mode change 100755 => 100644 src/compile.sh diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000..ad9fc7b3 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,46 @@ +# gcc -DTRAINING=1 -Wall -W -O3 -g -I../include denoise.c kiss_fft.c pitch.c celt_lpc.c -o dump_data -lm +# gcc -mfma -msse3 -mavx2 -o test_lpcnet -g -O2 -Wall -W -Wextra lpcnet.c nnet.c nnet_data.c -lm + +cmake_minimum_required(VERSION 3.1) + +project(lpcnet) + +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) + +add_executable(dump_data + src/dump_data.c + src/freq.c + src/kiss_fft.c + src/pitch.c + src/celt_lpc.c +) + + +target_compile_definitions(dump_data PRIVATE "TRAINING=1") +target_include_directories(dump_data PRIVATE ${CMAKE_CURRENT_SOURCE}/) +target_link_libraries(dump_data PRIVATE + "m" +) + + +add_executable(test_lpcnet + src/test_lpcnet.c + src/lpcnet.c + src/nnet.c + src/nnet_data.c + src/freq.c + src/kiss_fft.c + src/pitch.c + src/celt_lpc.c +) +add_custom_command(OUTPUT nnet_data.c + COMMAND dump_lpcnet.py `ls -t lpcnet*.h5 | head -1` + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE}/src) + +target_compile_options (test_lpcnet PRIVATE -march=native) +target_include_directories(test_lpcnet PRIVATE ${CMAKE_CURRENT_SOURCE}/) +target_link_libraries(test_lpcnet PRIVATE + "m" +) diff --git a/src/compile.sh b/src/compile.sh old mode 100755 new mode 100644 index 327f23fa..c17db720 --- a/src/compile.sh +++ b/src/compile.sh @@ -1,4 +1,4 @@ #!/bin/sh -gcc -Wall -W -O3 -g -I../include dump_data.c freq.c kiss_fft.c pitch.c celt_lpc.c -o dump_data -lm -gcc -o test_lpcnet -mavx2 -mfma -g -O3 -Wall -W -Wextra test_lpcnet.c lpcnet.c nnet.c nnet_data.c freq.c kiss_fft.c pitch.c celt_lpc.c -lm +gcc -Wall -fopenmp -W -O3 -g -I../include dump_data.c freq.c kiss_fft.c pitch.c celt_lpc.c -o dump_data -lm -Wunused-result +gcc -o test_lpcnet -mavx2 -mfma -g -O3 -Wall -W -Wextra test_lpcnet.c lpcnet.c nnet.c nnet_data.c freq.c kiss_fft.c pitch.c celt_lpc.c -lm -fopenmp -Wunknown-pragmas -Wunused-result diff --git a/src/vec_avx.h b/src/vec_avx.h index 1e58f8d1..aa6f6cd5 100644 --- a/src/vec_avx.h +++ b/src/vec_avx.h @@ -187,33 +187,52 @@ static void sgemv_accum16(float *out, const float *weights, int rows, int cols, } static void sparse_sgemv_accum16(float *out, const float *weights, int rows, const int *idx, const float *x) { - int i, j; - for (i=0;i=rows) + flag=0; + else + { + y = &out[i]; + cols = *idx++; + idx+=cols; + lc_weights=weights; + weights+=16*cols; + } + } + if (flag) + { + vy0 = _mm256_loadu_ps(&y[0]); + vy8 = _mm256_loadu_ps(&y[8]); + for (j=0;j Date: Fri, 4 Jan 2019 02:22:27 +0300 Subject: [PATCH 6/7] add OpetMP support --- src/vec_avx.h | 76 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 53 insertions(+), 23 deletions(-) diff --git a/src/vec_avx.h b/src/vec_avx.h index aa6f6cd5..f7c538cb 100644 --- a/src/vec_avx.h +++ b/src/vec_avx.h @@ -187,41 +187,71 @@ static void sgemv_accum16(float *out, const float *weights, int rows, int cols, } static void sparse_sgemv_accum16(float *out, const float *weights, int rows, const int *idx, const float *x) { - int i=-16; - #pragma omp parallel - while (i=rows) - flag=0; - else - { - y = &out[i]; - cols = *idx++; - idx+=cols; - lc_weights=weights; - weights+=16*cols; - } + int id; + __m256 vxj; + __m256 vw; + id = *idx++; + vxj = _mm256_broadcast_ss(&x[id]); + + vw = _mm256_loadu_ps(&weights[0]); + vy0 = _mm256_fmadd_ps(vw, vxj, vy0); + + vw = _mm256_loadu_ps(&weights[8]); + vy8 = _mm256_fmadd_ps(vw, vxj, vy8); + weights += 16; } - if (flag) - { + _mm256_storeu_ps (&y[0], vy0); + _mm256_storeu_ps (&y[8], vy8); + } + */ + + int i, j; + //initialization + const int *precomputed_idx[rows]; + const float *precomputed_weights[rows]; + for (i=0;i Date: Thu, 10 Jan 2019 17:00:20 +0300 Subject: [PATCH 7/7] remove-OMP --- src/compile.sh | 4 ++-- src/vec_avx.h | 50 -------------------------------------------------- 2 files changed, 2 insertions(+), 52 deletions(-) diff --git a/src/compile.sh b/src/compile.sh index c17db720..1c469b84 100644 --- a/src/compile.sh +++ b/src/compile.sh @@ -1,4 +1,4 @@ #!/bin/sh -gcc -Wall -fopenmp -W -O3 -g -I../include dump_data.c freq.c kiss_fft.c pitch.c celt_lpc.c -o dump_data -lm -Wunused-result -gcc -o test_lpcnet -mavx2 -mfma -g -O3 -Wall -W -Wextra test_lpcnet.c lpcnet.c nnet.c nnet_data.c freq.c kiss_fft.c pitch.c celt_lpc.c -lm -fopenmp -Wunknown-pragmas -Wunused-result +gcc -Wall -W -O3 -g -I../include dump_data.c freq.c kiss_fft.c pitch.c celt_lpc.c -o dump_data -lm -Wunused-result +gcc -o test_lpcnet -mavx2 -mfma -g -O3 -Wall -W -Wextra test_lpcnet.c lpcnet.c nnet.c nnet_data.c freq.c kiss_fft.c pitch.c celt_lpc.c -lm -Wunknown-pragmas -Wunused-result diff --git a/src/vec_avx.h b/src/vec_avx.h index f7c538cb..6ad082f6 100644 --- a/src/vec_avx.h +++ b/src/vec_avx.h @@ -187,7 +187,6 @@ static void sgemv_accum16(float *out, const float *weights, int rows, int cols, } static void sparse_sgemv_accum16(float *out, const float *weights, int rows, const int *idx, const float *x) { - /* original code int i, j; for (i=0;i