Skip to content

Commit

Permalink
refactor: better runQueryStream
Browse files Browse the repository at this point in the history
  • Loading branch information
developer239 committed Jul 27, 2024
1 parent e18cb8c commit 4122006
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 29 deletions.
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ add_library(${PROJECT_NAME} SHARED
src/main.cpp
src/LlamaCPPBinding.h
src/LlamaCPPBinding.cpp
src/TokenStream.h
src/TokenStream.cpp
${CMAKE_JS_SRC}
)

Expand Down
24 changes: 15 additions & 9 deletions cpp/src/LlamaCPPBinding.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "LlamaCPPBinding.h"
#include <napi.h>
#include <thread>

Napi::FunctionReference LlamaCPPBinding::constructor;

Expand Down Expand Up @@ -62,21 +63,26 @@ Napi::Value LlamaCPPBinding::RunQuery(const Napi::CallbackInfo& info) {

Napi::Value LlamaCPPBinding::RunQueryStream(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
if (info.Length() < 2 || !info[0].IsString() || !info[1].IsFunction()) {
Napi::TypeError::New(env, "String and function expected").ThrowAsJavaScriptException();
if (info.Length() < 1 || !info[0].IsString()) {
Napi::TypeError::New(env, "String expected").ThrowAsJavaScriptException();
return env.Null();
}

std::string prompt = info[0].As<Napi::String>().Utf8Value();
Napi::Function callback = info[1].As<Napi::Function>();
size_t max_tokens = 1000;
if (info.Length() > 2 && info[2].IsNumber()) {
max_tokens = info[2].As<Napi::Number>().Uint32Value();
if (info.Length() > 1 && info[1].IsNumber()) {
max_tokens = info[1].As<Napi::Number>().Uint32Value();
}

llama_->RunQueryStream(prompt, max_tokens, [&env, &callback](const std::string& token) {
callback.Call(env.Global(), {Napi::String::New(env, token)});
});
Napi::Object streamObj = TokenStream::NewInstance(env, env.Null());
TokenStream* stream = Napi::ObjectWrap<TokenStream>::Unwrap(streamObj);

std::thread([this, prompt, max_tokens, stream]() {
llama_->RunQueryStream(prompt, max_tokens, [stream](const std::string& token) {
stream->Push(token);
});
stream->End();
}).detach();

return env.Undefined();
return streamObj;
}
1 change: 1 addition & 0 deletions cpp/src/LlamaCPPBinding.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <napi.h>
#include "llama-wrapper.h"
#include "TokenStream.h"

class LlamaCPPBinding : public Napi::ObjectWrap<LlamaCPPBinding> {
public:
Expand Down
72 changes: 72 additions & 0 deletions cpp/src/TokenStream.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#include "TokenStream.h"

Napi::FunctionReference TokenStream::constructor;

Napi::Object TokenStream::Init(Napi::Env env, Napi::Object exports) {
Napi::Function func = DefineClass(env, "TokenStream", {
InstanceMethod("read", &TokenStream::Read),
InstanceMethod("push", &TokenStream::Push),
InstanceMethod("end", &TokenStream::End),
});

constructor = Napi::Persistent(func);
constructor.SuppressDestruct();

exports.Set("TokenStream", func);
return exports;
}

Napi::Object TokenStream::NewInstance(Napi::Env env, Napi::Value arg) {
Napi::EscapableHandleScope scope(env);
Napi::Object obj = constructor.New({arg});
return scope.Escape(napi_value(obj)).ToObject();
}

TokenStream::TokenStream(const Napi::CallbackInfo& info) : Napi::ObjectWrap<TokenStream>(info) {}

Napi::Value TokenStream::Read(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
std::unique_lock<std::mutex> lock(mutex);
cv.wait(lock, [this] { return !tokenQueue.empty() || finished; });

if (tokenQueue.empty() && finished) {
return env.Null();
}

std::string token = tokenQueue.front();
tokenQueue.pop();
return Napi::String::New(env, token);
}

Napi::Value TokenStream::Push(const Napi::CallbackInfo& info) {
if (info.Length() < 1 || !info[0].IsString()) {
Napi::TypeError::New(info.Env(), "String expected").ThrowAsJavaScriptException();
return info.Env().Undefined();
}

std::string token = info[0].As<Napi::String>().Utf8Value();
Push(token);
return info.Env().Undefined();
}

void TokenStream::Push(const std::string& token) {
{
std::lock_guard<std::mutex> lock(mutex);
tokenQueue.push(token);
}
cv.notify_one();
}

Napi::Value TokenStream::End(const Napi::CallbackInfo& info) {
End();
return info.Env().Undefined();
}

void TokenStream::End() {
{
std::lock_guard<std::mutex> lock(mutex);
finished = true;
}
cv.notify_all();
}

28 changes: 28 additions & 0 deletions cpp/src/TokenStream.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once

#include <napi.h>
#include <queue>
#include <mutex>
#include <condition_variable>

class TokenStream : public Napi::ObjectWrap<TokenStream> {
public:
static Napi::Object Init(Napi::Env env, Napi::Object exports);
static Napi::Object NewInstance(Napi::Env env, Napi::Value arg);
TokenStream(const Napi::CallbackInfo& info);

void Push(const std::string& token);
void End();

private:
static Napi::FunctionReference constructor;

Napi::Value Read(const Napi::CallbackInfo& info);
Napi::Value Push(const Napi::CallbackInfo& info);
Napi::Value End(const Napi::CallbackInfo& info);

std::queue<std::string> tokenQueue;
std::mutex mutex;
std::condition_variable cv;
bool finished = false;
};
2 changes: 2 additions & 0 deletions cpp/src/main.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include <napi.h>
#include "LlamaCPPBinding.h"
#include "TokenStream.h"

Napi::Object InitAll(Napi::Env env, Napi::Object exports) {
TokenStream::Init(env, exports);
return LlamaCPPBinding::Init(env, exports);
}

Expand Down
44 changes: 25 additions & 19 deletions example/index.js
Original file line number Diff line number Diff line change
@@ -1,30 +1,36 @@
const Llama = require('../index.js');

async function main() {
const llama = new Llama();
//
// Initialize
const llama = new Llama();
const modelPath = __dirname + "/models/Meta-Llama-3.1-8B-Instruct-Q3_K_S.gguf";

llama.initialize(__dirname + "/models/Meta-Llama-3.1-8B-Instruct-Q3_K_S.gguf");
if (!llama.initialize(modelPath)) {
console.error("Failed to initialize the model");
return;
}

console.log("\nRunning a simple query:");
const response = llama.runQuery("Tell me a short story.", 100);
console.log("Response:", response);
//
// Query
const query = "Hello.";

console.log("\nRunning a streaming query:");
let streamingResponse = "";
llama.runQueryStream(
"List 5 interesting facts about space.",
(token) => {
process.stdout.write(token);
streamingResponse += token;
},
200
);
//
// Sync query
const response = llama.runQuery(query, 100);
console.log(response)

// Wait for the streaming to finish
await new Promise(resolve => setTimeout(resolve, 5000));
//
// Stream query
const tokenStream = llama.runQueryStream(query, 200);
let streamingResponse = "";

console.log("\n\nFull streaming response:");
console.log(streamingResponse);
while (true) {
const token = await tokenStream.read();
if (token === null) break;
process.stdout.write(token);
streamingResponse += token;
}
}

main().catch(console.error);
6 changes: 5 additions & 1 deletion index.d.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
export class TokenStream {
read(): Promise<string | null>;
}

export class Llama {
constructor();
initialize(modelPath: string, contextSize?: number): boolean;
runQuery(prompt: string, maxTokens?: number): string;
runQueryStream(prompt: string, callback: (token: string) => void, maxTokens?: number): void;
runQueryStream(prompt: string, maxTokens?: number): TokenStream;
}

0 comments on commit 4122006

Please sign in to comment.