diff --git a/CMakeLists.txt b/CMakeLists.txt index be6db903c4a..2ffa09e3b47 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -83,6 +83,7 @@ option(WHISPER_BUILD_SERVER "whisper: build server example" ${WHISPER_STANDALO # 3rd party libs option(WHISPER_CURL "whisper: use libcurl to download model from an URL" OFF) option(WHISPER_SDL2 "whisper: support for libSDL2" OFF) +option(WEBSOCKET "whisper: support for websocket" OFF) if (CMAKE_SYSTEM_NAME MATCHES "Linux") option(WHISPER_FFMPEG "whisper: support building and linking with ffmpeg libs (avcodec, swresample, ...)" OFF) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e4265affe97..35b3f75bb19 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -100,8 +100,10 @@ if (EMSCRIPTEN) add_subdirectory(bench.wasm) elseif(CMAKE_JS_VERSION) add_subdirectory(addon.node) -else() - add_subdirectory(cli) +else() + if (WEBSOCKET) + add_subdirectory(cli) + endif() add_subdirectory(bench) add_subdirectory(server) add_subdirectory(quantize) diff --git a/examples/websocket-stream/CMakeLists.txt b/examples/websocket-stream/CMakeLists.txt new file mode 100644 index 00000000000..22de60c84ca --- /dev/null +++ b/examples/websocket-stream/CMakeLists.txt @@ -0,0 +1,8 @@ +set(TARGET whisper-websocket-stream) +add_executable(${TARGET} main.cpp whisper-server.cpp message-buffer.cpp) +find_package(ixwebsocket) +find_package(CURL REQUIRED) +include(DefaultTargetOptions) +target_link_libraries(${TARGET} PRIVATE common whisper ixwebsocket z CURL::libcurl ${CMAKE_THREAD_LIBS_INIT}) + +install(TARGETS ${TARGET} RUNTIME) diff --git a/examples/websocket-stream/README.md b/examples/websocket-stream/README.md new file mode 100644 index 00000000000..5c6e7ee560a --- /dev/null +++ b/examples/websocket-stream/README.md @@ -0,0 +1,90 @@ +# WebSocket Whisper Stream Example + +This example demonstrates a WebSocket-based real-time audio transcription service using the Whisper model. The server captures audio from clients, processes it using the Whisper model, and sends transcriptions back through WebSocket connections. + +## Features + +- Real-time audio transcription +- WebSocket communication for audio and transcription data +- Configurable parameters for model, language, and processing settings +- Integration with backend services via HTTP requests + +## Usage + +Run the server with the following command: + +```bash +./build/bin/whisper-websocket-stream -m ./models/ggml-large-v3-turbo.bin -t 8 --host 0.0.0.0 --port 9002 --forward-url http://localhost:8080/completion +``` + +### Parameters + +- `-m` or `--model`: Path to the Whisper model file. +- `-t` or `--threads`: Number of threads for processing. +- `-H` or `--host`: Hostname or IP address to bind the server to. +- `-p` or `--port`: Port number for the server. +- `-f` or `--forward-url`: URL to forward transcriptions to a backend service. +- `-nm` or `--max-messages`: Maximum number of messages before sending to the backend. +- `-l` or `--language`: Spoken language for transcription. +- `-vth` or `--vad-thold`: Voice activity detection threshold. +- `-tr` or `--translate`: Enable translation to English. +- `-ng` or `--no-gpu`: Disable GPU usage. +- `-bs` or `--beam-size`: Beam size for beam search. + +## Building + +To build the server, follow these steps: + +```bash +# Install dependencies +git clone --depth 1 https://github.com/machinezone/IXWebSocket/ +cd IXWebSocket +mkdir -p build && cd build && cmake -GNinja .. && sudo ninja -j$((npoc)) install +# Build the project +#cuda is optional +git clone --depth 1 https://github.com/ggml-org/whisper.cpp +cd whisper.cpp +mkdir -p build && cd build +cmake -GNinja -DCMAKE_BUILD_TYPE=Release -DWEBSOCKET=ON -DGGML_CUDA .. +ninja -j$((npoc)) + +# Run the server +./bin/whisper-websocket-stream --help +``` + +## Client Integration + +Clients can connect to the WebSocket server and send audio data. The server processes the audio and sends transcriptions back through the WebSocket connection. + +### Example Client Code (JavaScript) + +```javascript +const socket = new WebSocket('ws://localhost:9002'); + +socket.onopen = () => { + console.log('Connected to WebSocket server'); +}; + +socket.onmessage = (event) => { + console.log('Transcription:', event.data); +}; + +socket.onclose = () => { + console.log('Disconnected from WebSocket server'); +}; + +// Function to send audio data to the server +function sendAudioData(audioData) { + socket.send(audioData); +} +``` + +## Backend Integration + +The server can forward transcriptions to a backend service via HTTP requests. Configure the `forward_url` parameter to specify the backend service URL. + +## Dependencies +- whisper.cpp +- ixwebsocket for WebSocket communication +- libcurl for HTTP requests +``` diff --git a/examples/websocket-stream/client-session.h b/examples/websocket-stream/client-session.h new file mode 100644 index 00000000000..e2606e564bb --- /dev/null +++ b/examples/websocket-stream/client-session.h @@ -0,0 +1,15 @@ +#ifndef CLIENT_SESSION_H +#define CLIENT_SESSION_H +#include +#include +#include +#include "ixwebsocket/IXWebSocketServer.h" +#include "message-buffer.h" +struct ClientSession { + std::vector pcm_buffer; + std::mutex mtx; + std::atomic active{true}; + ix::WebSocket *connection; + MessageBuffer buffToBackend; +}; +#endif diff --git a/examples/websocket-stream/index.html b/examples/websocket-stream/index.html new file mode 100644 index 00000000000..5bc644824ca --- /dev/null +++ b/examples/websocket-stream/index.html @@ -0,0 +1,61 @@ + + + + Mic to WebSocket + + + +
+ + + + diff --git a/examples/websocket-stream/main.cpp b/examples/websocket-stream/main.cpp new file mode 100644 index 00000000000..1a4e1292c2b --- /dev/null +++ b/examples/websocket-stream/main.cpp @@ -0,0 +1,74 @@ +#include +#include "whisper.h" +#include "server-params.h" +#include "whisper-server.h" + +#define CONVERT_FROM_PCM_16 +std::string forward_url = "http://127.0.0.1:8080/completion"; +size_t max_messages = 1000; + +void print_usage(int argc, char** argv, const ServerParams& params) { + fprintf(stderr, "\n"); + fprintf(stderr, "usage: %s [options]\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " -H HOST, --host HOST [%-7s] hostname or ip\n", params.host.c_str()); + fprintf(stderr, " -p PORT, --port PORT [%-7d] server port\n", params.port); + fprintf(stderr, " -f FORWARD_URL, --forward-url FORWARD_URL [%-7s] forward url\n", forward_url.c_str()); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads\n", params.n_threads); + fprintf(stderr, " -nm max_messages, --max-messages max_messages [%-7d] max messages before send to backend\n", max_messages); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); + fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity threshold\n", params.vad_thold); + fprintf(stderr, " -tr, --translate [%-7s] translate to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); + fprintf(stderr, "\n"); +} + +bool parse_params(int argc, char** argv, ServerParams& params) { + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "-h" || arg == "--help") { + print_usage(argc, argv, params); + exit(0); + } + else if (arg == "-H" || arg == "--host") { params.host = argv[++i]; } + else if (arg == "-p" || arg == "--port") { params.port = std::stoi(argv[++i]); } + else if (arg == "-f" || arg == "--forward-url") { forward_url = argv[++i]; } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-nm" || arg == "--max-messages") { max_messages = std::stoi(argv[++i]); } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + print_usage(argc, argv, params); + return false; + } + } + return true; +} + +int main(int argc, char** argv) { + ServerParams params; + if (!parse_params(argc, argv, params)) { + return 1; + } + if (params.port < 1 || params.port > 65535) { + throw std::invalid_argument("Invalid port number"); + } + if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) { + fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str()); + return 1; + } + + WhisperServer server(params); + server.run(); + return 0; +} diff --git a/examples/websocket-stream/message-buffer.cpp b/examples/websocket-stream/message-buffer.cpp new file mode 100644 index 00000000000..00eb04bbd62 --- /dev/null +++ b/examples/websocket-stream/message-buffer.cpp @@ -0,0 +1,79 @@ +#include +#include +#include + +#include "message-buffer.h" +extern std::string forward_url; +extern size_t max_messages; +namespace { + std::stringstream ss; + std::mutex mtx; + size_t current_count = 0; + static size_t write_callback(char* ptr, size_t size, size_t nmemb, void* userdata) { + ((std::string*)userdata)->append(ptr, size * nmemb); + return size * nmemb; + } +} + +void MessageBuffer::add_message(const char* msg) { + std::lock_guard lock(mtx); + ss << std::string(msg) << '\n'; + if (++current_count >= max_messages) { + flush(); + } +} + +std::string MessageBuffer::get_payload() { + std::lock_guard lock(mtx); + return ss.str(); +} + +void MessageBuffer::flush() { + std::string payload = get_payload(); + if (!payload.empty()) { + send_via_http(payload); + ss.str(""); //clear string stream + current_count = 0; + } +} + +void MessageBuffer::send_via_http(const std::string& data) { + CURL* curl = curl_easy_init(); + if (!curl) { + printf("CURL init failed"); + return; + } + + //make headers + struct curl_slist* headers = NULL; + headers = curl_slist_append(headers, "Content-Type: text/plain"); + std::string cid_header = "X-Connection-ID: " + connection_id; + headers = curl_slist_append(headers, cid_header.c_str()); + + //config curl + std::string response; + printf("sending to %s\n", forward_url.c_str()); + curl_easy_setopt(curl, CURLOPT_URL, forward_url.c_str()); + curl_easy_setopt(curl, CURLOPT_POST, 1L); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, data.c_str()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, data.size()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_callback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response); + curl_easy_setopt(curl, CURLOPT_TIMEOUT, 5L); + curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, 2L); + + //run curl + for (int retry = 0; retry < 3; ++retry) { + CURLcode res = curl_easy_perform(curl); + if (res == CURLE_OK) { + printf("[Response (%s): %s\n", connection_id.c_str(), response.c_str()); + break; + } + printf("[CURL error: %s\n", curl_easy_strerror(res)); + } + + //clean + curl_slist_free_all(headers); + curl_easy_cleanup(curl); +} diff --git a/examples/websocket-stream/message-buffer.h b/examples/websocket-stream/message-buffer.h new file mode 100644 index 00000000000..0d0d480c2af --- /dev/null +++ b/examples/websocket-stream/message-buffer.h @@ -0,0 +1,14 @@ +#ifndef MESSAGE_BUFFER_H +#define MESSAGE_BUFFER_H +class MessageBuffer { +public: + std::string connection_id; + void add_message(const char* msg); + + std::string get_payload(); + + void flush(); + + void send_via_http(const std::string& data); +}; +#endif diff --git a/examples/websocket-stream/server-params.h b/examples/websocket-stream/server-params.h new file mode 100644 index 00000000000..55fd11bc514 --- /dev/null +++ b/examples/websocket-stream/server-params.h @@ -0,0 +1,23 @@ +#ifndef SERVER_PARAMS_H +#define SERVER_PARAMS_H +#include +struct ServerParams { + int32_t port = 9002; + int32_t n_threads = std::min(4, (int32_t)std::thread::hardware_concurrency()); + int32_t audio_ctx = 0; + int32_t beam_size = -1; + + float vad_thold = 0.6f; + + bool translate = false; + bool print_special = false; + bool no_timestamps = true; + bool tinydiarize = false; + bool use_gpu = true; + bool flash_attn = true; + + std::string language = "en"; + std::string model = "ggml-large-v3-turbo.bin"; + std::string host = "0.0.0.0"; +}; +#endif diff --git a/examples/websocket-stream/whisper-server.cpp b/examples/websocket-stream/whisper-server.cpp new file mode 100644 index 00000000000..53a5333a8d2 --- /dev/null +++ b/examples/websocket-stream/whisper-server.cpp @@ -0,0 +1,177 @@ +#include +#include +#include +#include +#include +#include +#include +#include "whisper-server.h" +#include "client-session.h" +#include "whisper.h" + +namespace { + ServerParams params; + std::unordered_map> clients; + std::mutex clients_mtx; + std::thread processor_thread; + std::atomic running{true}; + std::mutex g_ctx_mtx; + whisper_context* g_ctx = nullptr; + constexpr int CHUNK_SIZE = 3 * 16000; +} + +std::string generate_uuid_v4() { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, 15); + std::uniform_int_distribution<> dis2(8, 11); + + std::stringstream ss; + ss << std::hex; + + for (int i = 0; i < 8; i++) ss << dis(gen); + ss << "-"; + for (int i = 0; i < 4; i++) ss << dis(gen); + ss << "-4"; // v4 + for (int i = 0; i < 3; i++) ss << dis(gen); + ss << "-"; + ss << dis2(gen); + for (int i = 0; i < 3; i++) ss << dis(gen); + ss << "-"; + for (int i = 0; i < 12; i++) ss << dis(gen); + + return ss.str(); +} + +void handleMessage(std::shared_ptr state, + ix::WebSocket& ws, + const ix::WebSocketMessagePtr& msg) { + const std::string client_id = state->getId(); + + if (msg->type == ix::WebSocketMessageType::Open) { + printf("[%s] new client\n", client_id.c_str()); + std::lock_guard lock(clients_mtx); + clients[client_id] = std::make_unique(); + // UUID v4 + clients[client_id]->buffToBackend.connection_id = generate_uuid_v4(); + ws.sendText("CONNECTION_ID:" + clients[client_id]->buffToBackend.connection_id); + + clients[client_id]->connection = &ws; + } + else if (msg->type == ix::WebSocketMessageType::Close) { + printf("[%s] delete client\n", client_id.c_str()); + clients[client_id]->buffToBackend.flush(); + std::lock_guard lock(clients_mtx); + if (clients.count(client_id)) { + clients[client_id]->active = false; + clients.erase(client_id); + } + } + else if (msg->type == ix::WebSocketMessageType::Message && msg->binary) { + std::lock_guard lock(clients_mtx); + if (!clients.count(client_id)) return; + + auto& session = *clients[client_id]; + const auto& data = msg->str; + #ifdef CONVERT_FROM_PCM_16 + const int16_t* pcm16 = reinterpret_cast(data.data()); + size_t n_samples = data.size() / sizeof(int16_t); + + std::lock_guard session_lock(session.mtx); + for (size_t i = 0; i < n_samples; i++) { + session.pcm_buffer.push_back(pcm16[i] / 32768.0f); + } + #else + const int32_t* pcm32 = reinterpret_cast(data.data()); + //also we may use memcpy )) + size_t n_samples = data.size() / sizeof(int32_t); + + std::lock_guard session_lock(session.mtx); + for (size_t i = 0; i < n_samples; i++) { + session.pcm_buffer.push_back(pcm32[i]); + } + #endif + } +} + +void processChunk(std::vector &chunk, const std::string &id, ClientSession *session) { + std::lock_guard ctx_lock(g_ctx_mtx); + whisper_full_params wparams = whisper_full_default_params( + params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH + : WHISPER_SAMPLING_GREEDY); + + wparams.print_progress = false; + wparams.print_special = params.print_special; + wparams.print_realtime = false; + wparams.print_timestamps = !params.no_timestamps; + wparams.translate = params.translate; + wparams.language = params.language.c_str(); + wparams.n_threads = params.n_threads; + wparams.beam_search.beam_size = params.beam_size; + wparams.audio_ctx = params.audio_ctx; + wparams.tdrz_enable = params.tinydiarize; + + if (whisper_full(g_ctx, wparams, chunk.data(), chunk.size()) == 0) { + const char* text = whisper_full_get_segment_text(g_ctx, 0); + printf("[%s] %s\n", id.c_str(), text); + session->connection->sendText(text); + session->buffToBackend.add_message(text); + } + whisper_reset_timings(g_ctx); +} + +void process() { + while (running) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + std::lock_guard lock(clients_mtx); + for (auto& [id, session] : clients) { + std::lock_guard session_lock(session->mtx); + if (session->pcm_buffer.size() < CHUNK_SIZE) continue; + + std::vector chunk( + session->pcm_buffer.begin(), + session->pcm_buffer.begin() + CHUNK_SIZE + ); + session->pcm_buffer.erase( + session->pcm_buffer.begin(), + session->pcm_buffer.begin() + CHUNK_SIZE + ); + + processChunk(chunk, id, session.get()); + } + } +} + +WhisperServer::WhisperServer(const ServerParams& _params) : server(params.port, params.host) { + params = _params; + + whisper_context_params cparams = whisper_context_default_params(); + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; + + g_ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); + + server.setTLSOptions({}); + server.setOnClientMessageCallback([this](auto&&... args) { + handleMessage(args...); + }); + + processor_thread = std::thread([this] { process(); }); +} + +WhisperServer::~WhisperServer() { + running = false; + server.stop(); + if (processor_thread.joinable()) processor_thread.join(); + std::lock_guard lock(clients_mtx); + for (auto& [id, session] : clients) { + session->buffToBackend.flush(); + } + whisper_free(g_ctx); +} + +void WhisperServer::run() { + server.listenAndStart(); + while (running) std::this_thread::sleep_for(std::chrono::seconds(1)); +} diff --git a/examples/websocket-stream/whisper-server.h b/examples/websocket-stream/whisper-server.h new file mode 100644 index 00000000000..88bba420108 --- /dev/null +++ b/examples/websocket-stream/whisper-server.h @@ -0,0 +1,13 @@ +#ifndef WHISPER_SERVER_H +#define WHISPER_SERVER_H +#include "server-params.h" +#include "ixwebsocket/IXWebSocketServer.h" +class WhisperServer { + + ix::WebSocketServer server; +public: + WhisperServer(const ServerParams& params); + ~WhisperServer(); + void run(); +}; +#endif