diff --git a/examples/stream/stream.cc b/examples/stream/stream.cc index 68e4780..63588a3 100644 --- a/examples/stream/stream.cc +++ b/examples/stream/stream.cc @@ -1,6 +1,7 @@ #include "common-sdl.h" #include "common.h" #include "sense-voice.h" +#include "silero-vad.h" #include #include #include @@ -9,6 +10,41 @@ #include #include #include +#include + +// WAV 文件头结构 +struct WAVHeader { + char riff[4] = {'R', 'I', 'F', 'F'}; + uint32_t file_size; + char wave[4] = {'W', 'A', 'V', 'E'}; + char fmt[4] = {'f', 'm', 't', ' '}; + uint32_t fmt_size = 16; + uint16_t audio_format = 1; // PCM + uint16_t num_channels = 1; // 单声道 + uint32_t sample_rate = SENSE_VOICE_SAMPLE_RATE; + uint32_t byte_rate; + uint16_t block_align; + uint16_t bits_per_sample = 16; + char data[4] = {'d', 'a', 't', 'a'}; + uint32_t data_size; + + WAVHeader() { + byte_rate = sample_rate * num_channels * bits_per_sample / 8; + block_align = num_channels * bits_per_sample / 8; + file_size = 0; // 将在写入时更新 + data_size = 0; // 将在写入时更新 + } +}; + +// 将 float 音频数据转换为 16 位 PCM +void float_to_pcm16(const std::vector& float_audio, std::vector& pcm16_audio) { + pcm16_audio.resize(float_audio.size()); + for (size_t i = 0; i < float_audio.size(); ++i) { + // 限制在 [-1.0, 1.0] 范围内,然后转换为 16 位 + float sample = std::max(-1.0f, std::min(1.0f, float_audio[i])); + pcm16_audio[i] = static_cast(sample * 32767.0f); + } +} struct sense_voice_stream_params { int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); @@ -27,6 +63,10 @@ struct sense_voice_stream_params { std::string language = "auto"; std::string model = "models/ggml-base.en.bin"; std::string fname_out; + std::string audio_out; // 用于输出音频流到文件 + + // silero-vad 参数 + float speech_prob_threshold = 0.1f; }; @@ -46,9 +86,11 @@ void sense_voice_stream_usage(int /*argc*/, char **argv, const sense_voice_strea fprintf(stderr, " -l LANG, --language LANG [%-7s] [SenseVoice] spoken language\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] [SenseVoice] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] [IO] text output file name\n", params.fname_out.c_str()); + fprintf(stderr, " -o FNAME, --output-audio FNAME [%-7s] [IO] audio output file name\n", params.audio_out.c_str()); fprintf(stderr, " -ng, --no-gpu [%-7s] [SenseVoice] disable GPU inference\n", params.use_gpu ? "false" : "true"); fprintf(stderr, " -fa, --flash-attn [%-7s] [SenseVoice] flash attention during inference\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " --use-itn [%-7s] [SenseVoice] Filter duplicate tokens when outputting\n", params.use_itn ? "true" : "false"); + fprintf(stderr, " -spt --speech-prob-threshold [%-7.2f] [VAD] speech probability threshold for VAD\n", params.speech_prob_threshold); fprintf(stderr, "\n"); } @@ -70,6 +112,8 @@ static bool get_stream_params(int argc, char **argv, sense_voice_stream_params & params.model = argv[++i]; } else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; + } else if (arg == "-o" || arg == "--output-audio") { + params.audio_out = argv[++i]; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-fa" || arg == "--flash-attn") { @@ -88,6 +132,8 @@ static bool get_stream_params(int argc, char **argv, sense_voice_stream_params & params.chunk_size = std::stoi(argv[++i]); } else if (arg == "--use-itn") { params.use_itn = true; + } else if (arg == "--speech-prob-threshold" || arg == "-spt") { + params.speech_prob_threshold = std::stof(argv[++i]); } else { @@ -106,6 +152,22 @@ void sense_voice_free(struct sense_voice_context *ctx) { ggml_backend_buffer_free(ctx->model.buffer); + // 释放VAD相关资源 - 添加空指针检查 + if (ctx->state) { + if (ctx->state->vad_ctx) { + ggml_free(ctx->state->vad_ctx); + ctx->state->vad_ctx = nullptr; + } + if (ctx->state->vad_lstm_hidden_state_buffer) { + ggml_backend_buffer_free(ctx->state->vad_lstm_hidden_state_buffer); + ctx->state->vad_lstm_hidden_state_buffer = nullptr; + } + if (ctx->state->vad_lstm_context_buffer) { + ggml_backend_buffer_free(ctx->state->vad_lstm_context_buffer); + ctx->state->vad_lstm_context_buffer = nullptr; + } + } + sense_voice_free_state(ctx->state); delete ctx->model.model->encoder; @@ -118,6 +180,11 @@ void sense_voice_free(struct sense_voice_context *ctx) { int main(int argc, char **argv) { sense_voice_stream_params params; if (get_stream_params(argc, argv, params) == false) return 1; + + // VAD 常量定义(与 zcr_main/main.cc 保持一致) + const int VAD_LSTM_STATE_MEMORY_SIZE = 2048; + const int VAD_LSTM_STATE_DIM = 128; + const int n_sample_step = params.chunk_size * 1e-3 * SENSE_VOICE_SAMPLE_RATE; const int keep_nomute_step = params.chunk_size * params.min_mute_chunks * 1e-3 * SENSE_VOICE_SAMPLE_RATE; const int max_nomute_step = params.chunk_size * params.max_nomute_chunks * 1e-3 * SENSE_VOICE_SAMPLE_RATE; @@ -136,16 +203,75 @@ int main(int argc, char **argv) { bool is_running = true; struct sense_voice_context *ctx = sense_voice_small_init_from_file_with_params(params.model.c_str(), cparams); + + if (ctx == nullptr) { + fprintf(stderr, "error: failed to initialize sense voice context\n"); + return 3; + } + + // 设置语言ID(重要:必须设置才能正确识别) + ctx->language_id = sense_voice_lang_id(params.language.c_str()); + if (ctx->language_id == -1) { + fprintf(stderr, "warning: unknown language '%s', using auto detection\n", params.language.c_str()); + ctx->language_id = sense_voice_lang_id("auto"); + } + + fprintf(stderr, "Language: %s (ID: %d)\n", params.language.c_str(), ctx->language_id); + std::vector pcmf32_audio; std::vector pcmf32; std::vector pcmf32_tmp;// 传递给模型用 + + // 预留合理的容量,避免频繁的内存重分配 + pcmf32.reserve(max_nomute_step * 2); + pcmf32_audio.reserve(params.chunk_size * SENSE_VOICE_SAMPLE_RATE / 1000 * 4); // 预留4个chunk的空间 + pcmf32_tmp.reserve(max_nomute_step); // 为临时处理预留空间 + + // 文本输出文件 std::ofstream fout; if (params.fname_out.length() > 0) { fout.open(params.fname_out); if (!fout.is_open()) { - fprintf(stderr, "%s: failed to open output file '%s'!\n", __func__, params.fname_out.c_str()); + fprintf(stderr, "%s: failed to open text output file '%s'!\n", __func__, params.fname_out.c_str()); + return 1; + } + } + + // 音频输出文件 + std::ofstream audio_fout; + WAVHeader wav_header; + uint32_t total_samples = 0; + if (params.audio_out.length() > 0) { + audio_fout.open(params.audio_out, std::ios::binary); + if (!audio_fout.is_open()) { + fprintf(stderr, "%s: failed to open audio output file '%s'!\n", __func__, params.audio_out.c_str()); return 1; } + // 写入 WAV 文件头(先写入占位符,稍后更新) + audio_fout.write(reinterpret_cast(&wav_header), sizeof(WAVHeader)); + } + + // 初始化 VAD 状态(与 main.cc 保持一致) + if (params.use_vad) { + // init state + ctx->state->vad_ctx = ggml_init({VAD_LSTM_STATE_MEMORY_SIZE, nullptr, true}); + ctx->state->vad_lstm_context = ggml_new_tensor_1d(ctx->state->vad_ctx, GGML_TYPE_F32, VAD_LSTM_STATE_DIM); + ctx->state->vad_lstm_hidden_state = ggml_new_tensor_1d(ctx->state->vad_ctx, GGML_TYPE_F32, VAD_LSTM_STATE_DIM); + + ctx->state->vad_lstm_context_buffer = ggml_backend_alloc_buffer(ctx->state->backends[0], + ggml_nbytes(ctx->state->vad_lstm_context) + + ggml_backend_get_alignment(ctx->state->backends[0])); + ctx->state->vad_lstm_hidden_state_buffer = ggml_backend_alloc_buffer(ctx->state->backends[0], + ggml_nbytes(ctx->state->vad_lstm_hidden_state) + + ggml_backend_get_alignment(ctx->state->backends[0])); + auto context_alloc = ggml_tallocr_new(ctx->state->vad_lstm_context_buffer); + ggml_tallocr_alloc(&context_alloc, ctx->state->vad_lstm_context); + + auto state_alloc = ggml_tallocr_new(ctx->state->vad_lstm_hidden_state_buffer); + ggml_tallocr_alloc(&state_alloc, ctx->state->vad_lstm_hidden_state); + + ggml_set_zero(ctx->state->vad_lstm_context); + ggml_set_zero(ctx->state->vad_lstm_hidden_state); } { @@ -186,6 +312,22 @@ int main(int argc, char **argv) { // 获取新的音频,不论是否检测音频数据先把数据捞出来 std::this_thread::sleep_for(std::chrono::milliseconds(params.chunk_size)); audio.get(params.chunk_size, pcmf32_audio); + + // 如果需要输出音频到文件,即时写入 + if (audio_fout.is_open() && !pcmf32_audio.empty()) { + // 转换为 16 位 PCM 格式 + std::vector pcm16_data; + float_to_pcm16(pcmf32_audio, pcm16_data); + + // 写入 PCM 数据 + audio_fout.write(reinterpret_cast(pcm16_data.data()), + pcm16_data.size() * sizeof(int16_t)); + audio_fout.flush(); // 确保即时输出 + + // 更新样本计数 + total_samples += pcm16_data.size(); + } + // 转移到pcmf32中,直接识别pcmf32 pcmf32.insert(pcmf32.end(), pcmf32_audio.begin(), pcmf32_audio.end()); pcmf32_audio.clear(); @@ -212,7 +354,8 @@ int main(int argc, char **argv) { // 时间长度太长了直接换行重新开始 if (R_new_chunk >= max_nomute_step + idenitified_floats) { printf("\n"); - pcmf32_tmp = std::vector(pcmf32.begin() + R_new_chunk - idenitified_floats, pcmf32.end()); + // 优化:使用assign而不是重新构造vector + pcmf32_tmp.assign(pcmf32.begin() + R_new_chunk - idenitified_floats, pcmf32.end()); pcmf32 = pcmf32_tmp; idenitified_floats = R_new_chunk; } @@ -224,7 +367,43 @@ int main(int argc, char **argv) { // 新进来的所有chunk有可能导致序列分拆,需要注意 for (int i = L_new_chunk; i < R_new_chunk; i += n_sample_step) { // int R_this_chunk = i + n_sample_step; - bool isnomute = vad_energy_zcr(pcmf32.begin() + i - idenitified_floats, n_sample_step, SENSE_VOICE_SAMPLE_RATE, 1e-5, 0.2); + // 使用 silero-vad 替换 vad_energy_zcr + bool isnomute = false; + + // VAD检测 - 使用zcr_main/main.cc的方法 + int actual_chunk_size = n_sample_step; + int vad_chunk_size = std::max(640, actual_chunk_size); + std::vector vad_chunk(vad_chunk_size, 0); + + int start_idx = i - idenitified_floats; + + // 确保不越界访问 + for (int j = 0; j < actual_chunk_size && start_idx + j < pcmf32.size(); j++) { + if (start_idx + j >= 0) { + vad_chunk[j] = static_cast(pcmf32[start_idx + j]) / 32768.0f; + } + } + + // 如果实际chunk小于640,用最后一个样本值填充 + if (actual_chunk_size < 640) { + float last_sample = (actual_chunk_size > 0) ? vad_chunk[actual_chunk_size - 1] : 0.0f; + for (int j = actual_chunk_size; j < 640; j++) { + vad_chunk[j] = last_sample; + } + } + + float speech_prob = 0; + if (silero_vad_encode_internal(*ctx, *ctx->state, vad_chunk, params.n_threads, speech_prob)) { + isnomute = (speech_prob >= params.speech_prob_threshold); + // 调试信息:显示VAD结果 + // if (speech_prob > 0.1) { // 只显示有意义的概率 + // fprintf(stderr, "VAD: prob=%.3f, threshold=%.3f, isnomute=%d\n", + // speech_prob, params.threshold, isnomute); + // } + } else { + // 如果 VAD 处理失败,回退到vad_energy_zcr函数 + isnomute = vad_energy_zcr(pcmf32.begin() + start_idx, n_sample_step, SENSE_VOICE_SAMPLE_RATE); + } // fprintf(stderr, "Mute || isnomute = %d, ML = %d, MR = %d, NML = %d, NMR = %d, R_new_chunk = %d, i = %d, size = %d, idenitified = %d\n", isnomute, mute.first, mute.second, nomute.first, nomute.second, R_new_chunk, i, pcmf32.size(), idenitified_floats); if (nomute.first == -1) { if (isnomute) nomute.first = i; @@ -290,19 +469,48 @@ int main(int argc, char **argv) { } // 调整idenitified_floats并且减少pcmf32的长度 if (nomute.second > 0) { - pcmf32_tmp = std::vector(pcmf32.begin() + (nomute.second - idenitified_floats), pcmf32.end()); + // 优化:使用assign而不是重新构造vector + pcmf32_tmp.assign(pcmf32.begin() + (nomute.second - idenitified_floats), pcmf32.end()); pcmf32 = pcmf32_tmp; idenitified_floats = nomute.second; nomute.second = 0; } else if (nomute.first == -1) { - pcmf32_tmp = std::vector(pcmf32.begin() + (R_new_chunk - idenitified_floats), pcmf32.end()); + // 优化:使用assign而不是重新构造vector + pcmf32_tmp.assign(pcmf32.begin() + (R_new_chunk - idenitified_floats), pcmf32.end()); pcmf32 = pcmf32_tmp; idenitified_floats = R_new_chunk; } + + // 检查缓冲区大小并发出警告 + if (pcmf32.size() > 2 * max_nomute_step) { + fprintf(stderr, "Warning: Audio buffer size (%.2f MB, %.2f sec) exceeds recommended limit. Consider optimizing processing speed.\n", + pcmf32.size() * sizeof(double) / 1e6, + pcmf32.size() / (double)SENSE_VOICE_SAMPLE_RATE); + } } fflush(stdout); } audio.pause(); + + // 关闭输出文件 + if (fout.is_open()) { + fout.close(); + } + if (audio_fout.is_open()) { + // 更新 WAV 文件头中的文件大小信息 + wav_header.data_size = total_samples * sizeof(int16_t); + wav_header.file_size = sizeof(WAVHeader) - 8 + wav_header.data_size; + + // 重新定位到文件开头并写入更新后的文件头 + audio_fout.seekp(0, std::ios::beg); + audio_fout.write(reinterpret_cast(&wav_header), sizeof(WAVHeader)); + audio_fout.close(); + + fprintf(stderr, "Audio saved to '%s' (%u samples, %.2f seconds)\n", + params.audio_out.c_str(), total_samples, + (float)total_samples / SENSE_VOICE_SAMPLE_RATE); + } + sense_voice_free(ctx); return 0; } diff --git a/sense-voice/csrc/sense-voice.cc b/sense-voice/csrc/sense-voice.cc index 1545441..23889c5 100644 --- a/sense-voice/csrc/sense-voice.cc +++ b/sense-voice/csrc/sense-voice.cc @@ -687,6 +687,17 @@ int sense_voice_pcm_to_feature_with_state(struct sense_voice_context *ctx, // set input { + // 释放之前的资源以防止内存泄漏 + if (state->feature.ctx) { + ggml_free(state->feature.ctx); + state->feature.ctx = nullptr; + } + if (state->feature.buffer) { + ggml_backend_buffer_free(state->feature.buffer); + state->feature.buffer = nullptr; + } + state->feature.tensor = nullptr; + // init features state->feature.n_len = state->feature.data.size() / (state->feature.n_mel * state->feature.lfr_m); state->feature.ctx = ggml_init({ggml_tensor_overhead(), nullptr, true}); @@ -816,6 +827,17 @@ int sense_voice_batch_pcm_to_feature_with_state(struct sense_voice_context *ctx, // set input { + // 释放之前的资源以防止内存泄漏 + if (state->feature.ctx) { + ggml_free(state->feature.ctx); + state->feature.ctx = nullptr; + } + if (state->feature.buffer) { + ggml_backend_buffer_free(state->feature.buffer); + state->feature.buffer = nullptr; + } + state->feature.tensor = nullptr; + // init features state->feature.n_len = state->feature.data.size() / (state->feature.n_mel * state->feature.lfr_m); state->feature.ctx = ggml_init({ggml_tensor_overhead(), nullptr, true});