From 4122006aed2b887e037167a9184d46842f811517 Mon Sep 17 00:00:00 2001 From: michaljarnot Date: Sat, 27 Jul 2024 19:43:41 +0200 Subject: [PATCH] refactor: better runQueryStream --- cpp/CMakeLists.txt | 2 ++ cpp/src/LlamaCPPBinding.cpp | 24 ++++++++----- cpp/src/LlamaCPPBinding.h | 1 + cpp/src/TokenStream.cpp | 72 +++++++++++++++++++++++++++++++++++++ cpp/src/TokenStream.h | 28 +++++++++++++++ cpp/src/main.cpp | 2 ++ example/index.js | 44 +++++++++++++---------- index.d.ts | 6 +++- 8 files changed, 150 insertions(+), 29 deletions(-) create mode 100644 cpp/src/TokenStream.cpp create mode 100644 cpp/src/TokenStream.h diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index a55332e..5484dde 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -15,6 +15,8 @@ add_library(${PROJECT_NAME} SHARED src/main.cpp src/LlamaCPPBinding.h src/LlamaCPPBinding.cpp + src/TokenStream.h + src/TokenStream.cpp ${CMAKE_JS_SRC} ) diff --git a/cpp/src/LlamaCPPBinding.cpp b/cpp/src/LlamaCPPBinding.cpp index a3963a3..485471e 100644 --- a/cpp/src/LlamaCPPBinding.cpp +++ b/cpp/src/LlamaCPPBinding.cpp @@ -1,5 +1,6 @@ #include "LlamaCPPBinding.h" #include +#include Napi::FunctionReference LlamaCPPBinding::constructor; @@ -62,21 +63,26 @@ Napi::Value LlamaCPPBinding::RunQuery(const Napi::CallbackInfo& info) { Napi::Value LlamaCPPBinding::RunQueryStream(const Napi::CallbackInfo& info) { Napi::Env env = info.Env(); - if (info.Length() < 2 || !info[0].IsString() || !info[1].IsFunction()) { - Napi::TypeError::New(env, "String and function expected").ThrowAsJavaScriptException(); + if (info.Length() < 1 || !info[0].IsString()) { + Napi::TypeError::New(env, "String expected").ThrowAsJavaScriptException(); return env.Null(); } std::string prompt = info[0].As().Utf8Value(); - Napi::Function callback = info[1].As(); size_t max_tokens = 1000; - if (info.Length() > 2 && info[2].IsNumber()) { - max_tokens = info[2].As().Uint32Value(); + if (info.Length() > 1 && info[1].IsNumber()) { + max_tokens = info[1].As().Uint32Value(); } - llama_->RunQueryStream(prompt, max_tokens, [&env, &callback](const std::string& token) { - callback.Call(env.Global(), {Napi::String::New(env, token)}); - }); + Napi::Object streamObj = TokenStream::NewInstance(env, env.Null()); + TokenStream* stream = Napi::ObjectWrap::Unwrap(streamObj); + + std::thread([this, prompt, max_tokens, stream]() { + llama_->RunQueryStream(prompt, max_tokens, [stream](const std::string& token) { + stream->Push(token); + }); + stream->End(); + }).detach(); - return env.Undefined(); + return streamObj; } diff --git a/cpp/src/LlamaCPPBinding.h b/cpp/src/LlamaCPPBinding.h index 6991990..f2faff2 100644 --- a/cpp/src/LlamaCPPBinding.h +++ b/cpp/src/LlamaCPPBinding.h @@ -2,6 +2,7 @@ #include #include "llama-wrapper.h" +#include "TokenStream.h" class LlamaCPPBinding : public Napi::ObjectWrap { public: diff --git a/cpp/src/TokenStream.cpp b/cpp/src/TokenStream.cpp new file mode 100644 index 0000000..db8fb9a --- /dev/null +++ b/cpp/src/TokenStream.cpp @@ -0,0 +1,72 @@ +#include "TokenStream.h" + +Napi::FunctionReference TokenStream::constructor; + +Napi::Object TokenStream::Init(Napi::Env env, Napi::Object exports) { + Napi::Function func = DefineClass(env, "TokenStream", { + InstanceMethod("read", &TokenStream::Read), + InstanceMethod("push", &TokenStream::Push), + InstanceMethod("end", &TokenStream::End), + }); + + constructor = Napi::Persistent(func); + constructor.SuppressDestruct(); + + exports.Set("TokenStream", func); + return exports; +} + +Napi::Object TokenStream::NewInstance(Napi::Env env, Napi::Value arg) { + Napi::EscapableHandleScope scope(env); + Napi::Object obj = constructor.New({arg}); + return scope.Escape(napi_value(obj)).ToObject(); +} + +TokenStream::TokenStream(const Napi::CallbackInfo& info) : Napi::ObjectWrap(info) {} + +Napi::Value TokenStream::Read(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + std::unique_lock lock(mutex); + cv.wait(lock, [this] { return !tokenQueue.empty() || finished; }); + + if (tokenQueue.empty() && finished) { + return env.Null(); + } + + std::string token = tokenQueue.front(); + tokenQueue.pop(); + return Napi::String::New(env, token); +} + +Napi::Value TokenStream::Push(const Napi::CallbackInfo& info) { + if (info.Length() < 1 || !info[0].IsString()) { + Napi::TypeError::New(info.Env(), "String expected").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + std::string token = info[0].As().Utf8Value(); + Push(token); + return info.Env().Undefined(); +} + +void TokenStream::Push(const std::string& token) { + { + std::lock_guard lock(mutex); + tokenQueue.push(token); + } + cv.notify_one(); +} + +Napi::Value TokenStream::End(const Napi::CallbackInfo& info) { + End(); + return info.Env().Undefined(); +} + +void TokenStream::End() { + { + std::lock_guard lock(mutex); + finished = true; + } + cv.notify_all(); +} + diff --git a/cpp/src/TokenStream.h b/cpp/src/TokenStream.h new file mode 100644 index 0000000..7fc9330 --- /dev/null +++ b/cpp/src/TokenStream.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include +#include + +class TokenStream : public Napi::ObjectWrap { +public: + static Napi::Object Init(Napi::Env env, Napi::Object exports); + static Napi::Object NewInstance(Napi::Env env, Napi::Value arg); + TokenStream(const Napi::CallbackInfo& info); + + void Push(const std::string& token); + void End(); + +private: + static Napi::FunctionReference constructor; + + Napi::Value Read(const Napi::CallbackInfo& info); + Napi::Value Push(const Napi::CallbackInfo& info); + Napi::Value End(const Napi::CallbackInfo& info); + + std::queue tokenQueue; + std::mutex mutex; + std::condition_variable cv; + bool finished = false; +}; diff --git a/cpp/src/main.cpp b/cpp/src/main.cpp index df92284..d88ed02 100644 --- a/cpp/src/main.cpp +++ b/cpp/src/main.cpp @@ -1,7 +1,9 @@ #include #include "LlamaCPPBinding.h" +#include "TokenStream.h" Napi::Object InitAll(Napi::Env env, Napi::Object exports) { + TokenStream::Init(env, exports); return LlamaCPPBinding::Init(env, exports); } diff --git a/example/index.js b/example/index.js index 804aa3b..b3d6872 100644 --- a/example/index.js +++ b/example/index.js @@ -1,30 +1,36 @@ const Llama = require('../index.js'); async function main() { - const llama = new Llama(); + // + // Initialize + const llama = new Llama(); + const modelPath = __dirname + "/models/Meta-Llama-3.1-8B-Instruct-Q3_K_S.gguf"; - llama.initialize(__dirname + "/models/Meta-Llama-3.1-8B-Instruct-Q3_K_S.gguf"); + if (!llama.initialize(modelPath)) { + console.error("Failed to initialize the model"); + return; + } - console.log("\nRunning a simple query:"); - const response = llama.runQuery("Tell me a short story.", 100); - console.log("Response:", response); + // + // Query + const query = "Hello."; - console.log("\nRunning a streaming query:"); - let streamingResponse = ""; - llama.runQueryStream( - "List 5 interesting facts about space.", - (token) => { - process.stdout.write(token); - streamingResponse += token; - }, - 200 - ); + // + // Sync query + const response = llama.runQuery(query, 100); + console.log(response) - // Wait for the streaming to finish - await new Promise(resolve => setTimeout(resolve, 5000)); + // + // Stream query + const tokenStream = llama.runQueryStream(query, 200); + let streamingResponse = ""; - console.log("\n\nFull streaming response:"); - console.log(streamingResponse); + while (true) { + const token = await tokenStream.read(); + if (token === null) break; + process.stdout.write(token); + streamingResponse += token; + } } main().catch(console.error); diff --git a/index.d.ts b/index.d.ts index 77237bc..223c6b2 100644 --- a/index.d.ts +++ b/index.d.ts @@ -1,6 +1,10 @@ +export class TokenStream { + read(): Promise; +} + export class Llama { constructor(); initialize(modelPath: string, contextSize?: number): boolean; runQuery(prompt: string, maxTokens?: number): string; - runQueryStream(prompt: string, callback: (token: string) => void, maxTokens?: number): void; + runQueryStream(prompt: string, maxTokens?: number): TokenStream; }