diff --git a/examples/common-sdl.cpp b/examples/common-sdl.cpp index b61f8cff5fd..6272ce838a7 100644 --- a/examples/common-sdl.cpp +++ b/examples/common-sdl.cpp @@ -130,6 +130,7 @@ bool audio_async::clear() { m_audio_pos = 0; m_audio_len = 0; + m_audio_nxt = 0; } return true; @@ -172,6 +173,28 @@ void audio_async::callback(uint8_t * stream, int len) { } void audio_async::get(int ms, std::vector & result) { + if (ms <= 0) { + ms = m_len_ms; + } + + size_t n_samples = std::min(m_audio_len, (m_sample_rate * ms) / 1000); + + get_n(n_samples, result); +} + +void audio_async::next(std::vector & result) { + size_t n_samples; + + if (m_audio_pos >= m_audio_nxt) { + n_samples = m_audio_pos - m_audio_nxt; + } else { + n_samples = m_audio_len - m_audio_nxt + m_audio_pos; + } + + get_n(n_samples, result); +} + +void audio_async::get_n(size_t n_samples, std::vector & result) { if (!m_dev_id_in) { fprintf(stderr, "%s: no audio device to get audio from!\n", __func__); return; @@ -182,20 +205,9 @@ void audio_async::get(int ms, std::vector & result) { return; } - result.clear(); - { std::lock_guard lock(m_mutex); - if (ms <= 0) { - ms = m_len_ms; - } - - size_t n_samples = (m_sample_rate * ms) / 1000; - if (n_samples > m_audio_len) { - n_samples = m_audio_len; - } - result.resize(n_samples); int s0 = m_audio_pos - n_samples; @@ -205,10 +217,12 @@ void audio_async::get(int ms, std::vector & result) { if (s0 + n_samples > m_audio.size()) { const size_t n0 = m_audio.size() - s0; + m_audio_nxt = n_samples - n0; memcpy(result.data(), &m_audio[s0], n0 * sizeof(float)); - memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float)); + memcpy(&result[n0], &m_audio[0], m_audio_nxt * sizeof(float)); } else { + m_audio_nxt = s0 + n_samples; memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float)); } } diff --git a/examples/common-sdl.h b/examples/common-sdl.h index 9ee8a320724..746493f7c83 100644 --- a/examples/common-sdl.h +++ b/examples/common-sdl.h @@ -30,6 +30,8 @@ class audio_async { // get audio data from the circular buffer void get(int ms, std::vector & audio); + void next(std::vector & audio); + void get_n(size_t n_samples, std::vector & audio); private: SDL_AudioDeviceID m_dev_id_in = 0; @@ -43,6 +45,7 @@ class audio_async { std::vector m_audio; size_t m_audio_pos = 0; size_t m_audio_len = 0; + size_t m_audio_nxt = 0; }; // Return false if need to quit diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index 190f68a2c3b..ac2116bca12 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -7,7 +7,10 @@ #include "whisper.h" #include +#include #include +#include +#include #include #include #include @@ -21,7 +24,7 @@ struct whisper_params { int32_t length_ms = 10000; int32_t keep_ms = 200; int32_t capture_id = -1; - int32_t max_tokens = 32; + int32_t max_tokens = 128; int32_t audio_ctx = 0; float vad_thold = 0.6f; @@ -36,6 +39,7 @@ struct whisper_params { bool save_audio = false; // save audio to wav file bool use_gpu = true; bool flash_attn = false; + bool interim = false; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; @@ -65,6 +69,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; } + else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } @@ -72,6 +77,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-int" || arg == "--interim") { params.interim = true; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); @@ -102,6 +108,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true"); + fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); @@ -109,6 +116,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -sa, --save-audio [%-7s] save the recorded audio to a file\n", params.save_audio ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU inference\n", params.use_gpu ? "false" : "true"); fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention during inference\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -int, --interim [%-7s] show interim report in vad every step\n", params.interim ? "true" : "false"); fprintf(stderr, "\n"); } @@ -122,19 +130,16 @@ int main(int argc, char ** argv) { params.keep_ms = std::min(params.keep_ms, params.step_ms); params.length_ms = std::max(params.length_ms, params.step_ms); - const int n_samples_step = (1e-3*params.step_ms )*WHISPER_SAMPLE_RATE; - const int n_samples_len = (1e-3*params.length_ms)*WHISPER_SAMPLE_RATE; - const int n_samples_keep = (1e-3*params.keep_ms )*WHISPER_SAMPLE_RATE; - const int n_samples_30s = (1e-3*30000.0 )*WHISPER_SAMPLE_RATE; + const int n_samples_step = (1e-3*abs(params.step_ms))*WHISPER_SAMPLE_RATE; + const int n_samples_len = (1e-3*params.length_ms )*WHISPER_SAMPLE_RATE; + const int n_samples_keep = (1e-3*params.keep_ms )*WHISPER_SAMPLE_RATE; + const int n_samples_30s = (1e-3*30000.0 )*WHISPER_SAMPLE_RATE; + const int n_samples_100ms= (1e-3*100.0 )*WHISPER_SAMPLE_RATE; - const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD + const bool use_vad = params.step_ms <= 0; // sliding window mode uses VAD const int n_new_line = !use_vad ? std::max(1, params.length_ms / params.step_ms - 1) : 1; // number of steps to print new line - params.no_timestamps = !use_vad; - params.no_context |= use_vad; - params.max_tokens = 0; - // init audio audio_async audio(params.length_ms); @@ -159,9 +164,10 @@ int main(int argc, char ** argv) { struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); - std::vector pcmf32 (n_samples_30s, 0.0f); - std::vector pcmf32_old; - std::vector pcmf32_new(n_samples_30s, 0.0f); + std::vector pcmf32(n_samples_30s, 0.0f); + std::deque pcmf32_deque; + int n_samples_new = 0; + int n_samples_old = 0; std::vector prompt_tokens; @@ -219,17 +225,17 @@ int main(int argc, char ** argv) { wavWriter.open(filename, WHISPER_SAMPLE_RATE, 16, 1); } - printf("[Start speaking]\n"); - fflush(stdout); + fprintf(stderr, "[Start speaking]\n"); + fflush(stderr); auto t_last = std::chrono::high_resolution_clock::now(); + auto t_interim = t_last; + bool is_interim = false; const auto t_start = t_last; + std::string s_to_delete = ""; // main audio loop while (is_running) { - if (params.save_audio) { - wavWriter.write(pcmf32_new.data(), pcmf32_new.size()); - } // handle Ctrl + C is_running = sdl_poll_events(); @@ -238,62 +244,74 @@ int main(int argc, char ** argv) { } // process new audio + const auto t_now = std::chrono::high_resolution_clock::now(); + const auto t_diff = std::chrono::duration_cast(t_now - t_last).count(); + + // get new audio + if (n_samples_new > n_samples_step) { + pcmf32.clear(); + } else if (t_diff < abs(params.step_ms)) { + std::this_thread::sleep_for(std::chrono::milliseconds(abs(params.step_ms) - t_diff)); + continue; + } else { + audio.next(pcmf32); + } - if (!use_vad) { - while (true) { - audio.get(params.step_ms, pcmf32_new); - - if ((int) pcmf32_new.size() > 2*n_samples_step) { - fprintf(stderr, "\n\n%s: WARNING: cannot process audio fast enough, dropping audio ...\n\n", __func__); - audio.clear(); - continue; - } - - if ((int) pcmf32_new.size() >= n_samples_step) { - audio.clear(); - break; - } - - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } - - const int n_samples_new = pcmf32_new.size(); - - // take up to params.length_ms audio from previous iteration - const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_keep + n_samples_len - n_samples_new)); + const int n_samples_buf = pcmf32.size(); - //printf("processing: take = %d, new = %d, old = %d\n", n_samples_take, n_samples_new, (int) pcmf32_old.size()); + if (params.save_audio && n_samples_buf > 0) { + wavWriter.write(pcmf32.data(), n_samples_buf); + } - pcmf32.resize(n_samples_new + n_samples_take); + copy(pcmf32.begin(), pcmf32.end(), std::back_inserter(pcmf32_deque)); + if (pcmf32_deque.size() > n_samples_30s) { + pcmf32_deque.erase(pcmf32_deque.begin(), pcmf32_deque.end() - n_samples_30s); + } - for (int i = 0; i < n_samples_take; i++) { - pcmf32[i] = pcmf32_old[pcmf32_old.size() - n_samples_take + i]; - } + n_samples_new += n_samples_buf; + if (!is_interim && n_samples_new > 2*n_samples_step) { + fprintf(stderr, "\n\n%s: WARNING: cannot process audio fast enough, dropping audio ...\n", __func__); + fprintf(stderr, "t_diff = %.2fs, new = %.2fs, buf = %.2fs\n\n", 1e-3*t_diff, float(n_samples_new)/WHISPER_SAMPLE_RATE, float(n_samples_buf)/WHISPER_SAMPLE_RATE); + n_samples_old = 0; + n_samples_new = 0; + t_last = t_now; + continue; + } + is_interim = false; - memcpy(pcmf32.data() + n_samples_take, pcmf32_new.data(), n_samples_new*sizeof(float)); + if (!use_vad){ + n_samples_old += n_samples_new; + n_samples_new = 0; + pcmf32.resize(n_samples_old); + copy(pcmf32_deque.end() - n_samples_old, pcmf32_deque.end(), pcmf32.begin()); - pcmf32_old = pcmf32; + t_last = t_now; } else { - const auto t_now = std::chrono::high_resolution_clock::now(); - const auto t_diff = std::chrono::duration_cast(t_now - t_last).count(); - - if (t_diff < 2000) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - continue; - } - - audio.get(2000, pcmf32_new); - - if (::vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false)) { - audio.get(params.length_ms, pcmf32); + pcmf32.resize(n_samples_step); + copy(pcmf32_deque.end() - n_samples_step, pcmf32_deque.end(), pcmf32.begin()); + if (::vad_simple(pcmf32, WHISPER_SAMPLE_RATE, std::min(1000, abs(params.step_ms) / 2), params.vad_thold, params.freq_thold, false)) { + pcmf32.resize(n_samples_old + n_samples_new); + copy(pcmf32_deque.end() - n_samples_old - n_samples_new, pcmf32_deque.end(), pcmf32.begin()); + n_samples_new = 0; + n_samples_old = 0; + + t_last = t_now; } else { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - continue; + const auto n_interim_diff_ms = std::chrono::duration_cast(t_now - t_interim).count(); + + if (params.interim && n_interim_diff_ms > abs(params.step_ms)) { + is_interim = (n_interim_diff_ms < params.length_ms - abs(params.step_ms)); + n_samples_old += n_samples_new; + n_samples_new = 0; + pcmf32.resize(n_samples_old); + copy(pcmf32_deque.end() - n_samples_old, pcmf32_deque.end(), pcmf32.begin()); + } else { + n_samples_new -= n_samples_100ms; + n_samples_old = std::min(n_samples_len, n_samples_old + n_samples_100ms); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + continue; + } } - - t_last = t_now; } // run the inference @@ -325,80 +343,109 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: failed to process audio\n", argv[0]); return 6; } + t_interim = std::chrono::high_resolution_clock::now(); // print result; + int n_segments; + bool is_unconfirmed = false; + std::ostringstream text; { - if (!use_vad) { + if (!use_vad || params.interim && params.no_timestamps && s_to_delete.size()) { printf("\33[2K\r"); // print long empty line to clear the previous line - printf("%s", std::string(100, ' ').c_str()); + printf("%s", std::string(s_to_delete.size(), ' ').c_str()); printf("\33[2K\r"); - } else { + } else if (use_vad && !params.no_timestamps) { const int64_t t1 = (t_last - t_start).count()/1000000; const int64_t t0 = std::max(0.0, t1 - pcmf32.size()*1000.0/WHISPER_SAMPLE_RATE); - printf("\n"); - printf("### Transcription %d START | t0 = %d ms | t1 = %d ms\n", n_iter, (int) t0, (int) t1); - printf("\n"); + text << std::endl; + text << "### Transcription " << n_iter << " START | t0 = " << t0 << " ms | t1 = " << t1 << " ms" << std::endl; + text << std::endl; } - const int n_segments = whisper_full_n_segments(ctx); + n_segments = whisper_full_n_segments(ctx); + if (is_interim) { + if (n_segments < 2) { + is_unconfirmed = true; + } else { + n_segments--; + const int64_t t1_ms = whisper_full_get_segment_t1(ctx, n_segments - 1) * 10; + t_last += std::chrono::milliseconds(t1_ms); + const auto n_confirmed = (1e-3*t1_ms)*WHISPER_SAMPLE_RATE; + pcmf32.resize(n_confirmed); + n_samples_old -= n_confirmed; + } + } for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); - - if (params.no_timestamps) { - printf("%s", text); - fflush(stdout); + std::string i_text = whisper_full_get_segment_text(ctx, i); - if (params.fname_out.length() > 0) { - fout << text; + if (!use_vad || params.no_timestamps) { + if (i > 0) { + text << std::endl; } + text << i_text; } else { - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + const int64_t t_end = (t_last - t_start).count()/1000000; + const int64_t t_beg = std::max(0.0, t_end - pcmf32.size()*1000.0/WHISPER_SAMPLE_RATE); + const int64_t t0 = t_beg/10 + whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = t_beg/10 + whisper_full_get_segment_t1(ctx, i); - std::string output = "[" + to_timestamp(t0, false) + " --> " + to_timestamp(t1, false) + "] " + text; + text << "[" << to_timestamp(t0, false) << " --> " << to_timestamp(t1, false) << "] " << i_text; if (whisper_full_get_segment_speaker_turn_next(ctx, i)) { - output += " [SPEAKER_TURN]"; + text << " [SPEAKER_TURN]"; } - output += "\n"; - - printf("%s", output.c_str()); - fflush(stdout); - - if (params.fname_out.length() > 0) { - fout << output; - } + text << std::endl; } } - if (params.fname_out.length() > 0) { - fout << std::endl; + if (use_vad && !params.no_timestamps) { + text << std::endl; + text << "### Transcription " << n_iter << " END"; + text << std::endl; } + } - if (use_vad) { - printf("\n"); - printf("### Transcription %d END\n", n_iter); - } + if (params.fname_out.length() > 0) { + fout << text.str(); + fout << std::endl; } ++n_iter; - if (!use_vad && (n_iter % n_new_line) == 0) { + if (is_unconfirmed) { + --n_iter; + // utf-8 cannot be simply cut into two + std::wstring_convert, char32_t> conv; + auto t_u32 = conv.from_bytes(text.str()); + auto t_sub = conv.to_bytes(t_u32.substr(0, t_u32.size() / 2)); + text.str(t_sub + "…"); + } + + printf("%s", text.str().c_str()); + + if (is_unconfirmed || !use_vad && n_samples_old < n_samples_len - n_samples_step) { + s_to_delete = text.str(); + } else { printf("\n"); + s_to_delete = ""; - // keep part of the audio for next iteration to try to mitigate word boundary issues - pcmf32_old = std::vector(pcmf32.end() - n_samples_keep, pcmf32.end()); + if (!use_vad) { + n_iter = 0; + if (n_samples_keep < n_samples_old) { + // keep part of the audio for next iteration to try to mitigate word boundary issues + n_samples_old = n_samples_keep; + } + } // Add tokens of the last full length segment as the prompt if (!params.no_context) { prompt_tokens.clear(); - const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { const int token_count = whisper_full_n_tokens(ctx, i); for (int j = 0; j < token_count; ++j) {