From af21dcab53db3fc503339ab967afbc41882e06bf Mon Sep 17 00:00:00 2001 From: coltonpierson Date: Fri, 20 Oct 2023 15:36:21 -0700 Subject: [PATCH] handling cycles wip, includes a language for defining graphs wip --- ...__tests__test_traverse_single_node.run.xml | 21 + toolchain/Cargo.lock | 262 +- toolchain/Cargo.toml | 3 +- toolchain/book.toml | 10 + toolchain/chidori/src/translations/rust.rs | 645 +++-- toolchain/prompt-graph-core/Cargo.toml | 23 + toolchain/prompt-graph-core/README.md | 10 + toolchain/prompt-graph-core/build.rs | 7 +- .../examples/simple/arithmetic.rs | 1 + .../protobufs/.idea/.gitignore | 8 - .../protobufs/.idea/misc.xml | 6 - .../protobufs/.idea/modules.xml | 8 - .../protobufs/.idea/protobufs.iml | 9 - .../prompt-graph-core/protobufs/.idea/vcs.xml | 6 - .../prompt-graph-core/protobufs/DSL_v1.proto | 536 ---- .../src/build_runtime_graph/graph_parse.rs | 574 ---- .../src/build_runtime_graph/mod.rs | 1 - .../execution/execution/execution_graph.rs | 842 ++++++ .../execution/execution/execution_state.rs | 176 ++ .../src/execution/execution/mod.rs | 59 + .../execution/mutate_active_execution.rs | 153 + .../src/execution/integration/mod.rs | 1 + .../integration}/triggerable.rs | 66 +- .../src/execution/language/compile.rs | 103 + .../src/execution/language/mod.rs | 5 + .../src/execution/language/parser.rs | 755 +++++ .../language}/typechecker.rs | 0 .../prompt-graph-core/src/execution/mod.rs | 4 + .../src/execution/primitives/identifiers.rs | 3 + .../src/execution/primitives/mod.rs | 3 + .../primitives}/operation.rs | 87 +- .../execution/primitives/serialized_value.rs | 153 + .../src/execution/sdk/entry.rs | 10 + .../src/execution/sdk/mod.rs | 0 .../prompt-graph-core/src/execution_router.rs | 355 --- .../src/generated_protobufs/promptgraph.rs | 2569 ----------------- .../prompt-graph-core/src/graph_definition.rs | 437 --- toolchain/prompt-graph-core/src/lib.rs | 29 +- .../prompt-graph-core/src/library/mod.rs | 1 + .../src/library/std/code/mod.rs | 2 + .../src/library/std/code/runtime_deno.rs | 61 + .../src/library/std/code/runtime_pyo3.rs | 0 .../src/library/std/code/runtime_starlark.rs | 42 + .../src/library/std/io/mod.rs | 1 + .../src/library/std/io/zip/mod.rs | 46 + .../src/library/std/memory/in_memory/mod.rs} | 81 +- .../src/library/std/memory/mod.rs | 61 + .../src/library/std/memory/qdrant/mod.rs | 228 ++ .../prompt-graph-core/src/library/std/mod.rs | 5 + .../src/library/std/prompt/mod.rs | 7 + .../src/library/std/prompt}/openai/batch.rs | 28 +- .../src/library/std/prompt/openai/mod.rs | 0 .../library/std/prompt}/openai/streaming.rs | 71 +- .../src/library/std/prompt/prompt.rs | 14 + .../src/library/std/schedule/README.md | 24 + .../src/library/std/schedule/mod.rs | 1 + .../src/prompt_composition/mod.rs | 2 +- .../src/prompt_composition/templates.rs | 552 ++-- toolchain/prompt-graph-core/src/proto.rs | 2 - .../src/reactivity/database.rs | 323 --- .../prompt-graph-core/src/reactivity/mod.rs | 5 - .../src/reactivity/reactive_sql.rs | 253 -- toolchain/prompt-graph-core/src/utils/mod.rs | 29 - .../tests/data/files_and_dirs.zip | Bin .../tests/nodejs/main.test.js | 0 toolchain/prompt-graph-exec/src/executor.rs | 651 ----- .../prompt-graph-exec/src/integrations/mod.rs | 1 - .../src/integrations/openai/mod.rs | 2 - toolchain/prompt-graph-exec/src/lib.rs | 26 +- toolchain/prompt-graph-exec/src/main.rs | 7 +- .../src/runtime_nodes/mod.rs | 9 - .../src/runtime_nodes/node_code/deno.rs | 94 - .../src/runtime_nodes/node_code/mod.rs | 5 - .../src/runtime_nodes/node_code/node.rs | 77 - .../src/runtime_nodes/node_code/starlark.rs | 60 - .../src/runtime_nodes/node_component/mod.rs | 1 - .../src/runtime_nodes/node_component/node.rs | 7 - .../src/runtime_nodes/node_custom.rs | 41 - .../src/runtime_nodes/node_join.rs | 39 - .../src/runtime_nodes/node_loader/mod.rs | 1 - .../src/runtime_nodes/node_loader/node.rs | 103 - .../src/runtime_nodes/node_map.rs | 31 - .../src/runtime_nodes/node_memory/mod.rs | 2 - .../src/runtime_nodes/node_memory/node.rs | 312 -- .../src/runtime_nodes/node_prompt/mod.rs | 1 - .../src/runtime_nodes/node_prompt/node.rs | 73 - .../src/runtime_nodes/node_schedule/mod.rs | 1 - .../src/runtime_nodes/node_schedule/node.rs | 37 - toolchain/prompt-graph-std/Cargo.toml | 12 - toolchain/prompt-graph-std/src/main.rs | 7 - .../prompt-graph-ui/src-tauri/src/main.rs | 105 +- 91 files changed, 3989 insertions(+), 7495 deletions(-) create mode 100644 toolchain/.run/Test execution__database__tests__test_traverse_single_node.run.xml create mode 100644 toolchain/book.toml create mode 100644 toolchain/prompt-graph-core/examples/simple/arithmetic.rs delete mode 100644 toolchain/prompt-graph-core/protobufs/.idea/.gitignore delete mode 100644 toolchain/prompt-graph-core/protobufs/.idea/misc.xml delete mode 100644 toolchain/prompt-graph-core/protobufs/.idea/modules.xml delete mode 100644 toolchain/prompt-graph-core/protobufs/.idea/protobufs.iml delete mode 100644 toolchain/prompt-graph-core/protobufs/.idea/vcs.xml delete mode 100644 toolchain/prompt-graph-core/protobufs/DSL_v1.proto delete mode 100644 toolchain/prompt-graph-core/src/build_runtime_graph/graph_parse.rs delete mode 100644 toolchain/prompt-graph-core/src/build_runtime_graph/mod.rs create mode 100644 toolchain/prompt-graph-core/src/execution/execution/execution_graph.rs create mode 100644 toolchain/prompt-graph-core/src/execution/execution/execution_state.rs create mode 100644 toolchain/prompt-graph-core/src/execution/execution/mod.rs create mode 100644 toolchain/prompt-graph-core/src/execution/execution/mutate_active_execution.rs create mode 100644 toolchain/prompt-graph-core/src/execution/integration/mod.rs rename toolchain/prompt-graph-core/src/{reactivity => execution/integration}/triggerable.rs (58%) create mode 100644 toolchain/prompt-graph-core/src/execution/language/compile.rs create mode 100644 toolchain/prompt-graph-core/src/execution/language/mod.rs create mode 100644 toolchain/prompt-graph-core/src/execution/language/parser.rs rename toolchain/prompt-graph-core/src/{reactivity => execution/language}/typechecker.rs (100%) create mode 100644 toolchain/prompt-graph-core/src/execution/mod.rs create mode 100644 toolchain/prompt-graph-core/src/execution/primitives/identifiers.rs create mode 100644 toolchain/prompt-graph-core/src/execution/primitives/mod.rs rename toolchain/prompt-graph-core/src/{reactivity => execution/primitives}/operation.rs (58%) create mode 100644 toolchain/prompt-graph-core/src/execution/primitives/serialized_value.rs create mode 100644 toolchain/prompt-graph-core/src/execution/sdk/entry.rs create mode 100644 toolchain/prompt-graph-core/src/execution/sdk/mod.rs delete mode 100644 toolchain/prompt-graph-core/src/execution_router.rs delete mode 100644 toolchain/prompt-graph-core/src/generated_protobufs/promptgraph.rs delete mode 100644 toolchain/prompt-graph-core/src/graph_definition.rs create mode 100644 toolchain/prompt-graph-core/src/library/mod.rs create mode 100644 toolchain/prompt-graph-core/src/library/std/code/mod.rs create mode 100644 toolchain/prompt-graph-core/src/library/std/code/runtime_deno.rs create mode 100644 toolchain/prompt-graph-core/src/library/std/code/runtime_pyo3.rs create mode 100644 toolchain/prompt-graph-core/src/library/std/code/runtime_starlark.rs create mode 100644 toolchain/prompt-graph-core/src/library/std/io/mod.rs create mode 100644 toolchain/prompt-graph-core/src/library/std/io/zip/mod.rs rename toolchain/{prompt-graph-exec/src/runtime_nodes/node_memory/in_memory.rs => prompt-graph-core/src/library/std/memory/in_memory/mod.rs} (56%) create mode 100644 toolchain/prompt-graph-core/src/library/std/memory/mod.rs create mode 100644 toolchain/prompt-graph-core/src/library/std/memory/qdrant/mod.rs create mode 100644 toolchain/prompt-graph-core/src/library/std/mod.rs create mode 100644 toolchain/prompt-graph-core/src/library/std/prompt/mod.rs rename toolchain/{prompt-graph-exec/src/integrations => prompt-graph-core/src/library/std/prompt}/openai/batch.rs (75%) create mode 100644 toolchain/prompt-graph-core/src/library/std/prompt/openai/mod.rs rename toolchain/{prompt-graph-exec/src/integrations => prompt-graph-core/src/library/std/prompt}/openai/streaming.rs (73%) create mode 100644 toolchain/prompt-graph-core/src/library/std/prompt/prompt.rs create mode 100644 toolchain/prompt-graph-core/src/library/std/schedule/README.md create mode 100644 toolchain/prompt-graph-core/src/library/std/schedule/mod.rs delete mode 100644 toolchain/prompt-graph-core/src/proto.rs delete mode 100644 toolchain/prompt-graph-core/src/reactivity/database.rs delete mode 100644 toolchain/prompt-graph-core/src/reactivity/mod.rs delete mode 100644 toolchain/prompt-graph-core/src/reactivity/reactive_sql.rs rename toolchain/{prompt-graph-exec => prompt-graph-core}/tests/data/files_and_dirs.zip (100%) rename toolchain/{prompt-graph-exec => prompt-graph-core}/tests/nodejs/main.test.js (100%) delete mode 100644 toolchain/prompt-graph-exec/src/executor.rs delete mode 100644 toolchain/prompt-graph-exec/src/integrations/mod.rs delete mode 100644 toolchain/prompt-graph-exec/src/integrations/openai/mod.rs delete mode 100644 toolchain/prompt-graph-exec/src/runtime_nodes/mod.rs delete mode 100644 toolchain/prompt-graph-exec/src/runtime_nodes/node_code/deno.rs delete mode 100644 toolchain/prompt-graph-exec/src/runtime_nodes/node_code/mod.rs delete mode 100644 toolchain/prompt-graph-exec/src/runtime_nodes/node_code/node.rs delete mode 100644 toolchain/prompt-graph-exec/src/runtime_nodes/node_code/starlark.rs delete mode 100644 toolchain/prompt-graph-exec/src/runtime_nodes/node_component/mod.rs delete mode 100644 toolchain/prompt-graph-exec/src/runtime_nodes/node_component/node.rs delete mode 100644 toolchain/prompt-graph-exec/src/runtime_nodes/node_custom.rs delete mode 100644 toolchain/prompt-graph-exec/src/runtime_nodes/node_join.rs delete mode 100644 toolchain/prompt-graph-exec/src/runtime_nodes/node_loader/mod.rs delete mode 100644 toolchain/prompt-graph-exec/src/runtime_nodes/node_loader/node.rs delete mode 100644 toolchain/prompt-graph-exec/src/runtime_nodes/node_map.rs delete mode 100644 toolchain/prompt-graph-exec/src/runtime_nodes/node_memory/mod.rs delete mode 100644 toolchain/prompt-graph-exec/src/runtime_nodes/node_memory/node.rs delete mode 100644 toolchain/prompt-graph-exec/src/runtime_nodes/node_prompt/mod.rs delete mode 100644 toolchain/prompt-graph-exec/src/runtime_nodes/node_prompt/node.rs delete mode 100644 toolchain/prompt-graph-exec/src/runtime_nodes/node_schedule/mod.rs delete mode 100644 toolchain/prompt-graph-exec/src/runtime_nodes/node_schedule/node.rs delete mode 100644 toolchain/prompt-graph-std/Cargo.toml delete mode 100644 toolchain/prompt-graph-std/src/main.rs diff --git a/toolchain/.run/Test execution__database__tests__test_traverse_single_node.run.xml b/toolchain/.run/Test execution__database__tests__test_traverse_single_node.run.xml new file mode 100644 index 0000000..b99c73b --- /dev/null +++ b/toolchain/.run/Test execution__database__tests__test_traverse_single_node.run.xml @@ -0,0 +1,21 @@ + + + + \ No newline at end of file diff --git a/toolchain/Cargo.lock b/toolchain/Cargo.lock index e831755..7ce1bb8 100644 --- a/toolchain/Cargo.lock +++ b/toolchain/Cargo.lock @@ -193,6 +193,16 @@ dependencies = [ "os_str_bytes", ] +[[package]] +name = "ariadne" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72fe02fc62033df9ba41cba57ee19acf5e742511a140c7dbc3a873e19a19a1bd" +dependencies = [ + "unicode-width", + "yansi", +] + [[package]] name = "arrayvec" version = "0.7.4" @@ -272,9 +282,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.68" +version = "0.1.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" +checksum = "7b2d0f03b3640e3a630367e40c468cb7f309529c708ed1d88597047b0e7c6ef7" dependencies = [ "proc-macro2 1.0.60", "quote 1.0.28", @@ -790,6 +800,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "chumsky" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eebd66744a15ded14960ab4ccdbfb51ad3b81f51f3f04a80adac98c985396c9" +dependencies = [ + "hashbrown 0.14.2", + "stacker", +] + [[package]] name = "cipher" version = "0.4.4" @@ -1241,6 +1261,16 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2e66c9d817f1720209181c316d28635c050fa304f9c79e47a520882661b7308" +[[package]] +name = "debugid" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef552e6f588e446098f6ba40d89ac146c8c7b64aade83c051ee00bb5d2bc18d" +dependencies = [ + "serde", + "uuid", +] + [[package]] name = "debugserver-types" version = "0.5.0" @@ -1283,7 +1313,7 @@ checksum = "093052d481d5e8ee9bb9b2f08a4d9207b2fbf77585a418938dfb2e4cec350d20" dependencies = [ "anyhow", "bytes", - "deno_ops", + "deno_ops 0.70.0", "futures 0.3.28", "indexmap 1.9.3", "libc", @@ -1293,12 +1323,38 @@ dependencies = [ "pin-project", "serde", "serde_json", - "serde_v8", + "serde_v8 0.103.0", "smallvec", - "sourcemap", + "sourcemap 6.2.3", "tokio", "url", - "v8", + "v8 0.74.0", +] + +[[package]] +name = "deno_core" +version = "0.236.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ea0ab6f78d50bc3c9730f3a7faa3b9b32463b25f4af3dd0f02c6a18d995047e" +dependencies = [ + "anyhow", + "bytes", + "deno_ops 0.112.0", + "deno_unsync", + "futures 0.3.28", + "libc", + "log", + "parking_lot 0.12.1", + "pin-project", + "serde", + "serde_json", + "serde_v8 0.145.0", + "smallvec", + "sourcemap 7.0.1", + "static_assertions", + "tokio", + "url", + "v8 0.82.0", ] [[package]] @@ -1315,12 +1371,36 @@ dependencies = [ "proc-macro2 1.0.60", "quote 1.0.28", "regex", - "strum", - "strum_macros", + "strum 0.24.1", + "strum_macros 0.24.3", "syn 1.0.109", "syn 2.0.18", "thiserror", - "v8", + "v8 0.74.0", +] + +[[package]] +name = "deno_ops" +version = "0.112.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8050c4964e689fb05cac12df6c52727950b44ce48bdd6b5e4d3c0332f2e7aa76" +dependencies = [ + "proc-macro-rules", + "proc-macro2 1.0.60", + "quote 1.0.28", + "strum 0.25.0", + "strum_macros 0.25.3", + "syn 2.0.18", + "thiserror", +] + +[[package]] +name = "deno_unsync" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8a8f3722afd50e566ecfc783cc8a3a046bc4dd5eb45007431dfb2776aeb8993" +dependencies = [ + "tokio", ] [[package]] @@ -2175,7 +2255,7 @@ dependencies = [ "serde", "serde_json", "sqlparser 0.30.0", - "strum_macros", + "strum_macros 0.24.3", "thiserror", "uuid", ] @@ -2377,9 +2457,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.0" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +checksum = "f93e7192158dbcda357bdec5fb5788eebf8bbac027f3f33e719d29135ae84156" dependencies = [ "ahash 0.8.3", "allocator-api2", @@ -2439,7 +2519,7 @@ dependencies = [ "bincode 1.3.3", "cpu-time", "env_logger", - "hashbrown 0.14.0", + "hashbrown 0.14.2", "lazy_static", "log", "mmap-rs", @@ -2740,7 +2820,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" dependencies = [ "equivalent", - "hashbrown 0.14.0", + "hashbrown 0.14.2", ] [[package]] @@ -4138,6 +4218,16 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "priority-queue" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fff39edfcaec0d64e8d0da38564fad195d2d51b680940295fcc307366e101e61" +dependencies = [ + "autocfg", + "indexmap 1.9.3", +] + [[package]] name = "proc-macro-crate" version = "0.1.5" @@ -4187,6 +4277,29 @@ version = "0.5.20+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc375e1527247fe1a97d8b7156678dfe7c1af2fc075c9a4db3690ecd2a148068" +[[package]] +name = "proc-macro-rules" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07c277e4e643ef00c1233393c673f655e3672cf7eb3ba08a00bdd0ea59139b5f" +dependencies = [ + "proc-macro-rules-macros", + "proc-macro2 1.0.60", + "syn 2.0.18", +] + +[[package]] +name = "proc-macro-rules-macros" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "207fffb0fe655d1d47f6af98cc2793405e85929bdbc420d685554ff07be27ac7" +dependencies = [ + "once_cell", + "proc-macro2 1.0.60", + "quote 1.0.28", + "syn 2.0.18", +] + [[package]] name = "proc-macro2" version = "0.4.30" @@ -4237,25 +4350,41 @@ name = "prompt-graph-core" version = "0.1.28" dependencies = [ "anyhow", + "ariadne", + "async-trait", + "base64 0.21.2", + "chumsky", + "crossbeam-utils", + "deno_core 0.236.0", "env_logger", "futures 0.3.28", "gluesql", "handlebars", + "hnsw_rs_thousand_birds", + "http-body-util", "im", "indoc", "log", "num_cpus", "petgraph", + "priority-queue", "prost", "protobuf", + "qdrant-client", + "quote 1.0.28", + "rand 0.8.5", + "rkyv", "serde", "serde_json", "serde_yaml", "sqlparser 0.34.0", + "starlark", + "syn 1.0.109", "tokio", "tonic", "tonic-build", "typescript-type-def", + "zip", ] [[package]] @@ -4270,7 +4399,7 @@ dependencies = [ "bincode 2.0.0-rc.3", "bytes", "dashmap", - "deno_core", + "deno_core 0.192.0", "env_logger", "futures 0.3.28", "futures-core", @@ -4317,7 +4446,7 @@ dependencies = [ "bincode 2.0.0-rc.3", "bytes", "dashmap", - "deno_core", + "deno_core 0.192.0", "env_logger", "futures 0.3.28", "futures-core", @@ -4354,10 +4483,6 @@ dependencies = [ "zip", ] -[[package]] -name = "prompt-graph-std" -version = "0.1.28" - [[package]] name = "prompt-graph-ui" version = "0.1.0" @@ -4454,6 +4579,15 @@ dependencies = [ "thiserror", ] +[[package]] +name = "psm" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5787f7cda34e3033a72192c018bc5883100330f362ef279a8cbccfce8bb4e874" +dependencies = [ + "cc", +] + [[package]] name = "ptr_meta" version = "0.1.4" @@ -5336,7 +5470,22 @@ dependencies = [ "serde_bytes", "smallvec", "thiserror", - "v8", + "v8 0.74.0", +] + +[[package]] +name = "serde_v8" +version = "0.145.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbd782806b3088c7083a142be36ceb734ccfb1da6f82b5eb84a2bff2b4a68efe" +dependencies = [ + "bytes", + "derive_more", + "num-bigint", + "serde", + "smallvec", + "thiserror", + "v8 0.82.0", ] [[package]] @@ -5583,6 +5732,22 @@ dependencies = [ "url", ] +[[package]] +name = "sourcemap" +version = "7.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10da010a590ed2fa9ca8467b00ce7e9c5a8017742c0c09c45450efc172208c4b" +dependencies = [ + "data-encoding", + "debugid", + "if_chain", + "rustc_version 0.2.3", + "serde", + "serde_json", + "unicode-id", + "url", +] + [[package]] name = "spin" version = "0.5.2" @@ -5615,6 +5780,19 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "stacker" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c886bd4480155fd3ef527d45e9ac8dd7118a898a46530b7b94c3e21866259fce" +dependencies = [ + "cc", + "cfg-if 1.0.0", + "libc", + "psm", + "winapi", +] + [[package]] name = "starlark" version = "0.9.0" @@ -5747,7 +5925,16 @@ version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "063e6045c0e62079840579a7e47a355ae92f60eb74daaf156fb1e84ba164e63f" dependencies = [ - "strum_macros", + "strum_macros 0.24.3", +] + +[[package]] +name = "strum" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" +dependencies = [ + "strum_macros 0.25.3", ] [[package]] @@ -5763,6 +5950,19 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "strum_macros" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" +dependencies = [ + "heck", + "proc-macro2 1.0.60", + "quote 1.0.28", + "rustversion", + "syn 2.0.18", +] + [[package]] name = "subtle" version = "2.4.1" @@ -6848,6 +7048,18 @@ dependencies = [ "which", ] +[[package]] +name = "v8" +version = "0.82.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f53dfb242f4c0c39ed3fc7064378a342e57b5c9bd774636ad34ffe405b808121" +dependencies = [ + "bitflags", + "fslock", + "once_cell", + "which", +] + [[package]] name = "valuable" version = "0.1.0" @@ -7526,6 +7738,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "yansi" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" + [[package]] name = "zip" version = "0.6.6" diff --git a/toolchain/Cargo.toml b/toolchain/Cargo.toml index 05560f0..e36b23d 100644 --- a/toolchain/Cargo.toml +++ b/toolchain/Cargo.toml @@ -2,7 +2,6 @@ members = [ "prompt-graph-core", "prompt-graph-exec", - "prompt-graph-std", "prompt-graph-ui/src-tauri", "chidori", ] @@ -19,6 +18,8 @@ repository = "https://github.com/ThousandBirdsInc/chidori" lto = true [workspace.dependencies] +rkyv = {version = "0.7.42", features = ["validation"]} + gluesql = "0.14.0" protobuf = "3.2.0" sqlparser = "0.34.0" diff --git a/toolchain/book.toml b/toolchain/book.toml new file mode 100644 index 0000000..82653e6 --- /dev/null +++ b/toolchain/book.toml @@ -0,0 +1,10 @@ +[book] +authors = ["Colton Pierson"] +language = "en" +multilingual = false +src = "book_src" +title = "chidori" + +[output.html] +cname = "chidoriai.com" +edit-url-template = "https://github.com/ThousandBirdsInc/chidori/edit/main/{path}" \ No newline at end of file diff --git a/toolchain/chidori/src/translations/rust.rs b/toolchain/chidori/src/translations/rust.rs index bb67fd5..5c85f73 100644 --- a/toolchain/chidori/src/translations/rust.rs +++ b/toolchain/chidori/src/translations/rust.rs @@ -1,43 +1,64 @@ -use std::cell::RefCell; -use std::collections::{HashMap, VecDeque}; -use std::future::Future; -use std::hash::Hash; -use std::marker::PhantomData; -use std::pin::Pin; -use std::sync::Arc; +use crate::translations::shared::json_value_to_paths; use anyhow::Error; use futures::future::BoxFuture; use futures::StreamExt; use log::{debug, info}; +use neon_serde3; use once_cell::sync::OnceCell; -use tokio::runtime::Runtime; -use prompt_graph_core::build_runtime_graph::graph_parse::{CleanedDefinitionGraph, CleanIndividualNode, construct_query_from_output_type, derive_for_individual_node}; -use prompt_graph_core::graph_definition::{create_code_node, create_custom_node, create_prompt_node, create_vector_memory_node, SourceNodeType}; -use prompt_graph_core::proto::{ChangeValue, ChangeValueWithCounter, Empty, ExecutionStatus, File, FileAddressedChangeValueWithCounter, FilteredPollNodeWillExecuteEventsRequest, Item, ListBranchesRes, NodeWillExecute, NodeWillExecuteOnBranch, Path, Query, QueryAtFrame, QueryAtFrameResponse, RequestAckNodeWillExecuteEvent, RequestAtFrame, RequestFileMerge, RequestListBranches, RequestNewBranch, RequestOnlyId, RespondPollNodeWillExecuteEvents, SerializedValue, SerializedValueArray, SerializedValueObject}; +use prompt_graph_core::build_runtime_graph::graph_parse::{ + construct_query_from_output_type, derive_for_individual_node, CleanIndividualNode, + CleanedDefinitionGraph, +}; +use prompt_graph_core::graph_definition::{ + create_code_node, create_custom_node, create_prompt_node, create_vector_memory_node, + SourceNodeType, +}; +use prompt_graph_core::prompt_composition::templates::json_value_to_serialized_value; use prompt_graph_core::proto::execution_runtime_client::ExecutionRuntimeClient; use prompt_graph_core::proto::serialized_value::Val; +use prompt_graph_core::proto::{ + ChangeValue, ChangeValueWithCounter, Empty, ExecutionStatus, File, + FileAddressedChangeValueWithCounter, FilteredPollNodeWillExecuteEventsRequest, Item, + ListBranchesRes, NodeWillExecute, NodeWillExecuteOnBranch, Path, Query, QueryAtFrame, + QueryAtFrameResponse, RequestAckNodeWillExecuteEvent, RequestAtFrame, RequestFileMerge, + RequestListBranches, RequestNewBranch, RequestOnlyId, RespondPollNodeWillExecuteEvents, + SerializedValue, SerializedValueArray, SerializedValueObject, +}; use prompt_graph_exec::tonic_runtime::run_server; -use neon_serde3; use serde::{Deserialize, Serialize}; +use std::cell::RefCell; +use std::collections::{HashMap, VecDeque}; +use std::future::Future; +use std::hash::Hash; +use std::marker::PhantomData; +use std::pin::Pin; +use std::sync::Arc; +use tokio::runtime::Runtime; use tonic::Status; -use prompt_graph_core::prompt_composition::templates::json_value_to_serialized_value; -use crate::translations::shared::json_value_to_paths; -pub use prompt_graph_core::utils::serialized_value_to_string; -async fn get_client(url: String) -> Result, tonic::transport::Error> { +async fn get_client( + url: String, +) -> Result, tonic::transport::Error> { ExecutionRuntimeClient::connect(url.clone()).await } -type CallbackHandler = Box BoxFuture<'static, anyhow::Result> + Send + Sync>; +type CallbackHandler = Box< + dyn Fn(NodeWillExecuteOnBranch) -> BoxFuture<'static, anyhow::Result> + + Send + + Sync, +>; pub struct Handler { - pub(crate) callback: CallbackHandler + pub(crate) callback: CallbackHandler, } impl Handler { pub fn new(f: F) -> Self - where - F: Fn(NodeWillExecuteOnBranch) -> BoxFuture<'static, anyhow::Result> + Send + Sync + 'static + where + F: Fn(NodeWillExecuteOnBranch) -> BoxFuture<'static, anyhow::Result> + + Send + + Sync + + 'static, { Handler { callback: Box::new(f), @@ -45,24 +66,27 @@ impl Handler { } } - #[derive(Clone)] pub struct Chidori { file_id: String, current_head: u64, current_branch: u64, url: String, - pub(crate) custom_node_handlers: HashMap> + pub(crate) custom_node_handlers: HashMap>, } impl Chidori { - pub fn new(file_id: String, url: String) -> Self { if !url.contains("://") { panic!("Invalid url, must include protocol"); } // let api_token = cx.argument_opt(2)?.value(&mut cx); - debug!("Creating new Chidori instance with file_id={}, url={}, api_token={:?}", file_id, url, "".to_string()); + debug!( + "Creating new Chidori instance with file_id={}, url={}, api_token={:?}", + file_id, + url, + "".to_string() + ); Chidori { file_id, current_head: 0, @@ -79,10 +103,10 @@ impl Chidori { match result { Ok(_) => { println!("Server exited"); - }, + } Err(e) => { println!("Error running server: {}", e); - }, + } } }); @@ -92,9 +116,13 @@ impl Chidori { Ok(connection) => { eprintln!("Connection successfully established {:?}", &url); return Ok(()); - }, + } Err(e) => { - eprintln!("Error connecting to server: {} with Error {}. Retrying...", &url, &e.to_string()); + eprintln!( + "Error connecting to server: {} with Error {}. Retrying...", + &url, + &e.to_string() + ); std::thread::sleep(std::time::Duration::from_millis(1000)); } } @@ -105,11 +133,13 @@ impl Chidori { let file_id = self.file_id.clone(); let url = self.url.clone(); let mut client = get_client(url).await?; - let result = client.play(RequestAtFrame { - id: file_id, - frame, - branch, - }).await?; + let result = client + .play(RequestAtFrame { + id: file_id, + frame, + branch, + }) + .await?; Ok(result.into_inner()) } @@ -119,59 +149,70 @@ impl Chidori { let branch = self.current_branch.clone(); let mut client = get_client(url).await?; - let result = client.pause(RequestAtFrame { - id: file_id, - frame, - branch, - }).await?; + let result = client + .pause(RequestAtFrame { + id: file_id, + frame, + branch, + }) + .await?; Ok(result.into_inner()) } - pub async fn query( &self, query: String, branch: u64, frame: u64, ) -> anyhow::Result { + pub async fn query( + &self, + query: String, + branch: u64, + frame: u64, + ) -> anyhow::Result { let file_id = self.file_id.clone(); let url = self.url.clone(); let mut client = get_client(url).await?; - let result = client.run_query(QueryAtFrame { - id: file_id, - query: Some(Query { - query: Some(query) - }), - frame, - branch, - }).await?; + let result = client + .run_query(QueryAtFrame { + id: file_id, + query: Some(Query { query: Some(query) }), + frame, + branch, + }) + .await?; Ok(result.into_inner()) } - pub async fn branch( &self, branch: u64, frame: u64, ) -> anyhow::Result { + pub async fn branch(&self, branch: u64, frame: u64) -> anyhow::Result { let file_id = self.file_id.clone(); let url = self.url.clone(); let mut client = get_client(url).await?; - let result = client.branch(RequestNewBranch { - id: file_id, - source_branch_id: branch, - diverges_at_counter: frame - }).await?; + let result = client + .branch(RequestNewBranch { + id: file_id, + source_branch_id: branch, + diverges_at_counter: frame, + }) + .await?; Ok(result.into_inner()) } - pub async fn list_branches( &self) -> anyhow::Result { + pub async fn list_branches(&self) -> anyhow::Result { let file_id = self.file_id.clone(); let url = self.url.clone(); let mut client = get_client(url).await?; - let result = client.list_branches(RequestListBranches { - id: file_id, - }).await?; + let result = client + .list_branches(RequestListBranches { id: file_id }) + .await?; Ok(result.into_inner()) } - pub async fn display_graph_structure( &self, branch: u64) -> anyhow::Result { + pub async fn display_graph_structure(&self, branch: u64) -> anyhow::Result { let file_id = self.file_id.clone(); let url = self.url.clone(); let mut client = get_client(url).await?; - let file = client.current_file_state(RequestOnlyId { - id: file_id, - branch - }).await?; + let file = client + .current_file_state(RequestOnlyId { + id: file_id, + branch, + }) + .await?; let mut file = file.into_inner(); let mut g = CleanedDefinitionGraph::zero(); g.merge_file(&mut file).unwrap(); @@ -182,93 +223,103 @@ impl Chidori { let file_id = self.file_id.clone(); let url = self.url.clone(); let mut client = get_client(url).await?; - let resp = client.list_registered_graphs(Empty { }).await?; + let resp = client.list_registered_graphs(Empty {}).await?; let mut graphs = resp.into_inner(); info!("Registered Graphs = {:?}", graphs); Ok(()) } -// -// // TODO: need to figure out how to handle callbacks -// // fn list_input_proposals<'a>( -// // mut self_: PyRefMut<'_, Self>, -// // py: Python<'a>, -// // callback: PyObject -// // ) -> PyResult<&'a PyAny> { -// // let file_id = self_.file_id.clone(); -// // let url = self_.url.clone(); -// // let branch = self_.current_branch; -// // pyo3_asyncio::tokio::future_into_py(py, async move { -// // let mut client = get_client(url).await?; -// // let resp = client.list_input_proposals(RequestOnlyId { -// // id: file_id, -// // branch, -// // }).await.map_err(PyErrWrapper::from)?; -// // let mut stream = resp.into_inner(); -// // while let Some(x) = stream.next().await { -// // // callback.call(py, (x,), None); -// // info!("InputProposals = {:?}", x); -// // }; -// // Ok(()) -// // }) -// // } -// -// // fn respond_to_input_proposal(mut self_: PyRefMut<'_, Self>) -> PyResult<()> { -// // Ok(()) -// // } -// -// // TODO: need to figure out how to handle callbacks -// // fn list_change_events<'a>( -// // mut self_: PyRefMut<'_, Self>, -// // py: Python<'a>, -// // callback: PyObject -// // ) -> PyResult<&'a PyAny> { -// // let file_id = self_.file_id.clone(); -// // let url = self_.url.clone(); -// // let branch = self_.current_branch; -// // pyo3_asyncio::tokio::future_into_py(py, async move { -// // let mut client = get_client(url).await?; -// // let resp = client.list_change_events(RequestOnlyId { -// // id: file_id, -// // branch, -// // }).await.map_err(PyErrWrapper::from)?; -// // let mut stream = resp.into_inner(); -// // while let Some(x) = stream.next().await { -// // Python::with_gil(|py| pyo3_asyncio::tokio::into_future(callback.as_ref(py).call((x.map(ChangeValueWithCounterWrapper).map_err(PyErrWrapper::from)?,), None)?))? -// // .await?; -// // }; -// // Ok(()) -// // }) -// // } -// -// -// -// // TODO: this should accept an "Object" instead of args -// // TODO: nodes that are added should return a clean definition of what their addition looks like -// // TODO: adding a node should also display any errors + // + // // TODO: need to figure out how to handle callbacks + // // fn list_input_proposals<'a>( + // // mut self_: PyRefMut<'_, Self>, + // // py: Python<'a>, + // // callback: PyObject + // // ) -> PyResult<&'a PyAny> { + // // let file_id = self_.file_id.clone(); + // // let url = self_.url.clone(); + // // let branch = self_.current_branch; + // // pyo3_asyncio::tokio::future_into_py(py, async move { + // // let mut client = get_client(url).await?; + // // let resp = client.list_input_proposals(RequestOnlyId { + // // id: file_id, + // // branch, + // // }).await.map_err(PyErrWrapper::from)?; + // // let mut stream = resp.into_inner(); + // // while let Some(x) = stream.next().await { + // // // callback.call(py, (x,), None); + // // info!("InputProposals = {:?}", x); + // // }; + // // Ok(()) + // // }) + // // } + // + // // fn respond_to_input_proposal(mut self_: PyRefMut<'_, Self>) -> PyResult<()> { + // // Ok(()) + // // } + // + // // TODO: need to figure out how to handle callbacks + // // fn list_change_events<'a>( + // // mut self_: PyRefMut<'_, Self>, + // // py: Python<'a>, + // // callback: PyObject + // // ) -> PyResult<&'a PyAny> { + // // let file_id = self_.file_id.clone(); + // // let url = self_.url.clone(); + // // let branch = self_.current_branch; + // // pyo3_asyncio::tokio::future_into_py(py, async move { + // // let mut client = get_client(url).await?; + // // let resp = client.list_change_events(RequestOnlyId { + // // id: file_id, + // // branch, + // // }).await.map_err(PyErrWrapper::from)?; + // // let mut stream = resp.into_inner(); + // // while let Some(x) = stream.next().await { + // // Python::with_gil(|py| pyo3_asyncio::tokio::into_future(callback.as_ref(py).call((x.map(ChangeValueWithCounterWrapper).map_err(PyErrWrapper::from)?,), None)?))? + // // .await?; + // // }; + // // Ok(()) + // // }) + // // } + // + // + // + // // TODO: this should accept an "Object" instead of args + // // TODO: nodes that are added should return a clean definition of what their addition looks like + // // TODO: adding a node should also display any errors pub fn register_custom_node_handle(&mut self, key: String, handler: Handler) { self.custom_node_handlers.insert(key, Arc::new(handler)); } - pub async fn poll_local_code_node_execution(&self) -> anyhow::Result { + pub async fn poll_local_code_node_execution( + &self, + ) -> anyhow::Result { let file_id = self.file_id.clone(); let url = self.url.clone(); let mut client = get_client(url).await?; - let req = FilteredPollNodeWillExecuteEventsRequest { id: file_id.clone() }; + let req = FilteredPollNodeWillExecuteEventsRequest { + id: file_id.clone(), + }; let result = client.poll_custom_node_will_execute_events(req).await?; Ok(result.into_inner()) } - pub async fn ack_local_code_node_execution(&self, branch: u64, counter : u64) -> anyhow::Result { + pub async fn ack_local_code_node_execution( + &self, + branch: u64, + counter: u64, + ) -> anyhow::Result { let file_id = self.file_id.clone(); let url = self.url.clone(); let mut client = get_client(url).await?; - let result = client.ack_node_will_execute_event(RequestAckNodeWillExecuteEvent { - id: file_id.clone(), - branch, - counter, - }).await?; + let result = client + .ack_node_will_execute_event(RequestAckNodeWillExecuteEvent { + id: file_id.clone(), + branch, + counter, + }) + .await?; Ok(result.into_inner()) } @@ -277,22 +328,21 @@ impl Chidori { branch: u64, counter: u64, node_name: String, - response: T + response: T, ) -> anyhow::Result { let file_id = self.file_id.clone(); let url = self.url.clone(); let json_object = serde_json::to_value(response)?; let response_paths = json_value_to_paths(&json_object); - let filled_values = response_paths.into_iter().map(|path| { - ChangeValue { - path: Some(Path { - address: path.0, - }), + let filled_values = response_paths + .into_iter() + .map(|path| ChangeValue { + path: Some(Path { address: path.0 }), value: Some(path.1), branch, - } - }).collect(); + }) + .collect(); // TODO: need parent counters from the original change // TODO: need source node @@ -300,19 +350,22 @@ impl Chidori { // TODO: need to add the output table paths to these // TODO: this needs to look more like a real change - Ok(client.push_worker_event(FileAddressedChangeValueWithCounter { - branch, - counter, - node_name: node_name.clone(), - id: file_id.clone(), - change: Some(ChangeValueWithCounter { - filled_values, - parent_monotonic_counters: vec![], - monotonic_counter: counter, + Ok(client + .push_worker_event(FileAddressedChangeValueWithCounter { branch, - source_node: node_name.clone(), + counter, + node_name: node_name.clone(), + id: file_id.clone(), + change: Some(ChangeValueWithCounter { + filled_values, + parent_monotonic_counters: vec![], + monotonic_counter: counter, + branch, + source_node: node_name.clone(), + }), }) - }).await?.into_inner()) + .await? + .into_inner()) } pub async fn run_custom_node_loop(&self) -> anyhow::Result<()> { @@ -325,15 +378,30 @@ impl Chidori { continue; } else { backoff = 2; - for ev in &events.node_will_execute_events { + for ev in &events.node_will_execute_events { // ACK messages - let NodeWillExecuteOnBranch { branch, counter, node, ..} = ev; + let NodeWillExecuteOnBranch { + branch, + counter, + node, + .. + } = ev; let node_name = &node.as_ref().unwrap().source_node; - if let Some(x) = self.custom_node_handlers.get(&ev.custom_node_type_name.clone().unwrap()) { - self.ack_local_code_node_execution(*branch, *counter).await?; + if let Some(x) = self + .custom_node_handlers + .get(&ev.custom_node_type_name.clone().unwrap()) + { + self.ack_local_code_node_execution(*branch, *counter) + .await?; let result = (x.as_ref().callback)(ev.clone()).await?; dbg!(&result); - self.respond_local_code_node_execution(*branch, *counter, node_name.clone(), result).await?; + self.respond_local_code_node_execution( + *branch, + *counter, + node_name.clone(), + result, + ) + .await?; } } } @@ -341,7 +409,6 @@ impl Chidori { } } - fn default_triggers() -> Option> { Some(vec!["None".to_string()]) } @@ -352,7 +419,7 @@ pub struct PromptNodeCreateOpts { pub triggers: Option>, pub output_tables: Option>, pub template: String, - pub model: Option + pub model: Option, } impl Default for PromptNodeCreateOpts { @@ -376,18 +443,15 @@ impl PromptNodeCreateOpts { } } - - #[derive(serde::Serialize, serde::Deserialize)] pub struct CustomNodeCreateOpts { pub name: String, pub triggers: Option>, pub output_tables: Option>, pub output: Option, - pub node_type_name: String + pub node_type_name: String, } - impl Default for CustomNodeCreateOpts { fn default() -> Self { CustomNodeCreateOpts { @@ -409,8 +473,6 @@ impl CustomNodeCreateOpts { } } - - #[derive(serde::Serialize, serde::Deserialize)] pub struct DenoCodeNodeCreateOpts { pub name: String, @@ -418,7 +480,7 @@ pub struct DenoCodeNodeCreateOpts { pub output_tables: Option>, pub output: Option, pub code: String, - pub is_template: Option + pub is_template: Option, } impl Default for DenoCodeNodeCreateOpts { @@ -452,13 +514,12 @@ pub struct VectorMemoryNodeCreateOpts { pub output_tables: Option>, pub output: Option, pub template: Option, // TODO: default is the contents of the query - pub action: Option, // TODO: default WRITE + pub action: Option, // TODO: default WRITE pub embedding_model: Option, // TODO: default TEXT_EMBEDDING_ADA_002 pub db_vendor: Option, // TODO: default QDRANT pub collection_name: String, } - impl Default for VectorMemoryNodeCreateOpts { fn default() -> Self { VectorMemoryNodeCreateOpts { @@ -475,7 +536,6 @@ impl Default for VectorMemoryNodeCreateOpts { } } - impl VectorMemoryNodeCreateOpts { pub fn merge(&mut self, other: VectorMemoryNodeCreateOpts) { self.name = other.name; @@ -492,13 +552,16 @@ impl VectorMemoryNodeCreateOpts { fn remap_triggers(triggers: Option>) -> Vec> { let triggers: Vec> = if let Some(triggers) = triggers { - triggers.into_iter().map(|q| { - if q == "None".to_string() { - None - } else { - Some(q) - } - }).collect() + triggers + .into_iter() + .map(|q| { + if q == "None".to_string() { + None + } else { + Some(q) + } + }) + .collect() } else { vec![] }; @@ -513,7 +576,7 @@ pub struct GraphBuilder { impl GraphBuilder { pub fn new() -> Self { GraphBuilder { - clean_graph: CleanedDefinitionGraph::zero() + clean_graph: CleanedDefinitionGraph::zero(), } } pub fn prompt_node(&mut self, arg: PromptNodeCreateOpts) -> anyhow::Result { @@ -524,8 +587,12 @@ impl GraphBuilder { remap_triggers(def.triggers), def.template, def.model.unwrap_or("GPT_3_5_TURBO".to_string()), - def.output_tables.unwrap_or(vec![]))?; - self.clean_graph.merge_file(&File { nodes: vec![node.clone()], ..Default::default() })?; + def.output_tables.unwrap_or(vec![]), + )?; + self.clean_graph.merge_file(&File { + nodes: vec![node.clone()], + ..Default::default() + })?; Ok(NodeHandle::from(node)?) } @@ -537,13 +604,15 @@ impl GraphBuilder { remap_triggers(def.triggers.clone()), def.output.unwrap_or("{}".to_string()), def.node_type_name, - def.output_tables.unwrap_or(vec![]) + def.output_tables.unwrap_or(vec![]), ); - self.clean_graph.merge_file(&File { nodes: vec![node.clone()], ..Default::default() })?; + self.clean_graph.merge_file(&File { + nodes: vec![node.clone()], + ..Default::default() + })?; Ok(NodeHandle::from(node)?) } - pub fn deno_code_node(&mut self, arg: DenoCodeNodeCreateOpts) -> anyhow::Result { let mut def = DenoCodeNodeCreateOpts::default(); def.merge(arg); @@ -551,15 +620,24 @@ impl GraphBuilder { def.name.clone(), remap_triggers(def.triggers.clone()), def.output.unwrap_or("{}".to_string()), - SourceNodeType::Code("DENO".to_string(), def.code, def.is_template.unwrap_or(false)), - def.output_tables.unwrap_or(vec![]) + SourceNodeType::Code( + "DENO".to_string(), + def.code, + def.is_template.unwrap_or(false), + ), + def.output_tables.unwrap_or(vec![]), ); - self.clean_graph.merge_file(&File { nodes: vec![node.clone()], ..Default::default() })?; + self.clean_graph.merge_file(&File { + nodes: vec![node.clone()], + ..Default::default() + })?; Ok(NodeHandle::from(node)?) } - - pub fn vector_memory_node(&mut self, arg: VectorMemoryNodeCreateOpts) -> anyhow::Result { + pub fn vector_memory_node( + &mut self, + arg: VectorMemoryNodeCreateOpts, + ) -> anyhow::Result { let mut def = VectorMemoryNodeCreateOpts::default(); def.merge(arg); let node = create_vector_memory_node( @@ -567,67 +645,70 @@ impl GraphBuilder { remap_triggers(def.triggers.clone()), def.output.unwrap_or("{}".to_string()), def.action.unwrap_or("READ".to_string()), - def.embedding_model.unwrap_or("TEXT_EMBEDDING_ADA_002".to_string()), + def.embedding_model + .unwrap_or("TEXT_EMBEDDING_ADA_002".to_string()), def.template.unwrap_or("".to_string()), def.db_vendor.unwrap_or("QDRANT".to_string()), def.collection_name, - def.output_tables.unwrap_or(vec![]) + def.output_tables.unwrap_or(vec![]), )?; - self.clean_graph.merge_file(&File { nodes: vec![node.clone()], ..Default::default() })?; + self.clean_graph.merge_file(&File { + nodes: vec![node.clone()], + ..Default::default() + })?; Ok(NodeHandle::from(node)?) } -// -// -// // -// // fn observation_node(mut self_: PyRefMut<'_, Self>, name: String, query_def: Option, template: String, model: String) -> PyResult<()> { -// // let file_id = self_.file_id.clone(); -// // let node = create_observation_node( -// // "".to_string(), -// // None, -// // "".to_string(), -// // ); -// // executor::block_on(self_.client.merge(RequestFileMerge { -// // id: file_id, -// // file: Some(File { -// // nodes: vec![node], -// // ..Default::default() -// // }), -// // branch: 0, -// // })); -// // Ok(()) -// // } + // + // + // // + // // fn observation_node(mut self_: PyRefMut<'_, Self>, name: String, query_def: Option, template: String, model: String) -> PyResult<()> { + // // let file_id = self_.file_id.clone(); + // // let node = create_observation_node( + // // "".to_string(), + // // None, + // // "".to_string(), + // // ); + // // executor::block_on(self_.client.merge(RequestFileMerge { + // // id: file_id, + // // file: Some(File { + // // nodes: vec![node], + // // ..Default::default() + // // }), + // // branch: 0, + // // })); + // // Ok(()) + // // } // // TODO: need to figure out passing a buffer of bytes -// // TODO: nodes that are added should return a clean definition of what their addition looks like -// // TODO: adding a node should also display any errors -// /// x = None -// /// with open("/Users/coltonpierson/Downloads/files_and_dirs.zip", "rb") as zip_file: -// /// contents = zip_file.read() -// /// x = await p.load_zip_file("LoadZip", """ output: String """, contents) -// /// x -// // #[pyo3(signature = (name=String::new(), output_tables=vec![], output=String::new(), bytes=vec![]))] -// // fn load_zip_file<'a>( -// // mut self_: PyRefMut<'_, Self>, -// // py: Python<'a>, -// // name: String, -// // output_tables: Vec, -// // output: String, -// // bytes: Vec -// // ) -> PyResult<&'a PyAny> { -// // let file_id = self_.file_id.clone(); -// // let url = self_.url.clone(); -// // pyo3_asyncio::tokio::future_into_py(py, async move { -// // let node = create_loader_node( -// // name, -// // vec![], -// // output, -// // LoadFrom::ZipfileBytes(bytes), -// // output_tables -// // ); -// // Ok(push_file_merge(&url, &file_id, node).await?) -// // }) -// // } - + // // TODO: nodes that are added should return a clean definition of what their addition looks like + // // TODO: adding a node should also display any errors + // /// x = None + // /// with open("/Users/coltonpierson/Downloads/files_and_dirs.zip", "rb") as zip_file: + // /// contents = zip_file.read() + // /// x = await p.load_zip_file("LoadZip", """ output: String """, contents) + // /// x + // // #[pyo3(signature = (name=String::new(), output_tables=vec![], output=String::new(), bytes=vec![]))] + // // fn load_zip_file<'a>( + // // mut self_: PyRefMut<'_, Self>, + // // py: Python<'a>, + // // name: String, + // // output_tables: Vec, + // // output: String, + // // bytes: Vec + // // ) -> PyResult<&'a PyAny> { + // // let file_id = self_.file_id.clone(); + // // let url = self_.url.clone(); + // // pyo3_asyncio::tokio::future_into_py(py, async move { + // // let node = create_loader_node( + // // name, + // // vec![], + // // output, + // // LoadFrom::ZipfileBytes(bytes), + // // output_tables + // // ); + // // Ok(push_file_merge(&url, &file_id, node).await?) + // // }) + // // } pub fn serialize_yaml(&self) -> anyhow::Result { Ok(self.clean_graph.serialize_to_yaml()) @@ -637,35 +718,41 @@ impl GraphBuilder { let url = &c.url; let file_id = &c.file_id; let mut client = get_client(url.clone()).await?; - let nodes = self.clean_graph.node_by_name.clone().into_values().collect(); - - Ok(client.merge(RequestFileMerge { - id: file_id.clone(), - file: Some(File { nodes, ..Default::default() }), - branch: 0, - }).await.map(|x| x.into_inner())?) + let nodes = self + .clean_graph + .node_by_name + .clone() + .into_values() + .collect(); + + Ok(client + .merge(RequestFileMerge { + id: file_id.clone(), + file: Some(File { + nodes, + ..Default::default() + }), + branch: 0, + }) + .await + .map(|x| x.into_inner())?) } } - // Node handle #[derive(Clone)] pub struct NodeHandle { pub node: Item, - indiv: CleanIndividualNode + indiv: CleanIndividualNode, } impl NodeHandle { fn from(node: Item) -> anyhow::Result { let indiv = derive_for_individual_node(&node)?; - Ok(NodeHandle { - node, - indiv - }) + Ok(NodeHandle { node, indiv }) } } - impl NodeHandle { pub(crate) fn get_name(&self) -> String { self.node.core.as_ref().unwrap().name.clone() @@ -675,7 +762,11 @@ impl NodeHandle { self.indiv.output_paths.clone() } - pub fn run_when(&mut self, graph_builder: &mut GraphBuilder, other_node: &NodeHandle) -> anyhow::Result { + pub fn run_when( + &mut self, + graph_builder: &mut GraphBuilder, + other_node: &NodeHandle, + ) -> anyhow::Result { let triggers = &mut self.node.core.as_mut().unwrap().triggers; // Remove null query if it is the only one present @@ -686,26 +777,36 @@ impl NodeHandle { let q = construct_query_from_output_type( &other_node.get_name(), &other_node.get_name(), - &other_node.get_output_type() - ).unwrap(); - triggers.push(Query { query: Some(q)}); - graph_builder.clean_graph.merge_file(&File { nodes: vec![self.node.clone()], ..Default::default() })?; + &other_node.get_output_type(), + ) + .unwrap(); + triggers.push(Query { query: Some(q) }); + graph_builder.clean_graph.merge_file(&File { + nodes: vec![self.node.clone()], + ..Default::default() + })?; Ok(true) } - - pub async fn query(&self, file_id: String, url: String, branch: u64, frame: u64) -> anyhow::Result> { + pub async fn query( + &self, + file_id: String, + url: String, + branch: u64, + frame: u64, + ) -> anyhow::Result> { let name = &self.node.core.as_ref().unwrap().name; - let query = construct_query_from_output_type(&name, &name, &self.indiv.output_paths).unwrap(); + let query = + construct_query_from_output_type(&name, &name, &self.indiv.output_paths).unwrap(); let mut client = get_client(url).await?; - let result = client.run_query(QueryAtFrame { - id: file_id, - query: Some(Query { - query: Some(query) - }), - frame, - branch, - }).await?; + let result = client + .run_query(QueryAtFrame { + id: file_id, + query: Some(Query { query: Some(query) }), + frame, + branch, + }) + .await?; let res = result.into_inner(); let mut obj = HashMap::new(); for value in res.values.iter() { @@ -716,24 +817,22 @@ impl NodeHandle { } Ok(obj) } - } #[macro_export] macro_rules! register_node_handle { ($c:expr, $name:expr, $handler:expr) => { - $c.register_custom_node_handle($name.to_string(), Handler::new( - move |n| Box::pin(async move { ($handler)(n).await }) - )); + $c.register_custom_node_handle( + $name.to_string(), + Handler::new(move |n| Box::pin(async move { ($handler)(n).await })), + ); }; } - #[cfg(test)] mod tests { use super::*; #[test] - fn test_new_graph() { - } + fn test_new_graph() {} } diff --git a/toolchain/prompt-graph-core/Cargo.toml b/toolchain/prompt-graph-core/Cargo.toml index 6ae10b7..85ec080 100644 --- a/toolchain/prompt-graph-core/Cargo.toml +++ b/toolchain/prompt-graph-core/Cargo.toml @@ -14,6 +14,7 @@ description = "Core of Chidori, compiles graph and node definitions into an inte name = "prompt_graph_core" crate-type = ["cdylib", "lib"] bench = false +proc-macro = true [features] python = [] @@ -36,13 +37,35 @@ tokio.workspace = true env_logger.workspace = true log.workspace = true futures.workspace = true +rkyv.workspace = true +ariadne = "0.3.0" +chumsky = "0.9.3" im = "15.1.0" num_cpus = "1" petgraph = "0.6.3" typescript-type-def = "0.5.7" serde_yaml = "0.9.25" handlebars = "4.3.7" +syn = "1.0" +quote = "1.0" +crossbeam-utils = "0.8.15" +priority-queue = "1.3.2" +rand = "0.8" +async-trait = "0.1.69" + +# TODO: make optional + + +# TODO: make optional +deno_core = "0.236.0" +starlark = { version = "0.9.0"} +http-body-util = "0.1.0-rc.2" +zip = "0.6.6" +qdrant-client = "1.3.0" +hnsw_rs_thousand_birds = "0.1.20" + +base64 = "0.21.2" [build-dependencies] tonic-build = "0.9.2" diff --git a/toolchain/prompt-graph-core/README.md b/toolchain/prompt-graph-core/README.md index f1cd1e0..8846b21 100644 --- a/toolchain/prompt-graph-core/README.md +++ b/toolchain/prompt-graph-core/README.md @@ -2,3 +2,13 @@ This implements an interface for constructing prompt graphs. This can be used to annotate existing implementations with graph definitions as well. + +## Features + +- [ ] A graph definition language for reactive programs, wrapping other execution runtimes +- [ ] A pattern for annotating existing code to expose it to the graph definition language +- [ ] A scheduler for executing reactive programs +- [ ] Support for branching and merging reactive programs +- [ ] A wrapper around handlebars for rendering templates that supports tracing +- [ ] A standard library of core agent functionality +- [ ] Support for long running durable execution of agents diff --git a/toolchain/prompt-graph-core/build.rs b/toolchain/prompt-graph-core/build.rs index f801f6c..7377504 100644 --- a/toolchain/prompt-graph-core/build.rs +++ b/toolchain/prompt-graph-core/build.rs @@ -6,9 +6,12 @@ fn main() -> Result<(), Box> { .out_dir("./src/generated_protobufs") .build_server(true) .type_attribute(".", "#[derive(serde::Deserialize, serde::Serialize)]") // adding attributes - .type_attribute("promptgraph.ExecutionStatus", "#[derive(typescript_type_def::TypeDef)]") // adding attributes + .type_attribute( + "promptgraph.ExecutionStatus", + "#[derive(typescript_type_def::TypeDef)]", + ) // adding attributes .compile(&["./protobufs/DSL_v1.proto"], &["./protobufs/"]) .unwrap_or_else(|e| panic!("protobuf compile error: {}", e)); Ok(()) -} \ No newline at end of file +} diff --git a/toolchain/prompt-graph-core/examples/simple/arithmetic.rs b/toolchain/prompt-graph-core/examples/simple/arithmetic.rs new file mode 100644 index 0000000..ef9b5e9 --- /dev/null +++ b/toolchain/prompt-graph-core/examples/simple/arithmetic.rs @@ -0,0 +1 @@ +use prompt_graph_core::execution::sdk; diff --git a/toolchain/prompt-graph-core/protobufs/.idea/.gitignore b/toolchain/prompt-graph-core/protobufs/.idea/.gitignore deleted file mode 100644 index 13566b8..0000000 --- a/toolchain/prompt-graph-core/protobufs/.idea/.gitignore +++ /dev/null @@ -1,8 +0,0 @@ -# Default ignored files -/shelf/ -/workspace.xml -# Editor-based HTTP Client requests -/httpRequests/ -# Datasource local storage ignored files -/dataSources/ -/dataSources.local.xml diff --git a/toolchain/prompt-graph-core/protobufs/.idea/misc.xml b/toolchain/prompt-graph-core/protobufs/.idea/misc.xml deleted file mode 100644 index 639900d..0000000 --- a/toolchain/prompt-graph-core/protobufs/.idea/misc.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/toolchain/prompt-graph-core/protobufs/.idea/modules.xml b/toolchain/prompt-graph-core/protobufs/.idea/modules.xml deleted file mode 100644 index cb581db..0000000 --- a/toolchain/prompt-graph-core/protobufs/.idea/modules.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/toolchain/prompt-graph-core/protobufs/.idea/protobufs.iml b/toolchain/prompt-graph-core/protobufs/.idea/protobufs.iml deleted file mode 100644 index d6ebd48..0000000 --- a/toolchain/prompt-graph-core/protobufs/.idea/protobufs.iml +++ /dev/null @@ -1,9 +0,0 @@ - - - - - - - - - \ No newline at end of file diff --git a/toolchain/prompt-graph-core/protobufs/.idea/vcs.xml b/toolchain/prompt-graph-core/protobufs/.idea/vcs.xml deleted file mode 100644 index 6c0b863..0000000 --- a/toolchain/prompt-graph-core/protobufs/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/toolchain/prompt-graph-core/protobufs/DSL_v1.proto b/toolchain/prompt-graph-core/protobufs/DSL_v1.proto deleted file mode 100644 index 926544b..0000000 --- a/toolchain/prompt-graph-core/protobufs/DSL_v1.proto +++ /dev/null @@ -1,536 +0,0 @@ -syntax = "proto3"; -package promptgraph; - -// This format is used to serialize and deserialize -// PromptGraph definitions -// This is used to provide a language agnostic interface -// for defining PromptGraphs - - -// TODO: capabilities also gives us a security model - -// A Node is the core primitive within a PromptGraph, it is -// at a base level, just a function. However it includes an -// execution_capabilities which is used to determine where in our -// environment the function should be invoked. This allows us -// to determine in what environment this function must run. -message Node { - string handle = 1; - string execution_capabilities = 2; - repeated string argument_array = 3; - string inner_function_handle = 4; -} - -message Triggerable { - string node_handle = 1; - Query query = 2; -} - -message Subscribeable { - string node_handle = 1; - OutputType output = 2; -} - -enum SupportedChatModel { - GPT_4 = 0; - GPT_4_0314 = 1; - GPT_4_32K = 2; - GPT_4_32K_0314 = 3; - GPT_3_5_TURBO = 4; - GPT_3_5_TURBO_0301 = 5; -} - -enum SupportedCompletionModel { - TEXT_DAVINCI_003 = 0; - TEXT_DAVINCI_002 = 1; - TEXT_CURIE_001 = 2; - TEXT_BABBAGE_001 = 3; - TEXT_ADA_00 = 4; -} - -enum SupportedEmebddingModel { - TEXT_EMBEDDING_ADA_002 = 0; - TEXT_SEARCH_ADA_DOC_001 = 1; -} - -enum SupportedVectorDatabase { - IN_MEMORY = 0; - CHROMA = 1; - PINECONEDB = 2; - QDRANT = 3; -} - -enum SupportedSourceCodeLanguages { - DENO = 0; - STARLARK = 1; -} - -message Query { - optional string query = 1; -} - -// Processed version of the Query -message QueryPaths { - string node = 1; - repeated Path path = 2; -} - -message OutputType { - string output = 2; -} - -// Processed version of the OutputType -message OutputPaths { - string node = 1; - repeated Path path = 2; -} - -// Alias is a reference to another node, any value set -// on this node will propagate for the alias as well -message PromptGraphAlias { - string from = 2; - string to = 3; -} - -message PromptGraphConstant { - SerializedValue value = 2; -} - -message PromptGraphVar { -} - -message PromptGraphOutputValue { -} - - -message PromptGraphNodeCodeSourceCode { - SupportedSourceCodeLanguages language = 1; - string sourceCode = 2; - bool template = 3; -} - -message PromptGraphParameterNode { -} - -message PromptGraphMap { - string path = 4; -} - -message PromptGraphNodeCode { - oneof source { - PromptGraphNodeCodeSourceCode sourceCode = 6; - bytes zipfile = 7; - string s3Path = 8; - } -} - -message PromptGraphNodeLoader { - oneof loadFrom { - // Load a zip file, decompress it, and make the paths available as keys - bytes zipfileBytes = 1; - } -} - -message PromptGraphNodeCustom { - string type_name = 1; -} - -// TODO: we should allow the user to freely manipulate wall-clock time -// Output value of this should just be the timestamp -message PromptGraphNodeSchedule { - oneof policy { - string crontab = 1; - string naturalLanguage = 2; - string everyMs = 3; - } -} - -message PromptGraphNodePrompt { - string template = 4; - oneof model { - SupportedChatModel chatModel = 5; - SupportedCompletionModel completionModel = 6; - } - float temperature = 7; - float top_p = 8; - int32 max_tokens = 9; - float presence_penalty = 10; - float frequency_penalty = 11; - repeated string stop = 12; - // TODO: set the user token - // TODO: support logit bias -} - -enum MemoryAction { - READ = 0; - WRITE = 1; - DELETE = 2; -} - -// TODO: this expects a selector for the query? - no its a template and you build that -// TODO: what about the output type? pre-defined -// TODO: what about the metadata? -// TODO: metadata could be an independent query, or it could instead be a template too -message PromptGraphNodeMemory { - string collectionName = 3; - string template = 4; - oneof embeddingModel { - SupportedEmebddingModel model = 5; - } - oneof vectorDbProvider { - SupportedVectorDatabase db = 6; - } - MemoryAction action = 7; -} - -message PromptGraphNodeObservation { - string integration = 4; -} - -message PromptGraphNodeComponent { - oneof transclusion { - File inlineFile = 4; - bytes bytesReference = 5; - string s3PathReference = 6; - } -} - -message PromptGraphNodeEcho { -} - -message PromptGraphNodeJoin { - // TODO: configure resolving joins -} - -message ItemCore { - string name = 1; - repeated Query triggers = 2; - repeated string outputTables = 3; - OutputType output = 4; -} - -message Item { - ItemCore core = 1; - oneof item { - PromptGraphAlias alias = 2; - PromptGraphMap map = 3; - PromptGraphConstant constant = 4; - PromptGraphVar variable = 5; - PromptGraphOutputValue output = 6; - // TODO: delete above this line - PromptGraphNodeCode nodeCode = 7; - PromptGraphNodePrompt nodePrompt = 8; - PromptGraphNodeMemory nodeMemory = 9; - PromptGraphNodeComponent nodeComponent = 10; - PromptGraphNodeObservation nodeObservation = 11; - PromptGraphParameterNode nodeParameter = 12; - PromptGraphNodeEcho nodeEcho = 13; - PromptGraphNodeLoader nodeLoader = 14; - PromptGraphNodeCustom nodeCustom = 15; - PromptGraphNodeJoin nodeJoin = 16; - PromptGraphNodeSchedule nodeSchedule = 17; - } -} - -// TODO: add a flag for 'Cleaned', 'Dirty', 'Validated' -message File { - string id = 1; - repeated Item nodes = 2; -} - -message Path { - repeated string address = 1; -} - -message TypeDefinition { - oneof type { - PrimitiveType primitive = 1; - ArrayType array = 2; - ObjectType object = 3; - UnionType union = 4; - IntersectionType intersection = 5; - OptionalType optional = 6; - EnumType enum = 7; - } -} - -message PrimitiveType { - oneof primitive { - bool is_string = 1; - bool is_number = 2; - bool is_boolean = 3; - bool is_null = 4; - bool is_undefined = 5; - } -} - -message ArrayType { - TypeDefinition type = 1; -} - -message ObjectType { - map fields = 1; -} - -message UnionType { - repeated TypeDefinition types = 1; -} - -message IntersectionType { - repeated TypeDefinition types = 1; -} - -message OptionalType { - TypeDefinition type = 1; -} - -message EnumType { - map values = 1; -} - - -message SerializedValueArray { - repeated SerializedValue values = 1; -} - -message SerializedValueObject { - map values = 1; -} - -message SerializedValue { - oneof val { - float float = 2; - int32 number = 3; - string string = 4; - bool boolean = 5; - SerializedValueArray array = 6; - SerializedValueObject object = 7; - } -} - -message ChangeValue { - Path path = 1; - SerializedValue value = 2; - uint64 branch = 3; -} - - -message WrappedChangeValue { - uint64 monotonicCounter = 3; - ChangeValue changeValue = 4; -} - -// Computation of a node -message NodeWillExecute { - string sourceNode = 1; - repeated WrappedChangeValue changeValuesUsedInExecution = 2; - uint64 matchedQueryIndex = 3; -} - -// Group of node computations to run -message DispatchResult { - repeated NodeWillExecute operations = 1; -} - -message NodeWillExecuteOnBranch { - uint64 branch = 1; - uint64 counter = 2; - optional string custom_node_type_name = 3; - NodeWillExecute node = 4; -} - -message ChangeValueWithCounter { - repeated ChangeValue filledValues = 1; - repeated uint64 parentMonotonicCounters = 2; - uint64 monotonicCounter = 3; - uint64 branch = 4; - string sourceNode = 5; -} - - -message CounterWithPath { - uint64 monotonicCounter = 1; - Path path = 2; -} - -// Input proposals -message InputProposal { - string name = 1; - OutputType output = 2; - uint64 counter = 3; - uint64 branch = 4; -} - - -message RequestInputProposalResponse { - string id = 1; - uint64 proposal_counter = 2; - repeated ChangeValue changes = 3; - uint64 branch = 4; -} - -message DivergentBranch { - uint64 branch = 1; - uint64 diverges_at_counter = 2; -} - -message Branch { - uint64 id = 1; - repeated uint64 source_branch_ids = 2; - repeated DivergentBranch divergent_branches = 3; - uint64 diverges_at_counter = 4; -} - - -message Empty { -} - - -// This is the return value from api calls that reports the current counter and branch the operation -// was performed on. -message ExecutionStatus { - string id = 1; - uint64 monotonicCounter = 2; - uint64 branch = 3; -} - -message FileAddressedChangeValueWithCounter { - string id = 1; - string node_name = 2; - uint64 branch = 3; - uint64 counter = 4; - ChangeValueWithCounter change = 5; -} - -message RequestOnlyId { - string id = 1; - uint64 branch = 2; -} - -message FilteredPollNodeWillExecuteEventsRequest { - string id = 1; -} - - -message RequestAtFrame { - string id = 1; - uint64 frame = 2; - uint64 branch = 3; -} - -message RequestNewBranch { - string id = 1; - uint64 sourceBranchId = 2; - uint64 divergesAtCounter = 3; -} - -message RequestListBranches { - string id = 1; -} - -message ListBranchesRes { - string id = 1; - repeated Branch branches = 2; -} - -message RequestFileMerge { - string id = 1; - File file = 2; - uint64 branch = 3; -} - -message ParquetFile { - bytes data = 1; -} - -message QueryAtFrame { - string id = 1; - Query query = 2; - uint64 frame = 3; - uint64 branch = 4; -} - -message QueryAtFrameResponse { - repeated WrappedChangeValue values = 1; -} - -message RequestAckNodeWillExecuteEvent { - string id = 1; - uint64 branch = 3; - uint64 counter = 4; -} - -message RespondPollNodeWillExecuteEvents { - repeated NodeWillExecuteOnBranch nodeWillExecuteEvents = 1; -} - -message PromptLibraryRecord { - UpsertPromptLibraryRecord record = 1; - uint64 version_counter = 3; -} - -message UpsertPromptLibraryRecord { - string template = 1; - string name = 2; - string id = 3; - optional string description = 4; -} - -message ListRegisteredGraphsResponse { - repeated string ids = 1; -} - -// API: -service ExecutionRuntime { - - rpc RunQuery(QueryAtFrame) returns (QueryAtFrameResponse) {} - - // * Merge a new file - if an existing file is available at the id, will merge the new file into the existing one - rpc Merge(RequestFileMerge) returns (ExecutionStatus) {} - - // * Get the current graph state of a file at a branch and counter position - rpc CurrentFileState(RequestOnlyId) returns (File) {} - - // * Get the parquet history for a specific branch and Id - returns bytes - rpc GetParquetHistory(RequestOnlyId) returns (ParquetFile) {} - - // * Resume execution - rpc Play(RequestAtFrame) returns (ExecutionStatus) {} - - // * Pause execution - rpc Pause(RequestAtFrame) returns (ExecutionStatus) {} - - // * Split history into a separate branch - rpc Branch(RequestNewBranch) returns (ExecutionStatus) {} - - // * Get all branches - rpc ListBranches(RequestListBranches) returns (ListBranchesRes) {} - - // * List all registered files - rpc ListRegisteredGraphs(Empty) returns (ListRegisteredGraphsResponse) {} - - // * Receive a stream of input proposals <- this is a server-side stream - rpc ListInputProposals(RequestOnlyId) returns (stream InputProposal) {} - - // * Push responses to input proposals (these wait for some input from a host until they're resolved) <- RPC client to server - rpc RespondToInputProposal (RequestInputProposalResponse) returns (Empty); - - // * Observe the stream of execution events <- this is a server-side stream - rpc ListChangeEvents(RequestOnlyId) returns (stream ChangeValueWithCounter) {} - - rpc ListNodeWillExecuteEvents(RequestOnlyId) returns (stream NodeWillExecuteOnBranch) {} - - // * Observe when the server thinks our local node implementation should execute and with what changes - rpc PollCustomNodeWillExecuteEvents(FilteredPollNodeWillExecuteEventsRequest) returns (RespondPollNodeWillExecuteEvents) {} - // TODO: this should be with the state they need to execute with - // TODO: need to ack that these messages have been received, and retry sending them to workers? - // TODO: no workers should poll for them like temporal - // TODO: the pace of pulling these need to be managed by the worker - - rpc AckNodeWillExecuteEvent(RequestAckNodeWillExecuteEvent) returns (ExecutionStatus) {} - - // * Receive events from workers <- this is an RPC client to server, we don't need to wait for a response from the server - rpc PushWorkerEvent(FileAddressedChangeValueWithCounter) returns (ExecutionStatus) {} - - rpc PushTemplatePartial(UpsertPromptLibraryRecord) returns (ExecutionStatus) {} -} \ No newline at end of file diff --git a/toolchain/prompt-graph-core/src/build_runtime_graph/graph_parse.rs b/toolchain/prompt-graph-core/src/build_runtime_graph/graph_parse.rs deleted file mode 100644 index 01ca559..0000000 --- a/toolchain/prompt-graph-core/src/build_runtime_graph/graph_parse.rs +++ /dev/null @@ -1,574 +0,0 @@ -use std::collections::{HashMap, HashSet}; -use std::fmt::Write; -use std::{fmt, mem}; -use anyhow::anyhow; -use petgraph::dot::{Config, Dot}; -use petgraph::graphmap::DiGraphMap; -use serde::Serialize; -use sqlparser::ast::{Expr, JoinConstraint, Query, Select, SelectItem, SetExpr, Statement, TableWithJoins}; -use sqlparser::dialect::GenericDialect; -use sqlparser::parser::{Parser, ParserError}; -use crate::graph_definition::DefinitionGraph; -use crate::proto::{File, Item, item as dsl_item}; -use sqlparser::ast::{Join, JoinOperator, TableFactor}; -use crate::reactivity::reactive_sql; -use crate::reactivity::reactive_sql::SQLType; - - - -pub type OutputTypeDefinition = HashMap; - - -pub fn parse_output_type_def_to_paths(input: &str) -> OutputPaths { - serde_json::from_str(input).unwrap() -} - -/// Build a map of output paths to the nodes that they refer to -pub fn output_table_from_output_types(output_paths: &HashMap>>) -> HashMap> { - let mut result: HashMap, Vec> = HashMap::new(); - - for (key, paths) in output_paths.iter() { - for path in paths { - result.entry(path.clone()) - .or_insert(vec![]) - .push(key.clone()); - } - } - - - // Mutate the result so all of the keys are flat - result - .into_iter() - .map(|(k, v)| (k.join(":"), v)) - .collect() -} - -/// Build a map of query paths to the nodes they refer to -pub fn dispatch_table_from_query_paths(query_paths: &HashMap>>) -> HashMap> { - let mut result: HashMap, Vec> = HashMap::new(); - - // for each query path, get the node as key, and the paths - for (key, all_opt_paths) in query_paths.iter() { - for opt_paths in all_opt_paths { - if let Some(paths) = opt_paths { - for path in paths { - result.entry(path.clone()) - .or_insert(vec![]) - .push(key.clone()); - } - } else { - result.entry(vec![]) - .or_insert(vec![]) - .push(key.clone()); - } - } - } - - - // Mutate the result so all of the keys are flat - result - .into_iter() - .map(|(k, v)| (k.join(":"), v)) - .collect() -} - - -#[derive(Debug, Clone)] -pub struct CleanIndividualNode { - pub name: String, - pub query_path: QueryPath, - pub output_paths: OutputPaths, - pub output_tables: HashSet, -} - -pub fn derive_for_individual_node(node: &Item) -> anyhow::Result { - let name = &node.core.as_ref().unwrap().name; - let core = node.core.as_ref().unwrap(); - - let mut query_path: QueryPath = vec![]; - for query in &core.triggers { - if let Some(q) = &query.query { - let paths = query_path_from_query_string(&q)?; - query_path.push(Some(paths)); - } else { - query_path.push(None); - } - } - - // Add the node to the output table - - - let mut output_tables: HashSet<_> = core.output_tables.iter().cloned().collect(); - output_tables.insert(core.name.clone()); - - let mut output_paths = vec![]; - for output in &core.output { - let output_type: OutputTypeDefinition = serde_yaml::from_str(&output.output)?; - for output_table in &output_tables { - for (output_key, _ty) in &output_type { - output_paths.push(vec![output_table.clone(), output_key.clone()]); - } - } - } - - Ok(CleanIndividualNode { - name: name.clone(), - query_path, - output_paths, - output_tables - }) -} - -pub fn query_path_from_query_string(q: &String) -> anyhow::Result>> { - let dependent_on = reactive_sql::parse_tables_and_columns(&q)?; - let mut paths = vec![]; - for (table, columns) in dependent_on { - for column in columns { - let mut path_segment = vec![table.clone()]; - path_segment.push(column); - paths.push(path_segment); - } - } - Ok(paths) -} - - -type QueryVecGroup = Vec>; -type QueryPath = Vec>; -type OutputPaths = Vec>; - - -#[derive(Serialize, Debug, Clone)] -pub struct CleanedDefinitionGraph { - pub query_paths: HashMap, - pub node_by_name: HashMap, - pub dispatch_table: HashMap>, - pub output_table: HashMap>, - pub node_to_output_tables: HashMap>, - pub output_paths: HashMap, - -} - -impl CleanedDefinitionGraph { - pub fn new(definition_graph: &DefinitionGraph) -> Self { - let node_by_name = definition_graph.get_nodes().iter().map(|n| { - let name = &n.core.as_ref().unwrap().name; - (name.clone(), n.clone()) - }).collect(); - - CleanedDefinitionGraph::recompute_parsed_values(node_by_name).unwrap() - } - - pub fn zero() -> Self { - Self { - query_paths: HashMap::new(), - output_table: HashMap::new(), - dispatch_table: HashMap::new(), - output_paths: HashMap::new(), - node_by_name: HashMap::new(), - node_to_output_tables: HashMap::new(), - } - } - - pub fn get_node(&self, name: &str) -> Option<&Item> { - self.node_by_name.get(name) - } - - fn recompute_parsed_values(node_by_name: HashMap) -> anyhow::Result { - // Node name -> list of encoder documents - let mut graph = CleanedDefinitionGraph::zero(); - - for node in node_by_name.values() { - let indiv = &mut derive_for_individual_node(node)?; - let name = &indiv.name; - graph.node_to_output_tables.insert(name.clone(), mem::take(&mut indiv.output_tables)); - // graph.query_types.insert(name.clone(), mem::take(&mut indiv.query_type)); - // graph.output_types.insert(name.clone(), mem::take(&mut indiv.output_type)); - graph.query_paths.insert(name.clone(), mem::take(&mut indiv.query_path)); - graph.output_paths.insert(name.clone(), mem::take(&mut indiv.output_paths)); - } - - // This aggregates all the query types into a single query type - // graph.gql_query_type = generate_gql_schema_query_type(&graph.output_types); - graph.output_table = output_table_from_output_types(&graph.output_paths); - graph.dispatch_table = dispatch_table_from_query_paths(&graph.query_paths); - // graph.unified_type_doc = build_type_document(graph.output_types.iter().collect(), &graph.gql_query_type); - graph.node_by_name = node_by_name; - - Ok(graph) - } - - pub fn assert_parsing(&mut self) -> anyhow::Result<()> { - let recomputed = CleanedDefinitionGraph::recompute_parsed_values( - mem::take(&mut self.node_by_name) - ).unwrap(); - self.node_by_name = recomputed.node_by_name; - self.query_paths = recomputed.query_paths; - self.output_paths = recomputed.output_paths; - self.output_table = recomputed.output_table; - self.dispatch_table = recomputed.dispatch_table; - self.node_to_output_tables = recomputed.node_to_output_tables; - Ok(()) - } - - /// Merge a file into the current as a mutation. Returns a list of _updated_ nodes by name. - pub fn merge_file(&mut self, file: &File) -> anyhow::Result> { - // We merge each node in the file into our own file, combining their keys - let mut updated_nodes = vec![]; - for node in file.nodes.iter() { - if node.item.is_none() { continue; } - - // Nodes with the same name are merged - let name = &node.core.as_ref().unwrap().name; - - // If the node does not exist we just insert it - if !self.node_by_name.contains_key(name) { - self.node_by_name.insert(name.clone(), node.clone()); - } else { - updated_nodes.push(name.clone()); - // Otherwise we need to merge the node field by field - let existing = self.node_by_name.get_mut(name).unwrap(); - - let core = node.core.clone().unwrap(); - if let Some(existing_core) = &mut existing.core { - existing_core.name = core.name; - existing_core.output = core.output; - existing_core.triggers = core.triggers; - } - - match node.item.clone().unwrap() { - dsl_item::Item::NodeParameter(_n) => { - if let Some(dsl_item::Item::NodeParameter(_ex)) = &mut existing.item { - } - }, - dsl_item::Item::Map(n) => { - if let Some(dsl_item::Item::Map(ex)) = &mut existing.item { - ex.path = n.path; - } - }, - dsl_item::Item::NodeCode(n) => { - if let Some(dsl_item::Item::NodeCode(ex)) = &mut existing.item { - ex.source = n.source; - } - }, - dsl_item::Item::NodePrompt(n) => { - if let Some(dsl_item::Item::NodePrompt(ex)) = &mut existing.item { - ex.template = n.template; - ex.model = n.model; - ex.temperature = n.temperature; - ex.top_p = n.top_p; - ex.max_tokens = n.max_tokens; - ex.presence_penalty = n.presence_penalty; - ex.frequency_penalty = n.frequency_penalty; - ex.stop = n.stop; - } - }, - dsl_item::Item::NodeMemory(n) => { - if let Some(dsl_item::Item::NodeMemory(ex)) = &mut existing.item { - ex.template = n.template; - ex.embedding_model = n.embedding_model; - ex.vector_db_provider = n.vector_db_provider; - ex.action = n.action; - } - }, - dsl_item::Item::NodeComponent(n) => { - if let Some(dsl_item::Item::NodeComponent(ex)) = &mut existing.item { - ex.transclusion = n.transclusion - } - }, - dsl_item::Item::NodeObservation(n) => { - if let Some(dsl_item::Item::NodeObservation(ex)) = &mut existing.item { - ex.integration = n.integration; - } - }, - dsl_item::Item::NodeEcho(_n) => { - if let Some(dsl_item::Item::NodeEcho(_ex)) = &mut existing.item { - } - }, - dsl_item::Item::NodeLoader(n) => { - if let Some(dsl_item::Item::NodeLoader(ex)) = &mut existing.item { - ex.load_from = n.load_from; - } - }, - dsl_item::Item::NodeCustom(n) => { - if let Some(dsl_item::Item::NodeCustom(ex)) = &mut existing.item { - ex.type_name = n.type_name; - } - }, - _ => return Err(anyhow!("Node type not supported")) - }; - } - } - - let recomputed = CleanedDefinitionGraph::recompute_parsed_values( - mem::take(&mut self.node_by_name) - ).unwrap(); - - // TODO: Validate the new query that was added, if any - // for node in file.nodes.iter() { - // for opt_query_doc in recomputed.query_types.get(&node.core.as_ref().unwrap().name).unwrap() { - // if let Some(query_doc) = opt_query_doc { - // validate_new_query(&recomputed.unified_type_doc, query_doc); - // } - // } - // } - - self.node_by_name = recomputed.node_by_name; - self.query_paths = recomputed.query_paths; - self.output_paths = recomputed.output_paths; - self.output_table = recomputed.output_table; - self.dispatch_table = recomputed.dispatch_table; - self.node_to_output_tables = recomputed.node_to_output_tables; - - Ok(updated_nodes) - } - - /// Hashjoin output and dispatch tables - fn join_relation_between_output_and_dispatch_tables(&self) -> Vec<(String, String)> { - let mut edges: Vec<(String, String)> = vec![]; - for (key, originating_nodes) in self.output_table.iter() { - for originating_node in originating_nodes { - if let Some(affecting_nodes) = self.dispatch_table.get(key) { - for affecting_node in affecting_nodes { - edges.push((originating_node.clone(), affecting_node.clone())); - } - } - } - } - edges - } - - pub fn serialize_to_yaml(&self) -> String { - serde_yaml::to_string(&self).unwrap() - } - - pub fn get_dot_graph(&self) -> String { - let mut graph: DiGraphMap = petgraph::graphmap::DiGraphMap::new(); - - // Convert nodes into a numeric representation - let mut nodes = HashMap::new(); - let mut nodes_inverse = HashMap::new(); - let mut counter: u32 = 0; - let mut keys: Vec<&String> = self.node_by_name.keys().collect(); - keys.sort(); - for node_name in keys { - nodes.insert(node_name, counter); - nodes_inverse.insert(counter, node_name); - graph.add_node(counter); - counter += 1; - } - - // Join output and dispatch tables - let mut edges = self.join_relation_between_output_and_dispatch_tables(); - edges.sort(); - for (originating_node, affecting_node) in edges { - graph.add_edge(*nodes.get(&originating_node).unwrap(), *nodes.get(&affecting_node).unwrap(), 0); - } - - // TODO: this shows an error in intellij but it compiles fine - format!("{:?}", Dot::with_attr_getters( - &graph, - &[Config::EdgeNoLabel], - &|_, _| { "".to_string() }, - &|_, (n, _w)| { - format!("label=\"{}\"", nodes_inverse.get(&n).unwrap()) - } - )) - } - -} - - - - -pub fn construct_query_from_output_type(name: &String, namespace: &String, output_paths: &OutputPaths) -> anyhow::Result { - let projection_items: Vec = output_paths.iter().map(|x| x.join(".")).collect(); - let projection = projection_items.join(", "); - Ok(format!("SELECT {} FROM {}", projection, namespace)) -} - - - -#[cfg(test)] -mod tests { - use indoc::indoc; - - use crate::graph_definition::{create_code_node, SourceNodeType}; - use crate::proto::Query; - - use super::*; - - fn gen_item_hello(output_tables: Vec) -> Item { - create_code_node( - "code_node_test".to_string(), - vec![None], - r#"{ "output": String }"#.to_string(), - SourceNodeType::Code(String::from("DENO"), - indoc! { r#" - return { - "output": "hello" - } - "#}.to_string(), - false - ), - output_tables - ) - } - - fn gen_item_hello_plus_world() -> Item { - create_code_node( - "code_node_test_dep".to_string(), - vec![Some( r#" SELECT output FROM code_node_test"#.to_string(), - )], - r#"{ "result": String }"#.to_string(), - SourceNodeType::Code(String::from("DENO"), - indoc! { r#" - return { - "result": "{{code_node_test.output}}" + " world" - } - "#}.to_string(), - true - ), - vec![] - ) - } - - #[test] - fn test_construct_query_from_output_type() { - let output_paths : OutputPaths = vec![vec!["code_node_test".to_string(), "output".to_string()]]; - let query = construct_query_from_output_type(&"code_node_test".to_string(), &"code_node_test".to_string(), &output_paths).unwrap(); - assert_eq!(query, "SELECT code_node_test.output FROM code_node_test"); - } - - #[test] - fn test_construct_query_from_output_type_multiple_keys() { - let output_paths : OutputPaths = vec![vec!["code_node_test".to_string(), "output".to_string()], vec!["code_node_test".to_string(), "result".to_string()]]; - let query = construct_query_from_output_type(&"code_node_test".to_string(), &"code_node_test".to_string(), &output_paths).unwrap(); - assert_eq!(query, "SELECT code_node_test.output, code_node_test.result FROM code_node_test"); - } - - #[test] - fn test_dispatch_table_from_query_paths() { - let mut file = File { - id: "test".to_string(), - nodes: vec![ - gen_item_hello(vec![]), - // Our goal is to see changes from this node on both branches - gen_item_hello_plus_world() - ], - }; - let mut g = CleanedDefinitionGraph::zero(); - g.merge_file(&mut file).unwrap(); - - assert_eq!(g.dispatch_table, vec![ - ("code_node_test:output".to_string(), vec!["code_node_test_dep".to_string()]), - ("".to_string(), vec!["code_node_test".to_string()]) - ].into_iter().collect()); - } - - #[test] - fn test_dispatch_table_from_query_paths_multiple_keys() { - let mut file = File { - id: "test".to_string(), - nodes: vec![ - gen_item_hello(vec![]), - // Our goal is to see changes from this node on both branches - gen_item_hello_plus_world(), - create_code_node( - "code_node_test_multiple".to_string(), - vec![Some( r#" SELECT code_node_test_dep.result, code_node_test_dep.second FROM code_node_test_dep"#.to_string(), - )], - r#"{ "result": String }"#.to_string(), - SourceNodeType::Code( - String::from("DENO"), - indoc! { r#" return { "result": "out" } "#}.to_string(), - true - ), - vec![] - ) - ], - }; - let mut g = CleanedDefinitionGraph::zero(); - g.merge_file(&mut file).unwrap(); - - assert_eq!(g.dispatch_table, vec![ - ("code_node_test_dep:second".to_string(), vec!["code_node_test_multiple".to_string()]), - ("code_node_test_dep:result".to_string(), vec!["code_node_test_multiple".to_string()]), - ("code_node_test:output".to_string(), vec!["code_node_test_dep".to_string()]), - ("".to_string(), vec!["code_node_test".to_string()]) - ].into_iter().collect()); - } - - #[test] - fn test_producing_valid_dot_graph() { - let mut file = File { - id: "test".to_string(), - nodes: vec![ - gen_item_hello(vec![]), - // Our goal is to see changes from this node on both branches - gen_item_hello_plus_world() - ], - }; - let mut g = CleanedDefinitionGraph::zero(); - g.merge_file(&mut file).unwrap(); - - assert_eq!(g.join_relation_between_output_and_dispatch_tables(), vec![("code_node_test".to_string(), "code_node_test_dep".to_string())]); - - assert_eq!(g.get_dot_graph(), indoc!{r#" - digraph { - 0 [ label = "0" label="code_node_test"] - 1 [ label = "1" label="code_node_test_dep"] - 0 -> 1 [ ] - } - "#}); - } - - - #[test] - fn test_producing_valid_dot_graph_with_output_table() { - env_logger::init(); - let mut file = File { - id: "test".to_string(), - nodes: vec![ - gen_item_hello(vec!["OutputTable2".to_string()]), - // Our goal is to see changes from this node on both branches - gen_item_hello_plus_world(), - create_code_node( - "code_node_test_dep_output".to_string(), - vec![Some( r#"SELECT output FROM OutputTable2"#.to_string(), - )], - r#"{ result: String }"#.to_string(), - SourceNodeType::Code(String::from("DENO"), - indoc! { r#" - return { - "result": "{{code_node_test.output}}" + " world" - } - "#}.to_string(), - true - ), - vec![] - ) - ], - }; - let mut g = CleanedDefinitionGraph::zero(); - g.merge_file(&mut file).unwrap(); - - let mut list = g.join_relation_between_output_and_dispatch_tables(); - list.sort(); - assert_eq!(list, vec![ - ("code_node_test".to_string(), "code_node_test_dep".to_string()), - ("code_node_test".to_string(), "code_node_test_dep_output".to_string()), - ]); - - assert_eq!(g.get_dot_graph(), indoc!{r#" - digraph { - 0 [ label = "0" label="code_node_test"] - 1 [ label = "1" label="code_node_test_dep"] - 2 [ label = "2" label="code_node_test_dep_output"] - 0 -> 1 [ ] - 0 -> 2 [ ] - } - "#}); - } - -} diff --git a/toolchain/prompt-graph-core/src/build_runtime_graph/mod.rs b/toolchain/prompt-graph-core/src/build_runtime_graph/mod.rs deleted file mode 100644 index da58fca..0000000 --- a/toolchain/prompt-graph-core/src/build_runtime_graph/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod graph_parse; \ No newline at end of file diff --git a/toolchain/prompt-graph-core/src/execution/execution/execution_graph.rs b/toolchain/prompt-graph-core/src/execution/execution/execution_graph.rs new file mode 100644 index 0000000..b12d9ce --- /dev/null +++ b/toolchain/prompt-graph-core/src/execution/execution/execution_graph.rs @@ -0,0 +1,842 @@ +use crate::execution::execution::execution_state::{DependencyGraphMutation, ExecutionState}; +use crate::execution::integration::triggerable::{Subscribable, TriggerContext}; +use crate::execution::primitives::identifiers::{ArgumentIndex, OperationId, TimestampOfWrite}; +use crate::execution::primitives::operation::{ + OperationFn, OperationNode, OperationNodeDefinition, Signature, +}; +use crate::execution::primitives::serialized_value::deserialize_from_buf; +use crate::execution::primitives::serialized_value::RkyvSerializedValue as RSV; +use crossbeam_utils::sync::Unparker; +use futures::StreamExt; +use im::HashMap as ImHashMap; +use im::HashSet as ImHashSet; +use indoc::indoc; +use petgraph::algo::toposort; +use petgraph::data::Build; +use petgraph::dot::{Config, Dot}; +use petgraph::graph::{DiGraph, NodeIndex}; +use petgraph::graphmap::DiGraphMap; +use petgraph::visit::{Dfs, IntoEdgesDirected, VisitMap, Walker}; +use petgraph::Direction; +use std::cell::RefCell; +use std::collections::HashMap; +use std::collections::HashSet; +use std::fmt::{self, Formatter, Write}; +use std::rc::Rc; +use std::sync::{Arc, Mutex}; +// TODO: update all of these identifies to include a "space" they're within + +type EdgeIdentity = (OperationId, OperationId, ArgumentIndex); + +/// This models the network of reactive relationships between different components. +/// +/// This is heavily inspired by works such as Salsa, Verde, Incremental, Adapton, and Differential Dataflow. +pub struct ExecutionGraph { + /// Global revision number for modifications to the graph itself + revision: usize, + + /// Operation and its id + pub operation_by_id: HashMap, + + /// This is the graph of dependent execution state + /// + /// (branch, counter) -> steam_outputs_at_head + /// The dependency graph is stored within the execution graph, allowing us to model changes + /// to the dependency graph during the process of execution. + /// This is a graph of the mutations to the dependency graph. + /// As we make changes to the dependency graph itself, we track those transitions here. + /// This is roughly equivalent to a git history of the dependency graph. + /// + /// We store immutable representations of the history of the dependency graph. These + /// can be used to reconstruct a traversable dependency graph at any point in time. + /// + /// Identifiers on this graph refer to points in the execution graph. In execution terms, changes + /// along those edges are always considered to have occurred _after_ the target step. + execution_graph: DiGraphMap<(usize, usize), ExecutionState>, + + /// Dependency graph of the computable elements in the graph + /// + /// The dependency graph is a directed graph where the nodes are the ids of the operations and the + /// weights are the index of the input of the next operation. + /// + /// The usize::MAX index is a no-op that indicates that the operation is ready to run, an execution + /// order dependency rather than a value dependency. + dependency_graph: DiGraphMap>, +} + +impl ExecutionGraph { + /// Initialize a new reactivity database. This will create a default input and output node, + /// graphs default to being the unit function x -> x. + pub fn new() -> Self { + let mut dependency_graph = DiGraphMap::new(); + let mut operation_by_id = HashMap::new(); + ExecutionGraph { + operation_by_id, + execution_graph: Default::default(), + dependency_graph, + revision: 0, + } + } + + /// This adds an operation into the database + pub fn upsert_operation( + &mut self, + prev_execution_id: (usize, usize), + previous_state: ExecutionState, + node: usize, + args: usize, + func: Box, + ) -> ((usize, usize), ExecutionState) { + let mut new_state = previous_state + .clone() + .add_operation(node.clone(), args, func); + let output_new_state = new_state.clone(); + self.add_execution_edge(prev_execution_id, new_state, output_new_state) + } + + /// Indicates that this operation depends on the output of the given node + pub fn apply_dependency_graph_mutations( + &mut self, + prev_execution_id: (usize, usize), + previous_state: ExecutionState, + mutations: Vec, + ) -> ((usize, usize), ExecutionState) { + let mut new_state = previous_state + .clone() + .apply_dependency_graph_mutations(mutations); + let output_new_state = new_state.clone(); + self.add_execution_edge(prev_execution_id, new_state, output_new_state) + } + + fn add_execution_edge( + &mut self, + prev_execution_id: (usize, usize), + mut new_state: ExecutionState, + output_new_state: ExecutionState, + ) -> ((usize, usize), ExecutionState) { + let edges = self + .execution_graph + .edges_directed(prev_execution_id, Direction::Outgoing); + + let new_id = if let Some((_, max_to, _)) = + edges.max_by(|(_, a_to, _), (_, b_to, _)| (a_to.0).cmp(&(b_to.0))) + { + // Create an edge in the execution graph from the previous state to this new one + let id = (max_to.0 + 1, prev_execution_id.1 + 1); + self.execution_graph + .add_edge(prev_execution_id, id.clone(), new_state); + id + } else { + // Create an edge in the execution graph from the previous state to this new one + let id = (0, prev_execution_id.1 + 1); + self.execution_graph + .add_edge(prev_execution_id, id.clone(), new_state); + id + }; + + (new_id, output_new_state) + } + + pub fn render_execution_graph(&self) { + println!("================ Execution graph ================"); + println!("{:?}", Dot::with_config(&self.execution_graph, &[])); + } + + pub fn step_execution( + &mut self, + prev_execution_id: (usize, usize), + previous_state: ExecutionState, + ) -> ((usize, usize), ExecutionState) { + // Clone the previous immutable state for modification + let mut marked_for_consumption = HashSet::new(); + let mut new_state = previous_state.clone(); + let mut operation_by_id = previous_state.operation_by_id.clone(); + let dependency_graph = previous_state.get_dependency_graph(); + + // Every tick, every operation consumes from each of its incoming edges. + 'traverse_nodes: for operation_id in dependency_graph.nodes() { + let mut op_node = operation_by_id.get_mut(&operation_id).unwrap().borrow_mut(); + let mut dep_count = op_node.dependency_count; + let mut args: Vec<&Option>> = vec![&None; dep_count]; + + // Ops with 0 deps should only execute once + if dep_count == 0 { + if previous_state.check_if_previously_set(&operation_id) { + continue 'traverse_nodes; + } + } + + // TODO: this currently disallows multiple edges from the same node? + // Fetch the values from the previous execution cycle for each edge on this node + for (from, to, argument_indices) in + dependency_graph.edges_directed(operation_id, Direction::Incoming) + { + // TODO: if the dependency is on usize::MAX, then this is an execution order dependency + if let Some(output) = previous_state.state_get(&from) { + marked_for_consumption.insert(from.clone()); + // TODO: we can implement prioritization between different values here + for weight in argument_indices { + args[*weight] = output; + if dep_count > 0 { + dep_count -= 1; + } + } + } + } + + // Some of the required arguments are not yet available, continue to the next node + if dep_count != 0 { + continue 'traverse_nodes; + } + + // Execute the Operation with the given arguments + // TODO: support async/parallel execution + let result = op_node.execute(args.iter().map(|x| &**x).collect()); + + new_state.state_insert(operation_id, result.clone()); + } + new_state.state_consume_marked(marked_for_consumption); + + // The edge from this node is the greatest branching id + 1 + // if we re-evaluate execution at a given node, we get a new execution branch. + self.add_execution_edge(prev_execution_id, new_state.clone(), new_state) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::execution::primitives::serialized_value::RkyvSerializedValue as RSV; + use crate::execution::primitives::serialized_value::{ + deserialize_from_buf, serialize_to_vec, ArchivedRkyvSerializedValue, + }; + use log::warn; + use rkyv::ser::serializers::AllocSerializer; + use rkyv::ser::Serializer; + use rkyv::{archived_root, Deserialize, Serialize}; + use std::collections::HashSet; + use std::sync::atomic::{AtomicUsize, Ordering}; + + /* + Testing the execution of individual nodes. Validating that operations as defined can be executed. + */ + + #[test] + fn test_evaluation_single_node() { + let mut db = ExecutionGraph::new(); + let mut state = ExecutionState::new(); + let state_id = (0, 0); + let (state_id, mut state) = db.upsert_operation( + state_id, + state, + 1, + 0, + Box::new(|_args| { + let v = RSV::Number(1); + return serialize_to_vec(&v); + }), + ); + let (state_id, mut state) = db.upsert_operation( + state_id, + state, + 2, + 0, + Box::new(|_args| { + let v = RSV::Number(1); + return serialize_to_vec(&v); + }), + ); + let (state_id, mut state) = db.upsert_operation( + state_id, + state, + 3, + 2, + Box::new(|args| { + let arg0 = deserialize_from_buf(args[0].as_ref().unwrap().as_slice()); + let arg1 = deserialize_from_buf(args[1].as_ref().unwrap().as_slice()); + + if let (RSV::Number(a), RSV::Number(b)) = (arg0, arg1) { + let v = RSV::Number(a + b); + return serialize_to_vec(&v); + } + + panic!("Invalid arguments") + }), + ); + + let (state_id, mut state) = db.apply_dependency_graph_mutations( + (0, 0), + state, + vec![DependencyGraphMutation::Create { + operation_id: 3, + depends_on: vec![(1, 0), (2, 1)], + }], + ); + + let v0 = RSV::Number(1); + let v1 = RSV::Number(2); + let arg0 = serialize_to_vec(&v0); + let arg1 = serialize_to_vec(&v1); + + // Manually manipulating the state to insert the arguments for this test + state.state_insert(1, Some(arg0)); + state.state_insert(2, Some(arg1)); + + let (_, new_state) = db.step_execution(state_id, state.clone()); + + assert!(new_state.state_get(&3).is_some()); + let result = new_state.state_get(&3).unwrap(); + let result_val = deserialize_from_buf(&result.as_ref().clone().unwrap()); + assert_eq!(result_val, RSV::Number(3)); + } + + /* + Testing the traverse of the dependency graph. Validating that execution of the graph moves through + the graph as expected. + */ + + #[test] + fn test_traverse_single_node() { + let mut db = ExecutionGraph::new(); + let mut state = ExecutionState::new(); + let state_id = (0, 0); + let (state_id, mut state) = + db.upsert_operation(state_id, state, 0, 0, Box::new(|_args| vec![0, 0, 0])); + let (state_id, mut state) = + db.upsert_operation(state_id, state, 1, 0, Box::new(|_args| vec![1, 1, 1])); + let (state_id, state) = db.apply_dependency_graph_mutations( + (0, 0), + state, + vec![DependencyGraphMutation::Create { + operation_id: 1, + depends_on: vec![(0, 0)], + }], + ); + let (_, new_state) = db.step_execution(state_id, state); + assert_eq!(new_state.state_get(&1).unwrap(), &Some(vec![1, 1, 1])); + } + + #[test] + fn test_traverse_linear_chain() { + let mut db = ExecutionGraph::new(); + + // Nodes are in this structure + // 0 + // | + // 1 + // | + // 2 + + let mut state = ExecutionState::new(); + let state_id = (0, 0); + + let (state_id, mut state) = db.upsert_operation( + state_id, + state, + 0, + 0, + Box::new(|args| RSV::Number(0).into()), + ); + let (state_id, mut state) = db.upsert_operation( + state_id, + state, + 1, + 1, + Box::new(|args| RSV::Number(1).into()), + ); + let (state_id, mut state) = db.upsert_operation( + state_id, + state, + 2, + 1, + Box::new(|args| RSV::Number(2).into()), + ); + let (state_id, state) = db.apply_dependency_graph_mutations( + (0, 0), + state, + vec![ + DependencyGraphMutation::Create { + operation_id: 1, + depends_on: vec![(0, 0)], + }, + DependencyGraphMutation::Create { + operation_id: 2, + depends_on: vec![(1, 0)], + }, + ], + ); + + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get(&1), None); + assert_eq!(state.state_get(&2), None); + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get_value(&1), Some(RSV::Number(1))); + assert_eq!(state.state_get(&2), None); + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get_value(&1), None); + assert_eq!(state.state_get_value(&2), Some(RSV::Number(2))); + } + + #[test] + fn test_traverse_branching() { + let mut db = ExecutionGraph::new(); + + // Nodes are in this structure + // 0 + // | + // 1 + // / \ + // 2 3 + + let mut state = ExecutionState::new(); + let state_id = (0, 0); + let (state_id, mut state) = db.upsert_operation( + state_id, + state, + 0, + 0, + Box::new(|args| RSV::Number(0).into()), + ); + let (state_id, mut state) = db.upsert_operation( + state_id, + state, + 1, + 1, + Box::new(|args| RSV::Number(1).into()), + ); + let (state_id, mut state) = db.upsert_operation( + state_id, + state, + 2, + 1, + Box::new(|args| RSV::Number(2).into()), + ); + let (state_id, mut state) = db.upsert_operation( + state_id, + state, + 3, + 1, + Box::new(|args| RSV::Number(3).into()), + ); + + let (state_id, state) = db.apply_dependency_graph_mutations( + (0, 0), + state, + vec![ + DependencyGraphMutation::Create { + operation_id: 1, + depends_on: vec![(0, 0)], + }, + DependencyGraphMutation::Create { + operation_id: 2, + depends_on: vec![(1, 0)], + }, + DependencyGraphMutation::Create { + operation_id: 3, + depends_on: vec![(1, 0)], + }, + ], + ); + + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get(&1), None); + assert_eq!(state.state_get(&2), None); + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get_value(&1), Some(RSV::Number(1))); + assert_eq!(state.state_get(&2), None); + assert_eq!(state.state_get(&3), None); + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get(&1), None); + assert_eq!(state.state_get_value(&2), Some(RSV::Number(2))); + assert_eq!(state.state_get_value(&3), Some(RSV::Number(3))); + } + + #[test] + fn test_traverse_branching_and_convergence() { + let mut db = ExecutionGraph::new(); + + // Nodes are in this structure + // 0 + // | + // 1 + // / \ + // 2 3 + // \ / + // 4 + + let mut state = ExecutionState::new(); + let state_id = (0, 0); + let (state_id, mut state) = db.upsert_operation( + state_id, + state, + 0, + 0, + Box::new(|args| RSV::Number(0).into()), + ); + let (state_id, mut state) = db.upsert_operation( + state_id, + state, + 1, + 1, + Box::new(|args| RSV::Number(1).into()), + ); + let (state_id, mut state) = db.upsert_operation( + state_id, + state, + 2, + 1, + Box::new(|args| RSV::Number(2).into()), + ); + let (state_id, mut state) = db.upsert_operation( + state_id, + state, + 3, + 1, + Box::new(|args| RSV::Number(3).into()), + ); + let (state_id, mut state) = db.upsert_operation( + state_id, + state, + 4, + 2, + Box::new(|args| RSV::Number(4).into()), + ); + + let (state_id, state) = db.apply_dependency_graph_mutations( + (0, 0), + state, + vec![ + DependencyGraphMutation::Create { + operation_id: 1, + depends_on: vec![(0, 0)], + }, + DependencyGraphMutation::Create { + operation_id: 2, + depends_on: vec![(1, 0)], + }, + DependencyGraphMutation::Create { + operation_id: 3, + depends_on: vec![(1, 0)], + }, + DependencyGraphMutation::Create { + operation_id: 4, + depends_on: vec![(2, 0), (3, 1)], + }, + ], + ); + + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get(&1), None); + assert_eq!(state.state_get(&2), None); + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get_value(&1), Some(RSV::Number(1))); + assert_eq!(state.state_get(&2), None); + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get(&1), None); + assert_eq!(state.state_get_value(&2), Some(RSV::Number(2))); + assert_eq!(state.state_get_value(&3), Some(RSV::Number(3))); + assert_eq!(state.state_get(&4), None); + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get(&1), None); + assert_eq!(state.state_get(&2), None); + assert_eq!(state.state_get(&3), None); + assert_eq!(state.state_get_value(&4), Some(RSV::Number(4))); + } + + #[test] + fn test_traverse_cycle() { + let mut db = ExecutionGraph::new(); + + // Nodes are in this structure _with the following cycle_ + // 0 + // | + // 1 * depends 1 -> 3 + // / \ + // 2 3 + // \ / * depends 3 -> 4 + // 4 + // | + // 5 + + let mut state = ExecutionState::new(); + let state_id = (0, 0); + + // We start with the number 1 at node 0 + let (state_id, mut state) = db.upsert_operation( + state_id, + state, + 0, + 0, + Box::new(|_args| { + let v = RSV::Number(1); + return serialize_to_vec(&v); + }), + ); + + // Each node adds 1 to the inbound item (all nodes only have one dependency per index) + let f1 = |args: Vec<&Option>>| { + let arg0 = deserialize_from_buf(args[0].as_ref().unwrap().as_slice()); + + if let RSV::Number(a) = arg0 { + let v = RSV::Number(a + 1); + return serialize_to_vec(&v); + } + + panic!("Invalid arguments") + }; + + let (state_id, mut state) = db.upsert_operation(state_id, state, 1, 1, Box::new(f1)); + let (state_id, mut state) = db.upsert_operation(state_id, state, 2, 1, Box::new(f1)); + let (state_id, mut state) = db.upsert_operation(state_id, state, 3, 1, Box::new(f1)); + let (state_id, mut state) = db.upsert_operation(state_id, state, 4, 1, Box::new(f1)); + let (state_id, mut state) = db.upsert_operation(state_id, state, 5, 1, Box::new(f1)); + + let (state_id, state) = db.apply_dependency_graph_mutations( + (0, 0), + state, + vec![ + DependencyGraphMutation::Create { + operation_id: 1, + depends_on: vec![(0, 0), (3, 0)], + }, + DependencyGraphMutation::Create { + operation_id: 2, + depends_on: vec![(1, 0)], + }, + DependencyGraphMutation::Create { + operation_id: 3, + depends_on: vec![(4, 0)], + }, + DependencyGraphMutation::Create { + operation_id: 4, + depends_on: vec![(2, 0)], + }, + DependencyGraphMutation::Create { + operation_id: 5, + depends_on: vec![(4, 0)], + }, + ], + ); + + // We expect to see the value at each node increment repeatedly. + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get(&1), None); + assert_eq!(state.state_get(&2), None); + let (state_id, state) = db.step_execution(state_id, state); + + assert_eq!(state.state_get_value(&1), Some(RSV::Number(2))); + assert_eq!(state.state_get(&2), None); + assert_eq!(state.state_get(&3), None); + assert_eq!(state.state_get(&4), None); + assert_eq!(state.state_get(&5), None); + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get(&1), None); + assert_eq!(state.state_get_value(&2), Some(RSV::Number(3))); + assert_eq!(state.state_get(&3), None); + assert_eq!(state.state_get(&4), None); + assert_eq!(state.state_get(&5), None); + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get(&1), None); + assert_eq!(state.state_get(&2), None); + assert_eq!(state.state_get(&3), None); + assert_eq!(state.state_get_value(&4), Some(RSV::Number(4))); + assert_eq!(state.state_get(&5), None); + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get(&1), None); + assert_eq!(state.state_get(&2), None); + assert_eq!(state.state_get_value(&3), Some(RSV::Number(5))); + assert_eq!(state.state_get(&4), None); + assert_eq!(state.state_get_value(&5), Some(RSV::Number(5))); + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get_value(&1), Some(RSV::Number(6))); + assert_eq!(state.state_get(&2), None); + assert_eq!(state.state_get(&3), None); + assert_eq!(state.state_get(&4), None); + assert_eq!(state.state_get_value(&5), Some(RSV::Number(5))); + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get(&1), None); + assert_eq!(state.state_get_value(&2), Some(RSV::Number(7))); + assert_eq!(state.state_get(&3), None); + assert_eq!(state.state_get(&4), None); + assert_eq!(state.state_get_value(&5), Some(RSV::Number(5))); + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get(&1), None); + assert_eq!(state.state_get(&2), None); + assert_eq!(state.state_get(&3), None); + assert_eq!(state.state_get_value(&4), Some(RSV::Number(8))); + assert_eq!(state.state_get_value(&5), Some(RSV::Number(5))); + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get(&1), None); + assert_eq!(state.state_get(&2), None); + assert_eq!(state.state_get_value(&3), Some(RSV::Number(9))); + assert_eq!(state.state_get(&4), None); + assert_eq!(state.state_get_value(&5), Some(RSV::Number(9))); + } + + #[test] + fn test_branching_multiple_state_paths() { + let mut db = ExecutionGraph::new(); + + // Nodes are in this structure + // 0 + // | + // 1 + // | + // 2 + + let mut state = ExecutionState::new(); + let state_id = (0, 0); + + // We start with the number 1 at node 0 + let (state_id, mut state) = db.upsert_operation( + state_id, + state, + 0, + 0, + Box::new(|_args| { + let v = RSV::Number(1); + return serialize_to_vec(&v); + }), + ); + + // Globally mutates this value, making each call to this function side-effecting + static atomic_usize: AtomicUsize = AtomicUsize::new(0); + let f_side_effect = |args: Vec<&Option>>| { + let arg0 = deserialize_from_buf(args[0].as_ref().unwrap().as_slice()); + + if let RSV::Number(a) = arg0 { + let plus = atomic_usize.fetch_add(1, Ordering::SeqCst); + let v = RSV::Number(a + plus as i32); + return serialize_to_vec(&v); + } + + panic!("Invalid arguments") + }; + + let (state_id, mut state) = + db.upsert_operation(state_id, state, 1, 1, Box::new(f_side_effect)); + let (state_id, mut state) = + db.upsert_operation(state_id, state, 2, 1, Box::new(f_side_effect)); + + let (state_id, state) = db.apply_dependency_graph_mutations( + (0, 0), + state, + vec![ + DependencyGraphMutation::Create { + operation_id: 1, + depends_on: vec![(0, 0)], + }, + DependencyGraphMutation::Create { + operation_id: 2, + depends_on: vec![(1, 0)], + }, + ], + ); + + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get(&1), None); + assert_eq!(state.state_get(&2), None); + let (x_state_id, x_state) = db.step_execution(state_id, state); + assert_eq!(x_state.state_get_value(&1), Some(RSV::Number(1))); + assert_eq!(x_state.state_get(&2), None); + + let (state_id, state) = db.step_execution(x_state_id.clone(), x_state.clone()); + assert_eq!(state_id.0, 0); + assert_eq!(state.state_get(&1), None); + assert_eq!(state.state_get_value(&2), Some(RSV::Number(2))); + + // When we re-evaluate from a previous point, we should get a new branch + let (state_id, state) = db.step_execution(x_state_id.clone(), x_state); + // The state_id.0 being incremented indicates that we're on a new branch + assert_eq!(state_id.0, 1); + assert_eq!(state.state_get(&1), None); + // Op 2 should re-evaluate to 3, since it's on a new branch but continuing to mutate the stateful counter + assert_eq!(state.state_get_value(&2), Some(RSV::Number(3))); + } + + #[test] + fn test_mutation_of_the_dependency_graph_on_branches() { + let mut db = ExecutionGraph::new(); + + // Nodes are in this structure + // 0 + // | + // 1 * we're going to be changing the definiton of the function of this node on one branch + // | + // 2 + + let mut state = ExecutionState::new(); + let state_id = (0, 0); + + // We start with the number 0 at node 0 + let (state_id, mut state) = db.upsert_operation( + state_id, + state, + 0, + 0, + Box::new(|_args| { + let v = RSV::Number(0); + return serialize_to_vec(&v); + }), + ); + + let f_v1 = |args: Vec<&Option>>| { + let arg0 = deserialize_from_buf(args[0].as_ref().unwrap().as_slice()); + + if let RSV::Number(a) = arg0 { + let v = RSV::Number(a + 1); + return serialize_to_vec(&v); + } + + panic!("Invalid arguments") + }; + + let f_v2 = |args: Vec<&Option>>| { + let arg0 = deserialize_from_buf(args[0].as_ref().unwrap().as_slice()); + + if let RSV::Number(a) = arg0 { + let v = RSV::Number(a + 200); + return serialize_to_vec(&v); + } + + panic!("Invalid arguments") + }; + + let (state_id, mut state) = db.upsert_operation(state_id, state, 1, 1, Box::new(f_v1)); + let (state_id, mut state) = db.upsert_operation(state_id, state, 2, 1, Box::new(f_v1)); + + let (state_id, state) = db.apply_dependency_graph_mutations( + (0, 0), + state, + vec![ + DependencyGraphMutation::Create { + operation_id: 1, + depends_on: vec![(0, 0)], + }, + DependencyGraphMutation::Create { + operation_id: 2, + depends_on: vec![(1, 0)], + }, + ], + ); + + let (x_state_id, x_state) = db.step_execution(state_id, state); + assert_eq!(x_state.state_get(&1), None); + assert_eq!(x_state.state_get(&2), None); + let (state_id, state) = db.step_execution(x_state_id, x_state.clone()); + assert_eq!(state.state_get_value(&1), Some(RSV::Number(1))); + assert_eq!(state.state_get(&2), None); + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get_value(&1), None); + assert_eq!(state.state_get_value(&2), Some(RSV::Number(2))); + + // Change the definition of the operation "1" to add 200 instead of 1, then re-evaluate + let (state_id, mut state) = db.upsert_operation(x_state_id, x_state, 1, 1, Box::new(f_v2)); + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get_value(&1), Some(RSV::Number(200))); + assert_eq!(state.state_get(&2), None); + let (state_id, state) = db.step_execution(state_id, state); + assert_eq!(state.state_get_value(&1), None); + assert_eq!(state.state_get_value(&2), Some(RSV::Number(201))); + } +} diff --git a/toolchain/prompt-graph-core/src/execution/execution/execution_state.rs b/toolchain/prompt-graph-core/src/execution/execution/execution_state.rs new file mode 100644 index 0000000..1bc48a2 --- /dev/null +++ b/toolchain/prompt-graph-core/src/execution/execution/execution_state.rs @@ -0,0 +1,176 @@ +use crate::execution::primitives::identifiers::{ArgumentIndex, OperationId}; +use crate::execution::primitives::operation::{OperationFn, OperationNode}; +use crate::execution::primitives::serialized_value::{ + deserialize_from_buf, RkyvSerializedValue as RSV, +}; +use im::{HashMap as ImHashMap, HashSet as ImHashSet}; +use indoc::indoc; +use petgraph::dot::{Config, Dot}; +use petgraph::graphmap::DiGraphMap; +use std::cell::RefCell; +use std::collections::HashSet; +use std::fmt; +use std::fmt::Formatter; +use std::rc::Rc; + +pub enum DependencyGraphMutation { + Create { + operation_id: OperationId, + depends_on: Vec<(OperationId, ArgumentIndex)>, + }, + Delete { + operation_id: OperationId, + }, +} + +#[derive(Clone)] +pub struct ExecutionState { + // TODO: update all operations to use this id instead of a separate representation + id: (usize, usize), + + state: ImHashMap>>>, + + pub operation_by_id: ImHashMap>>, + + /// Note what keys have _ever_ been set, which is an optimization to avoid needing to do + /// a complete historical traversal to verify that a value has been set. + has_been_set: ImHashSet, + + /// Dependency graph of the computable elements in the graph + /// + /// The dependency graph is a directed graph where the nodes are the ids of the operations and the + /// weights are the index of the input of the next operation. + /// + /// The usize::MAX index is a no-op that indicates that the operation is ready to run, an execution + /// order dependency rather than a value dependency. + dependency_graph: ImHashMap>, +} + +impl std::fmt::Debug for ExecutionState { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str(&render_map_as_table(self)) + } +} + +fn render_map_as_table(exec_state: &ExecutionState) -> String { + let mut table = String::from(indoc!( + r" + | Key | Value | + |---|---|" + )); + + for key in exec_state.state.keys() { + if let Some(val) = exec_state.state_get_value(key) { + table.push_str(&format!( + indoc!( + r" + | {} | {:?} |" + ), + key, val, + )); + } + } + + table +} + +impl ExecutionState { + pub fn new() -> Self { + ExecutionState { + id: (0, 0), + state: Default::default(), + operation_by_id: Default::default(), + has_been_set: Default::default(), + dependency_graph: Default::default(), + } + } + + pub fn state_get_value(&self, operation_id: &OperationId) -> Option { + self.state_get(operation_id) + .map(|x| deserialize_from_buf(x.as_ref().unwrap())) + } + + pub fn state_get(&self, operation_id: &OperationId) -> Option<&Option>> { + self.state.get(operation_id).map(|x| x.as_ref()) + } + + pub fn check_if_previously_set(&self, operation_id: &OperationId) -> bool { + self.has_been_set.contains(operation_id) + } + + pub fn state_consume_marked(&mut self, marked_for_consumption: HashSet) { + for key in marked_for_consumption.clone().into_iter() { + self.state.remove(&key); + } + } + + pub fn state_insert(&mut self, operation_id: OperationId, value: Option>) { + self.state.insert(operation_id, Rc::new(value)); + self.has_been_set.insert(operation_id); + } + + pub fn render_dependency_graph(&self) { + println!("================ Dependency graph ================"); + println!( + "{:?}", + Dot::with_config(&self.get_dependency_graph(), &[Config::EdgeNoLabel]) + ); + } + + pub fn get_dependency_graph(&self) -> DiGraphMap> { + let mut graph = DiGraphMap::new(); + for (node, value) in self.dependency_graph.clone().into_iter() { + graph.add_node(node); + for (depends_on, index) in value.into_iter() { + let r = graph.add_edge(depends_on, node, vec![index]); + if r.is_some() { + graph + .edge_weight_mut(depends_on, node) + .unwrap() + .append(&mut r.unwrap()); + } + } + } + graph + } + + pub fn add_operation(&mut self, node: usize, args: usize, func: Box) -> Self { + let mut s = self.clone(); + s.operation_by_id.insert( + node.clone(), + Rc::new(RefCell::new(OperationNode::new(args, Some(func)))), + ); + s + } + + pub fn apply_dependency_graph_mutations( + &self, + mutations: Vec, + ) -> Self { + let mut s = self.clone(); + for mutation in mutations { + match mutation { + DependencyGraphMutation::Create { + operation_id, + depends_on, + } => { + if let Some(e) = s.dependency_graph.get_mut(&operation_id) { + e.clear(); + e.extend(depends_on.into_iter()); + } else { + s.dependency_graph + .entry(operation_id) + .or_insert(HashSet::from_iter(depends_on.into_iter())); + } + } + DependencyGraphMutation::Delete { operation_id } => { + s.dependency_graph.remove(&operation_id); + } + } + } + s + } +} + +#[cfg(test)] +mod tests {} diff --git a/toolchain/prompt-graph-core/src/execution/execution/mod.rs b/toolchain/prompt-graph-core/src/execution/execution/mod.rs new file mode 100644 index 0000000..112008c --- /dev/null +++ b/toolchain/prompt-graph-core/src/execution/execution/mod.rs @@ -0,0 +1,59 @@ +pub mod execution_graph; +pub mod execution_state; +pub mod mutate_active_execution; +use crate::execution::integration::triggerable::{Subscribable, TriggerContext}; +use crate::execution::primitives::identifiers::{ArgumentIndex, OperationId, TimestampOfWrite}; +use crate::execution::primitives::operation::{ + OperationFn, OperationNode, OperationNodeDefinition, Signature, +}; +use crate::execution::primitives::serialized_value::deserialize_from_buf; +use crate::execution::primitives::serialized_value::RkyvSerializedValue as RSV; +use crossbeam_utils::sync::Unparker; +pub use execution_state::{DependencyGraphMutation, ExecutionState}; +use futures::StreamExt; +use im::HashMap as ImHashMap; +use im::HashSet as ImHashSet; +use indoc::indoc; +use petgraph::algo::toposort; +use petgraph::data::Build; +use petgraph::dot::{Config, Dot}; +use petgraph::graph::{DiGraph, NodeIndex}; +use petgraph::graphmap::DiGraphMap; +use petgraph::visit::{Dfs, IntoEdgesDirected, VisitMap, Walker}; +use petgraph::Direction; +use std::cell::RefCell; +use std::collections::HashMap; +use std::collections::HashSet; +use std::fmt::{self, Formatter, Write}; +use std::rc::Rc; +use std::sync::{Arc, Mutex}; + +type OperationValue = Vec; +type OperationEventHandler = Box; +type OperationEventHandlers = Rc>>; + +/// The set of async nodes for which the scheduler has received ready +/// notifications. +#[derive(Clone)] +struct Notifications { + /// Nodes that received notifications. + nodes: Arc>>, + + /// Handle to wake up the scheduler thread when a notification arrives. + unparker: Unparker, +} + +impl Notifications { + fn new(size: usize, unparker: Unparker) -> Self { + Self { + nodes: Arc::new(Mutex::new(HashSet::with_capacity(size))), + unparker, + } + } + + /// Add a new notification. + fn notify(&self, node_id: OperationId) { + self.nodes.lock().unwrap().insert(node_id); + self.unparker.unpark(); + } +} diff --git a/toolchain/prompt-graph-core/src/execution/execution/mutate_active_execution.rs b/toolchain/prompt-graph-core/src/execution/execution/mutate_active_execution.rs new file mode 100644 index 0000000..66b628d --- /dev/null +++ b/toolchain/prompt-graph-core/src/execution/execution/mutate_active_execution.rs @@ -0,0 +1,153 @@ +use crate::execution::execution::execution_graph::ExecutionGraph; +use crate::execution::primitives::operation::{OperationNode, OperationNodeDefinition}; + +pub enum GraphMutation { + Update { + node_id: usize, + operation_node: OperationNodeDefinition, + }, + Delete { + node_id: usize, + }, +} + +impl ExecutionGraph { + /// This function is called when a change is made to the definition of the graph. + /// When a change is made to the graph, we need to identify which elements are now dirtied and must + /// be re-executed + pub fn handle_operation_change(&mut self, node_id: usize, incoming_change: GraphMutation) { + // Changes the operation of a cell + let node_id = match &incoming_change { + GraphMutation::Update { node_id, .. } | GraphMutation::Delete { node_id, .. } => { + node_id.clone() + } + _ => return, + }; + + if let GraphMutation::Delete { node_id, .. } = &incoming_change { + // TODO: tombstone the target node + unimplemented!(); + } + + if let Some(existing_op_node) = self.operation_by_id.remove(&node_id) { + if let GraphMutation::Delete { node_id, .. } = &incoming_change {} + if let GraphMutation::Update { + node_id, + operation_node, + } = incoming_change + { + // Update dependency graph + let existing_dependencies = &existing_op_node.dependency_count; + let new_dependencies = &operation_node.dependency_count; + if existing_dependencies != new_dependencies { + // for neighbor in existing_dependencies { + // self.dependency_graph + // .remove_edge(neighbor.clone(), node_id.clone()); + // } + // for (idx, dependency) in new_dependencies.iter().enumerate() { + // self.dependency_graph + // .add_edge(dependency.clone(), node_id.clone(), idx); + // } + } + + // Overwrite the existing node definition + self.operation_by_id + .insert(node_id.clone(), OperationNode::from(operation_node)); + } + } else { + // TODO: existing node does not exist - create the node + if let GraphMutation::Update { + node_id, + operation_node, + } = incoming_change + { + // for (idx, dependency) in operation_node.dependencies.iter().enumerate() { + // self.add_value_dependency_to_operation( + // node_id.clone(), + // dependency.clone(), + // idx, + // ); + // } + self.operation_by_id + .insert(node_id.clone(), OperationNode::from(operation_node)); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::execution::primitives::operation::OperationNodeDefinition; + /* + Testing the operation change event handler - giving structure to modifications of the evaluation graph + */ + + #[test] + fn test_handle_operation_change_add() { + let mut db = ExecutionGraph::new(); + let node_id = 2; + let operation_node = OperationNodeDefinition { + operation: None, + dependency_count: 0, + }; + + db.handle_operation_change( + node_id, + GraphMutation::Update { + node_id, + operation_node, + }, + ); + } + + #[test] + fn test_handle_operation_change_update() { + let mut db = ExecutionGraph::new(); + let node_id = 1; + let operation_node = OperationNodeDefinition { + operation: None, + dependency_count: 0, + }; + + db.handle_operation_change( + node_id, + GraphMutation::Update { + node_id, + operation_node, + }, + ); + } + + #[test] + fn test_handle_operation_change_delete() { + let mut db = ExecutionGraph::new(); + let node_id = 1; + + db.handle_operation_change(node_id, GraphMutation::Delete { node_id }); + + assert!(db.operation_by_id.get(&node_id).is_none()); + } + + #[test] + fn test_handle_operation_change_execution_order() { + let mut db = ExecutionGraph::new(); + let node_id = 2; + let operation_node = OperationNodeDefinition { + operation: None, + dependency_count: 0, + }; + + db.handle_operation_change( + node_id, + GraphMutation::Update { + node_id, + operation_node, + }, + ); + } + + /* + Testing stepwise evaluation of the graph based on the execution order + */ +} diff --git a/toolchain/prompt-graph-core/src/execution/integration/mod.rs b/toolchain/prompt-graph-core/src/execution/integration/mod.rs new file mode 100644 index 0000000..742ec58 --- /dev/null +++ b/toolchain/prompt-graph-core/src/execution/integration/mod.rs @@ -0,0 +1 @@ +pub mod triggerable; diff --git a/toolchain/prompt-graph-core/src/reactivity/triggerable.rs b/toolchain/prompt-graph-core/src/execution/integration/triggerable.rs similarity index 58% rename from toolchain/prompt-graph-core/src/reactivity/triggerable.rs rename to toolchain/prompt-graph-core/src/execution/integration/triggerable.rs index b004184..9bd531b 100644 --- a/toolchain/prompt-graph-core/src/reactivity/triggerable.rs +++ b/toolchain/prompt-graph-core/src/execution/integration/triggerable.rs @@ -1,11 +1,10 @@ +use crate::execution::execution::execution_graph::ExecutionGraph; +use crate::time_travel::global::HistoricalData; +use im::HashMap as ImmutableHashMap; /// This describes an API for a reactive system with a function "make_triggerable" that wraps /// a method passing to it a set of triggerable relationships. When those relationships fire, /// the associated method is invoked. - use std::collections::HashMap; -use im::HashMap as ImmutableHashMap; -use crate::reactivity::database::ReactivityDatabase; -use crate::time_travel::global::HistoricalData; impl Subscribable for HistoricalData { fn has_changed(&self) -> bool { @@ -14,10 +13,8 @@ impl Subscribable for HistoricalData { self.has_processed_change_at(self.history.len().saturating_sub(1)) } } - pub trait TriggerContext {} - struct Context<'a> { triggered: &'a mut bool, } @@ -28,28 +25,25 @@ pub trait Subscribable { fn has_changed(&self) -> bool; } -pub fn make_triggerable( - reactivity_db: &mut ReactivityDatabase, - func: F, -) - where - T: TriggerContext, - S: Subscribable, - F: 'static + FnMut(&mut T), +/// Registers a given function with a reactivity db, based on a pointer to the boxed function. +/// We use this as the identity of that method and return that identity so that we can continue to +/// mutate the composition of the registered function. +pub fn make_triggerable(reactivity_db: &mut ExecutionGraph, args: usize, func: F) -> usize +where + F: 'static + FnMut(Vec<&Option>>) -> Vec, { let boxed_fn = Box::new(func); let box_address = &*boxed_fn as *const _ as usize; - reactivity_db.add_operation(box_address, boxed_fn); + // reactivity_db.upsert_operation(box_address, args, boxed_fn) + assert!(false); + 0 } - - #[cfg(test)] mod tests { use super::*; use crate::time_travel::global::HistoricalData; - #[derive(Debug, Clone)] struct Element { name: String, @@ -68,41 +62,16 @@ mod tests { impl<'a> TriggerContext for HistoryContext<'a> {} - - #[test] - fn test_has_changed() { - let element = Element { name: "unchanged".into() }; - assert_eq!(element.has_changed(), false); - - let changed_element = Element { name: "name changed".into() }; - assert_eq!(changed_element.has_changed(), true); - } - - #[test] - fn test_historical_data() { - let initial_data = ImmutableHashMap::unit(String::from("status"), String::from("initial")); - let updated_data = ImmutableHashMap::unit(String::from("status"), String::from("updated")); - - let mut historical_data = HistoricalData::new(initial_data); - assert_eq!(historical_data.has_changed(), false); - - historical_data.update(updated_data); - assert_eq!(historical_data.has_changed(), true); - } - - #[test] fn test_trigger_on_dependency_change() { let mut binding = false; { - let mut registry = ReactivityDatabase::new(); + let mut registry = ExecutionGraph::new(); let mut context = Context { triggered: &mut binding, }; - make_triggerable( &mut registry, |ctx: &mut Context| { - *ctx.triggered = true; - }, &mut context); + make_triggerable(&mut registry, 2, |ctx: Vec<&Option>>| vec![1, 2]); assert_eq!(*context.triggered, true); } @@ -114,15 +83,12 @@ mod tests { let mut historical_data_a = HistoricalData::new(initial_data_a); { - let mut registry = ReactivityDatabase::new(); + let mut registry = ExecutionGraph::new(); let mut context = HistoryContext { data: &mut historical_data_a, }; - make_triggerable(&mut registry, |ctx: &mut HistoryContext| { - ctx.data.update(ImmutableHashMap::unit(String::from("status"), String::from("changed again"))); - }, &mut context); + make_triggerable(&mut registry, 2, |ctx: Vec<&Option>>| vec![1, 2]); } } } - diff --git a/toolchain/prompt-graph-core/src/execution/language/compile.rs b/toolchain/prompt-graph-core/src/execution/language/compile.rs new file mode 100644 index 0000000..5c1986a --- /dev/null +++ b/toolchain/prompt-graph-core/src/execution/language/compile.rs @@ -0,0 +1,103 @@ +use crate::execution::execution::execution_graph::ExecutionGraph; +use crate::execution::execution::execution_state::ExecutionState; +/// Logic to convert the rust ast to our scheduler graph +use crate::execution::language::parser::{BinaryOp, Error, Expr, Func, Program, Value}; +use std::collections::HashMap; + +use crate::execution::primitives::serialized_value::{ + serialize_to_vec, RkyvSerializedValue as RSV, +}; + +// TODO: how does this actually fetch the functions? +fn compile_to_graph(program: Program) -> ExecutionGraph { + let mut db = ExecutionGraph::new(); + let mut state = ExecutionState::new(); + let state_id = (0, 0); + + // We start with the number 0 at node 0 + let (state_id, mut state) = db.upsert_operation( + state_id, + state, + 0, + 0, + Box::new(|_args| { + let v = RSV::Number(0); + return serialize_to_vec(&v); + }), + ); + + if let Some(main_func) = program.funcs.get("main") { + eval_to_graph(&main_func.body.0, &program.funcs, &mut db); + } + // ast.body + + // fn eval_expr( + // expr: &Spanned, + // funcs: &HashMap, + // stack: &mut Vec<(String, Value)>, + // ) -> Result { + + // TODO: look for a "main" function + // 1. For each function call, create a node + // 2. For each import, create a node + // 3. Create connections between nodes when we refer to variables + // 4. For each function assigned to a variable, create a connection to all function invocations that refer to that variable + db +} + +fn eval_to_graph( + expr: &Expr, + funcs: &HashMap, + db: &mut ExecutionGraph, +) -> Result<(), Error> { + // match &expr { + // Expr::Error => unreachable!(), // Error expressions only get created by parser errors, so cannot exist in a valid AST + // Expr::Value(val) => val.clone(), + // Expr::List(items) => Value::List(), + // Expr::Local(name) => stack, + // Expr::Let(local, val, body) => {} + // Expr::Then(a, b) => { + // eval_to_graph(a, funcs, stack)?; + // eval_to_graph(b, funcs, stack)? + // } + // Expr::Binary(a, BinaryOp::Add, b) => Value::Num(), + // Expr::Binary(a, BinaryOp::Sub, b) => Value::Num(), + // Expr::Binary(a, BinaryOp::Mul, b) => Value::Num(), + // Expr::Binary(a, BinaryOp::Div, b) => Value::Num(), + // Expr::Binary(a, BinaryOp::Eq, b) => {} + // Expr::Binary(a, BinaryOp::NotEq, b) => {} + // Expr::Binary(a, BinaryOp::PipeOp, b) => {} + // Expr::Call(func, args) => {} + // Expr::If(cond, a, b) => {} + // Expr::Print(a) => {} + // } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::execution::language::parser::parse; + use std::collections::HashMap; + + #[test] + fn test_compiling_simple_program() { + let program = parse( + r#" + fn main() { + let x = 1; + let y = 2; + print(x + y); + x |> + y |> + print(x + y); + } + "# + .to_string(), + ) + .unwrap() + .unwrap(); + + let db = compile_to_graph(program); + } +} diff --git a/toolchain/prompt-graph-core/src/execution/language/mod.rs b/toolchain/prompt-graph-core/src/execution/language/mod.rs new file mode 100644 index 0000000..8638589 --- /dev/null +++ b/toolchain/prompt-graph-core/src/execution/language/mod.rs @@ -0,0 +1,5 @@ +// This language exists to be able to author lazily evaluated functions. +// It's possible to do this in Rust, but it's not ergonomic. + +pub mod compile; +pub mod parser; diff --git a/toolchain/prompt-graph-core/src/execution/language/parser.rs b/toolchain/prompt-graph-core/src/execution/language/parser.rs new file mode 100644 index 0000000..2931f26 --- /dev/null +++ b/toolchain/prompt-graph-core/src/execution/language/parser.rs @@ -0,0 +1,755 @@ +//! This is an entire parser and interpreter for a dynamically-typed Rust-like expression-oriented +//! programming language. This is taken from the chumsky examples. +//! +//! The goal of this mini language is the definition of LLM supported software. This should produce +//! an AST that can be translated to our graph representation. + +use ariadne::{Color, Fmt, Label, Report, ReportKind, Source}; +use chumsky::{prelude::*, stream::Stream}; +use std::io::Cursor; +use std::{collections::HashMap, env, fmt, fs}; + +pub type Span = std::ops::Range; + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum Token { + Null, + Bool(bool), + Num(String), + Str(String), + Op(String), + Ctrl(char), + Ident(String), + Fn, + Use, + Let, + Print, + If, + Else, +} + +impl fmt::Display for Token { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Token::Null => write!(f, "null"), + Token::Bool(x) => write!(f, "{}", x), + Token::Num(n) => write!(f, "{}", n), + Token::Str(s) => write!(f, "{}", s), + Token::Op(s) => write!(f, "{}", s), + Token::Ctrl(c) => write!(f, "{}", c), + Token::Ident(s) => write!(f, "{}", s), + Token::Fn => write!(f, "fn"), + Token::Use => write!(f, "use"), + Token::Let => write!(f, "let"), + Token::Print => write!(f, "print"), + Token::If => write!(f, "if"), + Token::Else => write!(f, "else"), + } + } +} + +fn lexer() -> impl Parser, Error = Simple> { + // A parser for numbers + let num = text::int(10) + .chain::(just('.').chain(text::digits(10)).or_not().flatten()) + .collect::() + .map(Token::Num); + + // A parser for strings + let str_ = just('"') + .ignore_then(filter(|c| *c != '"').repeated()) + .then_ignore(just('"')) + .collect::() + .map(Token::Str); + + // A parser for operators + let op = one_of("+-*/!=") + .repeated() + .at_least(1) + .collect::() + .map(|op| Token::Op(op)) + .or(just("|>").map(|_| Token::Op("|>".to_string()))); + + // A parser for control characters (delimiters, semicolons, etc.) + let ctrl = one_of("()[]{};,").map(|c| Token::Ctrl(c)); + + // A parser for identifiers and keywords + let ident = text::ident().map(|ident: String| match ident.as_str() { + "fn" => Token::Fn, + "use" => Token::Use, + "let" => Token::Let, + "print" => Token::Print, + "if" => Token::If, + "else" => Token::Else, + "true" => Token::Bool(true), + "false" => Token::Bool(false), + "null" => Token::Null, + _ => Token::Ident(ident), + }); + + // A single token can be one of the above + let token = num + .or(str_) + .or(op) + .or(ctrl) + .or(ident) + .recover_with(skip_then_retry_until([])); + + let comment = just("//").then(take_until(just('\n'))).padded(); + + token + .map_with_span(|tok, span| (tok, span)) + .padded_by(comment.repeated()) + .padded() + .repeated() +} + +#[derive(Clone, Debug, PartialEq)] +pub enum Value { + Null, + Bool(bool), + Num(f64), + Str(String), + List(Vec), + Func(String), +} + +impl Value { + fn num(self, span: Span) -> Result { + if let Value::Num(x) = self { + Ok(x) + } else { + Err(Error { + span, + msg: format!("'{}' is not a number", self), + }) + } + } +} + +impl std::fmt::Display for Value { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::Null => write!(f, "null"), + Self::Bool(x) => write!(f, "{}", x), + Self::Num(x) => write!(f, "{}", x), + Self::Str(x) => write!(f, "{}", x), + Self::List(xs) => write!( + f, + "[{}]", + xs.iter() + .map(|x| x.to_string()) + .collect::>() + .join(", ") + ), + Self::Func(name) => write!(f, "", name), + } + } +} + +#[derive(Clone, Debug, PartialEq)] +pub enum BinaryOp { + Add, + Sub, + Mul, + Div, + Eq, + NotEq, + PipeOp, +} + +pub type Spanned = (T, Span); + +// An expression node in the AST. Children are spanned so we can generate useful runtime errors. +#[derive(Debug, PartialEq)] +pub enum Expr { + Error, + Value(Value), + List(Vec>), + Local(String), + Let(String, Box>, Box>), + // Used to chain blocks together. + Then(Box>, Box>), + Binary(Box>, BinaryOp, Box>), + Call(Box>, Vec>), + If(Box>, Box>, Box>), + Import(String, String, String), + Print(Box>), +} + +// A function node in the AST. +#[derive(Debug, PartialEq)] +pub struct Func { + pub args: Vec, + pub body: Spanned, +} + +fn import_parser() -> impl Parser, Error = Simple> + Clone { + let string_literal = select! { Token::Str(s) => s.clone() }.labelled("string literal"); + + let import = just(Token::Use) + .ignore_then(just(Token::Ident("import".to_string()))) + .then_ignore(just(Token::Ctrl('('))) + .ignore_then(string_literal.clone()) // Source path + .then_ignore(just(Token::Ctrl(','))) + .then(string_literal) // Language + .then_ignore(just(Token::Ctrl(')'))) + .then_ignore(just(Token::Ident("as".to_string()))) + .then(select! { Token::Ident(ident) => ident.clone() }) // Alias + .then_ignore(just(Token::Ctrl(';'))) + .map_with_span(|((path, lang), alias), span| { + // Here, create an Expr variant to represent an import statement + (Expr::Import(path, lang, alias), span) + }) + .labelled("import statement"); + + import +} + +fn expr_parser() -> impl Parser, Error = Simple> + Clone { + recursive(|expr| { + let raw_expr = recursive(|raw_expr| { + let val = select! { + Token::Null => Expr::Value(Value::Null), + Token::Bool(x) => Expr::Value(Value::Bool(x)), + Token::Num(n) => Expr::Value(Value::Num(n.parse().unwrap())), + Token::Str(s) => Expr::Value(Value::Str(s)), + } + .labelled("value"); + + let ident = select! { Token::Ident(ident) => ident.clone() }.labelled("identifier"); + + // A list of expressions + let items = expr + .clone() + .separated_by(just(Token::Ctrl(','))) + .allow_trailing(); + + // A let expression + let let_ = just(Token::Let) + .ignore_then(ident) + .then_ignore(just(Token::Op("=".to_string()))) + .then(raw_expr) + .then_ignore(just(Token::Ctrl(';'))) + .then(expr.clone()) + .map(|((name, val), body)| Expr::Let(name, Box::new(val), Box::new(body))); + + let list = items + .clone() + .delimited_by(just(Token::Ctrl('[')), just(Token::Ctrl(']'))) + .map(Expr::List); + + // 'Atoms' are expressions that contain no ambiguity + let atom = val + .or(ident.map(Expr::Local)) + .or(let_) + .or(list) + // In Nano Rust, `print` is just a keyword, just like Python 2, for simplicity + .or(just(Token::Print) + .ignore_then( + expr.clone() + .delimited_by(just(Token::Ctrl('(')), just(Token::Ctrl(')'))), + ) + .map(|expr| Expr::Print(Box::new(expr)))) + .map_with_span(|expr, span| (expr, span)) + // Atoms can also just be normal expressions, but surrounded with parentheses + .or(expr + .clone() + .delimited_by(just(Token::Ctrl('(')), just(Token::Ctrl(')')))) + // Attempt to recover anything that looks like a parenthesised expression but contains errors + .recover_with(nested_delimiters( + Token::Ctrl('('), + Token::Ctrl(')'), + [ + (Token::Ctrl('['), Token::Ctrl(']')), + (Token::Ctrl('{'), Token::Ctrl('}')), + ], + |span| (Expr::Error, span), + )) + // Attempt to recover anything that looks like a list but contains errors + .recover_with(nested_delimiters( + Token::Ctrl('['), + Token::Ctrl(']'), + [ + (Token::Ctrl('('), Token::Ctrl(')')), + (Token::Ctrl('{'), Token::Ctrl('}')), + ], + |span| (Expr::Error, span), + )); + + // Function calls have very high precedence so we prioritise them + let call = atom + .then( + items + .delimited_by(just(Token::Ctrl('(')), just(Token::Ctrl(')'))) + .map_with_span(|args, span: Span| (args, span)) + .repeated(), + ) + .foldl(|f, args| { + let span = f.1.start..args.1.end; + (Expr::Call(Box::new(f), args.0), span) + }); + + // Product ops (multiply and divide) have equal precedence + let op = just(Token::Op("*".to_string())) + .to(BinaryOp::Mul) + .or(just(Token::Op("/".to_string())).to(BinaryOp::Div)); + let product = call + .clone() + .then(op.then(call).repeated()) + .foldl(|a, (op, b)| { + let span = a.1.start..b.1.end; + (Expr::Binary(Box::new(a), op, Box::new(b)), span) + }); + + // Sum ops (add and subtract) have equal precedence + let op = just(Token::Op("+".to_string())) + .to(BinaryOp::Add) + .or(just(Token::Op("-".to_string())).to(BinaryOp::Sub)); + let sum = product + .clone() + .then(op.then(product).repeated()) + .foldl(|a, (op, b)| { + let span = a.1.start..b.1.end; + (Expr::Binary(Box::new(a), op, Box::new(b)), span) + }); + + // Comparison ops (equal, not-equal) have equal precedence + let op = just(Token::Op("==".to_string())) + .to(BinaryOp::Eq) + .or(just(Token::Op("!=".to_string())).to(BinaryOp::NotEq)); + let compare = sum + .clone() + .then(op.then(sum).repeated()) + .foldl(|a, (op, b)| { + let span = a.1.start..b.1.end; + (Expr::Binary(Box::new(a), op, Box::new(b)), span) + }); + + // Pipe operator ops have equal precedence + let op = just(Token::Op("|>".to_string())).to(BinaryOp::PipeOp); + let pipe = compare + .clone() + .then(op.then(compare).repeated()) + .foldl(|a, (op, b)| { + let span = a.1.start..b.1.end; + (Expr::Binary(Box::new(a), op, Box::new(b)), span) + }); + + pipe + }); + + // Blocks are expressions but delimited with braces + let block = expr + .clone() + .delimited_by(just(Token::Ctrl('{')), just(Token::Ctrl('}'))) + // Attempt to recover anything that looks like a block but contains errors + .recover_with(nested_delimiters( + Token::Ctrl('{'), + Token::Ctrl('}'), + [ + (Token::Ctrl('('), Token::Ctrl(')')), + (Token::Ctrl('['), Token::Ctrl(']')), + ], + |span| (Expr::Error, span), + )); + + let if_ = recursive(|if_| { + just(Token::If) + .ignore_then(expr.clone()) + .then(block.clone()) + .then( + just(Token::Else) + .ignore_then(block.clone().or(if_)) + .or_not(), + ) + .map_with_span(|((cond, a), b), span: Span| { + ( + Expr::If( + Box::new(cond), + Box::new(a), + Box::new(match b { + Some(b) => b, + // If an `if` expression has no trailing `else` block, we magic up one that just produces null + None => (Expr::Value(Value::Null), span.clone()), + }), + ), + span, + ) + }) + }); + + // Both blocks and `if` are 'block expressions' and can appear in the place of statements + let block_expr = block.or(if_).labelled("block"); + + let block_chain = block_expr + .clone() + .then(block_expr.clone().repeated()) + .foldl(|a, b| { + let span = a.1.start..b.1.end; + (Expr::Then(Box::new(a), Box::new(b)), span) + }); + + block_chain + // Expressions, chained by semicolons, are statements + .or(raw_expr.clone()) + .then(just(Token::Ctrl(';')).ignore_then(expr.or_not()).repeated()) + .foldl(|a, b| { + // This allows creating a span that covers the entire Then expression. + // b_end is the end of b if it exists, otherwise it is the end of a. + let a_start = a.1.start; + let b_end = b.as_ref().map(|b| b.1.end).unwrap_or(a.1.end); + ( + Expr::Then( + Box::new(a), + Box::new(match b { + Some(b) => b, + // Since there is no b expression then its span is empty. + None => (Expr::Value(Value::Null), b_end..b_end), + }), + ), + a_start..b_end, + ) + }) + }) +} + +fn funcs_parser() -> impl Parser, Error = Simple> + Clone { + let ident = filter_map(|span, tok| match tok { + Token::Ident(ident) => Ok(ident.clone()), + _ => Err(Simple::expected_input_found(span, Vec::new(), Some(tok))), + }); + + // Argument lists are just identifiers separated by commas, surrounded by parentheses + let args = ident + .clone() + .separated_by(just(Token::Ctrl(','))) + .allow_trailing() + .delimited_by(just(Token::Ctrl('(')), just(Token::Ctrl(')'))) + .labelled("function args"); + + let func = just(Token::Fn) + .ignore_then( + ident + .map_with_span(|name, span| (name, span)) + .labelled("function name"), + ) + .then(args) + .then( + expr_parser() + .delimited_by(just(Token::Ctrl('{')), just(Token::Ctrl('}'))) + // Attempt to recover anything that looks like a function body but contains errors + .recover_with(nested_delimiters( + Token::Ctrl('{'), + Token::Ctrl('}'), + [ + (Token::Ctrl('('), Token::Ctrl(')')), + (Token::Ctrl('['), Token::Ctrl(']')), + ], + |span| (Expr::Error, span), + )), + ) + .map(|((name, args), body)| (name, Func { args, body })) + .labelled("function"); + + func.repeated() + .try_map(|fs, _| { + let mut funcs = HashMap::new(); + for ((name, name_span), f) in fs { + if funcs.insert(name.clone(), f).is_some() { + return Err(Simple::custom( + name_span.clone(), + format!("Function '{}' already exists", name), + )); + } + } + Ok(funcs) + }) + .then_ignore(end()) +} + +pub struct Error { + span: Span, + msg: String, +} + +fn file_parser() -> impl Parser> + Clone { + // Define the parsers for imports and functions + let import_p = import_parser(); + let funcs_p = funcs_parser(); + + // The main parser should parse a sequence of imports and functions + import_p + .clone() + .repeated() // Allows for multiple import statements + .then(funcs_p.clone()) + .map(|(imports, funcs)| Program { imports, funcs }) + // Handle the end of the input + .then_ignore(end()) +} + +// Define a structure to hold the parsed program +#[derive(Debug)] +pub struct Program { + pub imports: Vec>, + pub funcs: HashMap, +} + +pub fn parse(src: String) -> Result, Vec> { + let (tokens, mut errs) = lexer().parse_recovery(src.as_str()); + + let (ast, parse_errs) = if let Some(tokens) = tokens { + //dbg!(tokens); + let len = src.chars().count(); + let (ast, parse_errs) = + file_parser().parse_recovery(Stream::from_iter(len..len + 1, tokens.into_iter())); + + (ast, parse_errs) + } else { + (None, Vec::new()) + }; + + let formatted_errs = errs + .into_iter() + .map(|e| e.map(|c| c.to_string())) + .chain(parse_errs.into_iter().map(|e| e.map(|tok| tok.to_string()))) + .map(|e| { + let report = Report::build(ReportKind::Error, (), e.span().start); + + let report = match e.reason() { + chumsky::error::SimpleReason::Unclosed { span, delimiter } => report + .with_message(format!( + "Unclosed delimiter {}", + delimiter.fg(Color::Yellow) + )) + .with_label( + Label::new(span.clone()) + .with_message(format!( + "Unclosed delimiter {}", + delimiter.fg(Color::Yellow) + )) + .with_color(Color::Yellow), + ) + .with_label( + Label::new(e.span()) + .with_message(format!( + "Must be closed before this {}", + e.found() + .unwrap_or(&"end of file".to_string()) + .fg(Color::Red) + )) + .with_color(Color::Red), + ), + chumsky::error::SimpleReason::Unexpected => report + .with_message(format!( + "{}, expected {}", + if e.found().is_some() { + "Unexpected token in input" + } else { + "Unexpected end of input" + }, + if e.expected().len() == 0 { + "something else".to_string() + } else { + e.expected() + .map(|expected| match expected { + Some(expected) => expected.to_string(), + None => "end of input".to_string(), + }) + .collect::>() + .join(", ") + } + )) + .with_label( + Label::new(e.span()) + .with_message(format!( + "Unexpected token {}", + e.found() + .unwrap_or(&"end of file".to_string()) + .fg(Color::Red) + )) + .with_color(Color::Red), + ), + chumsky::error::SimpleReason::Custom(msg) => report.with_message(msg).with_label( + Label::new(e.span()) + .with_message(format!("{}", msg.fg(Color::Red))) + .with_color(Color::Red), + ), + }; + + let mut buffer = Cursor::new(Vec::new()); + report + .finish() + .write(Source::from(&src), &mut buffer) + .unwrap(); + let s = String::from_utf8(buffer.into_inner()).expect("Found invalid UTF-8"); + s + }); + + Ok(ast) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + #[test] + fn test_parsing_simple_program() { + let result = parse( + r#" + fn main() { + let x = 1; + let y = 2; + print(x + y); + x |> + y |> + print(x + y); + } + "# + .to_string(), + ); + + assert_eq!(result.is_ok(), true); + let result = result.unwrap(); + assert_eq!(result.is_some(), true); + let result = result.unwrap(); + + if let Some(Func { args, body }) = result.funcs.get("main") { + let (Expr::Let(n, bx, rest), _) = body else { panic!("Unexpected expression structure") }; + let (Expr::Value(Value::Num(v)), _) = **bx else { panic!("Unexpected expression structure") }; + assert_eq!(n, &"x".to_string()); + assert_eq!(v, 1.0); + + let (Expr::Let(ref n, ref bx, ref rest), _) = **rest else { panic!("Unexpected expression structure") }; + let (Expr::Value(Value::Num(v)), _) = **bx else { panic!("Unexpected expression structure") }; + assert_eq!(n, &"y".to_string()); + assert_eq!(v, 2.0); + + let (Expr::Then(ref left, ref right), _) = **rest else { panic!("Unexpected expression structure") }; + let (Expr::Print(ref bx), _) = **left else { panic!("Unexpected expression structure") }; + let (Expr::Binary(ref left, ref op, ref right), _) = **bx else { panic!("Unexpected expression structure") }; + let (Expr::Local(ref n), _) = **left else { panic!("Unexpected expression structure") }; + let (Expr::Local(ref n2), _) = **right else { panic!("Unexpected expression structure") }; + assert_eq!(n, &"x".to_string()); + assert_eq!(n2, &"y".to_string()); + assert_eq!(op, &BinaryOp::Add); + } else { + panic!("Function 'main' not found"); + } + } + + #[test] + fn test_valid_import() { + let parser = import_parser().then_ignore(end()); + let test_input = "use import(\"src/js/example\", \"javascript\") as z;"; + let len = test_input.chars().count(); + let (tokens, mut errs) = lexer().parse_recovery(test_input); + let parsed = parser + .parse(Stream::from_iter(len..len + 1, tokens.unwrap().into_iter())) + .unwrap() + .0; + + let expected = Expr::Import( + "src/js/example".to_string(), + "javascript".to_string(), + "z".to_string(), + ); + + assert_eq!(parsed, expected); + } + + #[test] + fn test_invalid_import_missing_as() { + let parser = import_parser().then_ignore(end()); + let test_input = "use import(\"src/js/example\", \"javascript\");"; + let len = test_input.chars().count(); + let (tokens, mut errs) = lexer().parse_recovery(test_input); + let parsed = parser.parse(Stream::from_iter(len..len + 1, tokens.unwrap().into_iter())); + + assert!(parsed.is_err()); + } + + #[test] + fn test_invalid_import_missing_quotes() { + let parser = import_parser().then_ignore(end()); + let test_input = "use import(src/js/example, javascript) as z;"; + let len = test_input.chars().count(); + let (tokens, mut errs) = lexer().parse_recovery(test_input); + let parsed = parser.parse(Stream::from_iter(len..len + 1, tokens.unwrap().into_iter())); + assert!(parsed.is_err()); + } + + #[test] + fn test_import_with_extra_tokens() { + let parser = import_parser().then_ignore(end()); + let test_input = "use import(\"src/js/example\", \"javascript\") as z extra;"; + let len = test_input.chars().count(); + let (tokens, mut errs) = lexer().parse_recovery(test_input); + let parsed = parser.parse(Stream::from_iter(len..len + 1, tokens.unwrap().into_iter())); + + assert!(parsed.is_err()); + } + + #[test] + fn test_parsing_program_with_imports() { + let result = parse( + r#" + use import("src/js/example", "javascript") as z; + use import("src/prompt/example", "prompt") as p; + use import("src/prompt/example", "text") as t; + use import("src/py/example", "python") as py; + + fn main() { + let x = 1; + let y = 2; + print(x + y); + } + "# + .to_string(), + ); + + let result = result.unwrap(); + let result = result.unwrap(); + + assert_eq!( + Expr::Import("src/js/example".into(), "javascript".into(), "z".into()), + result.imports[0].0 + ); + assert_eq!( + Expr::Import("src/prompt/example".into(), "prompt".into(), "p".into()), + result.imports[1].0 + ); + assert_eq!( + Expr::Import("src/prompt/example".into(), "text".into(), "t".into()), + result.imports[2].0 + ); + assert_eq!( + Expr::Import("src/py/example".into(), "python".into(), "py".into()), + result.imports[3].0 + ); + + if let Some(Func { args, body }) = result.funcs.get("main") { + let (Expr::Let(n, bx, rest), _) = body else { panic!("Unexpected expression structure") }; + let (Expr::Value(Value::Num(v)), _) = **bx else { panic!("Unexpected expression structure") }; + assert_eq!(n, &"x".to_string()); + assert_eq!(v, 1.0); + + let (Expr::Let(ref n, ref bx, ref rest), _) = **rest else { panic!("Unexpected expression structure") }; + let (Expr::Value(Value::Num(v)), _) = **bx else { panic!("Unexpected expression structure") }; + assert_eq!(n, &"y".to_string()); + assert_eq!(v, 2.0); + + let (Expr::Then(ref left, ref right), _) = **rest else { panic!("Unexpected expression structure") }; + let (Expr::Print(ref bx), _) = **left else { panic!("Unexpected expression structure") }; + let (Expr::Binary(ref left, ref op, ref right), _) = **bx else { panic!("Unexpected expression structure") }; + let (Expr::Local(ref n), _) = **left else { panic!("Unexpected expression structure") }; + let (Expr::Local(ref n2), _) = **right else { panic!("Unexpected expression structure") }; + assert_eq!(n, &"x".to_string()); + assert_eq!(n2, &"y".to_string()); + assert_eq!(op, &BinaryOp::Add); + } else { + panic!("Function 'main' not found"); + } + } +} diff --git a/toolchain/prompt-graph-core/src/reactivity/typechecker.rs b/toolchain/prompt-graph-core/src/execution/language/typechecker.rs similarity index 100% rename from toolchain/prompt-graph-core/src/reactivity/typechecker.rs rename to toolchain/prompt-graph-core/src/execution/language/typechecker.rs diff --git a/toolchain/prompt-graph-core/src/execution/mod.rs b/toolchain/prompt-graph-core/src/execution/mod.rs new file mode 100644 index 0000000..f9ba2a4 --- /dev/null +++ b/toolchain/prompt-graph-core/src/execution/mod.rs @@ -0,0 +1,4 @@ +pub mod execution; +pub mod integration; +pub mod language; +pub mod primitives; diff --git a/toolchain/prompt-graph-core/src/execution/primitives/identifiers.rs b/toolchain/prompt-graph-core/src/execution/primitives/identifiers.rs new file mode 100644 index 0000000..3d1188d --- /dev/null +++ b/toolchain/prompt-graph-core/src/execution/primitives/identifiers.rs @@ -0,0 +1,3 @@ +pub type OperationId = usize; +pub type ArgumentIndex = usize; +pub type TimestampOfWrite = usize; diff --git a/toolchain/prompt-graph-core/src/execution/primitives/mod.rs b/toolchain/prompt-graph-core/src/execution/primitives/mod.rs new file mode 100644 index 0000000..7ff65bf --- /dev/null +++ b/toolchain/prompt-graph-core/src/execution/primitives/mod.rs @@ -0,0 +1,3 @@ +pub mod identifiers; +pub mod operation; +pub mod serialized_value; diff --git a/toolchain/prompt-graph-core/src/reactivity/operation.rs b/toolchain/prompt-graph-core/src/execution/primitives/operation.rs similarity index 58% rename from toolchain/prompt-graph-core/src/reactivity/operation.rs rename to toolchain/prompt-graph-core/src/execution/primitives/operation.rs index 99d689d..0785ae9 100644 --- a/toolchain/prompt-graph-core/src/reactivity/operation.rs +++ b/toolchain/prompt-graph-core/src/execution/primitives/operation.rs @@ -1,6 +1,11 @@ +/// An Operation is a function, which can be executed on the graph. It +/// can be pure or impure, and it can be mutable or immutable. Each Operation +/// has a unique identifier within a given graph. +use crate::execution::integration::triggerable::TriggerContext; use std::collections::HashMap; -use crate::reactivity::triggerable::{TriggerContext}; +use std::fmt; +#[derive(PartialEq, Debug)] pub struct Signature { /// Signature of the total inputs for this graph input_signature: HashMap, @@ -18,12 +23,13 @@ impl Signature { } } - +#[derive(PartialEq, Debug)] enum Purity { Pure, - Impure + Impure, } +#[derive(PartialEq, Debug)] enum Mutability { Mutable, Immutable, @@ -46,17 +52,18 @@ enum Mutability { /// It is up to the user to structure those maps in such a way that they don't collide with other /// values being represented in the state of our system. These inputs and outputs are managed /// by our Execution Database. -pub type OperationFn = dyn FnMut(&[u8]) -> Vec; +pub type OperationFn = dyn FnMut(Vec<&Option>>) -> Vec; pub struct OperationNodeDefinition { /// The operation function itself - operation: Option>, + pub(crate) operation: Option>, /// Dependencies of this node - pub(crate) dependencies: Vec, + pub(crate) dependency_count: usize, } pub struct OperationNode { + pub(crate) id: usize, /// Is the node pure or impure, does it have side effects? Does it depend on external state? purity: Purity, @@ -64,9 +71,6 @@ pub struct OperationNode { /// Is the node mutable or immutable, can its value change after an execution? mutability: Mutability, - /// Is this node observed by a consumer? - is_observed: bool, - /// When did the output of this node last actually change changed_at: usize, @@ -86,41 +90,66 @@ pub struct OperationNode { operation: Option>, /// Dependencies of this node - pub(crate) dependencies: Vec, + pub(crate) arity: usize, + pub(crate) dependency_count: usize, + pub(crate) unresolved_dependencies: Vec, /// Partial application arena - this stores partially applied arguments for this OperationNode - partial_application: Vec + partial_application: Vec, +} + +// TODO: OperationNode need + +impl fmt::Debug for OperationNode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OperationNode") + .field("id", &self.id) + .field("purity", &self.purity) + .field("mutability", &self.mutability) + .field("changed_at", &self.changed_at) + .field("verified_at", &self.verified_at) + .field("dirty", &self.dirty) + .field("height", &self.height) + .field("signature", &self.signature) + .field("operation", &self.operation.is_some()) + .field("dependency_count", &self.dependency_count) + .field("unresolved_dependencies", &self.unresolved_dependencies) + .field("partial_application", &self.partial_application) + .finish() + } } impl Default for OperationNode { fn default() -> Self { OperationNode { + id: 0, purity: Purity::Pure, mutability: Mutability::Mutable, - is_observed: true, changed_at: 0, verified_at: 0, height: 0, - dirty: false, + dirty: true, signature: Signature::new(), operation: None, - dependencies: vec![], + arity: 0, + dependency_count: 0, + unresolved_dependencies: vec![], partial_application: Vec::new(), } } } impl OperationNode { - pub(crate) fn new(f: Option>) -> Self { + pub(crate) fn new(args: usize, f: Option>) -> Self { let mut node = OperationNode::default(); node.operation = f; + node.dependency_count = args; node } - pub(crate) fn from(d: &OperationNodeDefinition) -> Self { + pub(crate) fn from(mut d: OperationNodeDefinition) -> Self { let mut node = OperationNode::default(); - node.operation = d.operation.clone(); - node.dependencies = d.dependencies.clone(); + node.operation = d.operation.take(); node } @@ -128,9 +157,11 @@ impl OperationNode { unimplemented!(); } - pub(crate) fn execute(&self, context: &[u8]) { - if let Some(exec) = self.operation.as_ref() { - exec(context); + pub(crate) fn execute(&mut self, context: Vec<&Option>>) -> Option> { + if let Some(exec) = self.operation.as_deref_mut() { + Some(exec(context)) + } else { + None } } } @@ -142,24 +173,26 @@ mod tests { #[test] fn test_execute_with_operation() { let mut executed = false; - let operation: Box = Box::new(|context: &[u8]| -> Vec { - context.to_vec() - }); + let operation: Box = + Box::new(|context: Vec<&Option>>| -> Vec { vec![0, 1] }); let mut node = OperationNode::default(); node.operation = Some(operation); let bytes = vec![1, 2, 3]; - node.execute(&bytes); + node.execute(vec![&Some(bytes)]); assert_eq!(executed, true); } #[test] fn test_execute_without_operation() { - let node = OperationNode::default(); + let mut node = OperationNode::default(); let bytes = vec![1, 2, 3]; - node.execute(&bytes); // should not panic + node.execute(vec![&Some(bytes)]); // should not panic } + + // TODO: test application of Operations/composition + // TODO: test manual evaluation of a composition of operations } diff --git a/toolchain/prompt-graph-core/src/execution/primitives/serialized_value.rs b/toolchain/prompt-graph-core/src/execution/primitives/serialized_value.rs new file mode 100644 index 0000000..0a40288 --- /dev/null +++ b/toolchain/prompt-graph-core/src/execution/primitives/serialized_value.rs @@ -0,0 +1,153 @@ +use rkyv::{ + archived_root, check_archived_root, + ser::{serializers::AllocSerializer, Serializer}, + Archive, Deserialize, Serialize, +}; +use std::collections::HashMap; + +#[derive(Archive, Serialize, Deserialize, Debug, PartialEq)] +#[archive(bound(serialize = "__S: rkyv::ser::ScratchSpace + rkyv::ser::Serializer"))] +#[archive(check_bytes)] +#[archive_attr(check_bytes( + bound = "__C: rkyv::validation::ArchiveContext, <__C as rkyv::Fallible>::Error: std::error::Error" +))] +#[archive_attr(derive(Debug))] +pub enum RkyvSerializedValue { + StreamPointer(u32), + FunctionPointer(u32), + Float(f32), + Number(i32), + String(String), + Boolean(bool), + Null, + + Array( + #[omit_bounds] + #[archive_attr(omit_bounds)] + Vec, + ), + + Object( + #[omit_bounds] + #[archive_attr(omit_bounds)] + HashMap, + ), +} + +impl From for Vec { + fn from(item: RkyvSerializedValue) -> Self { + serialize_to_vec(&item) + } +} + +pub fn serialize_to_vec(v: &RkyvSerializedValue) -> Vec { + let mut serializer = AllocSerializer::<4096>::default(); + serializer.serialize_value(v).unwrap(); + let buf = serializer.into_serializer().into_inner(); + buf.to_vec() +} + +pub fn deserialize_from_buf(v: &[u8]) -> RkyvSerializedValue { + let rkyv1 = unsafe { archived_root::(v) }; + let arg1: RkyvSerializedValue = rkyv1.deserialize(&mut rkyv::Infallible).unwrap(); + arg1 +} + +#[cfg(test)] +mod tests { + use super::*; + use rkyv::{ + archived_root, + ser::{serializers::AllocSerializer, Serializer}, + with::{ArchiveWith, DeserializeWith, SerializeWith}, + Deserialize, Infallible, + }; + + fn round_trip(value: RkyvSerializedValue) -> () { + let mut serializer = AllocSerializer::<4096>::default(); + serializer.serialize_value(&value).unwrap(); + let buf = serializer.into_serializer().into_inner(); + let archived_value = unsafe { archived_root::(&buf) }; + check_archived_root::(&buf).unwrap(); + let deserialized: RkyvSerializedValue = + archived_value.deserialize(&mut rkyv::Infallible).unwrap(); + assert_eq!(deserialized, value); + } + + #[test] + fn test_float() { + let value = RkyvSerializedValue::Float(42.0); + round_trip(value); + } + + #[test] + fn test_number() { + let value = RkyvSerializedValue::Number(42); + round_trip(value); + } + + #[test] + fn test_string() { + let value = RkyvSerializedValue::String("Hello".to_string()); + round_trip(value); + } + + #[test] + fn test_boolean() { + let value = RkyvSerializedValue::Boolean(true); + round_trip(value); + } + + #[test] + fn test_array() { + let value = RkyvSerializedValue::Array(vec![ + RkyvSerializedValue::Number(42), + RkyvSerializedValue::Boolean(true), + ]); + round_trip(value); + } + + #[test] + fn test_object() { + let mut map = HashMap::new(); + map.insert( + "key".to_string(), + RkyvSerializedValue::String("value".to_string()), + ); + let value = RkyvSerializedValue::Object(map); + round_trip(value); + } + + #[test] + fn test_serialize_to_vec() { + let value = RkyvSerializedValue::String("Hello".to_string()); + let serialized_vec = serialize_to_vec(&value); + + // Verify if serialized_vec is non-empty, or any other conditions. + assert!(!serialized_vec.is_empty()); + } + + #[test] + fn test_deserialize_from_vec() { + let value = RkyvSerializedValue::String("Hello".to_string()); + let serialized_vec = serialize_to_vec(&value); + let deserialized_value = deserialize_from_buf(&serialized_vec); + + // Verify if deserialized_value matches the original value. + assert_eq!(value, deserialized_value); + } + + #[test] + fn test_serialize_deserialize_cycle() { + let value = RkyvSerializedValue::String("Hello".to_string()); + let serialized_vec = serialize_to_vec(&value); + let deserialized_value = deserialize_from_buf(&serialized_vec); + + // Verify if deserialization after serialization yields the original value. + assert_eq!(value, deserialized_value); + + // Further tests to ensure that serialization -> deserialization is an identity operation. + let reserialized_vec = serialize_to_vec(&deserialized_value); + assert_eq!(serialized_vec, reserialized_vec); + } +} diff --git a/toolchain/prompt-graph-core/src/execution/sdk/entry.rs b/toolchain/prompt-graph-core/src/execution/sdk/entry.rs new file mode 100644 index 0000000..8834cf7 --- /dev/null +++ b/toolchain/prompt-graph-core/src/execution/sdk/entry.rs @@ -0,0 +1,10 @@ +/// This is an SDK for building execution graphs. It is designed to be used iteratively. + +/// Start a new execution graph. +fn create() {} + +/// Add a node to the execution graph. +fn add_node() {} + +/// Add a relationship between two nodes in the execution graph. +fn add_relationship() {} diff --git a/toolchain/prompt-graph-core/src/execution/sdk/mod.rs b/toolchain/prompt-graph-core/src/execution/sdk/mod.rs new file mode 100644 index 0000000..e69de29 diff --git a/toolchain/prompt-graph-core/src/execution_router.rs b/toolchain/prompt-graph-core/src/execution_router.rs deleted file mode 100644 index 88b1ca3..0000000 --- a/toolchain/prompt-graph-core/src/execution_router.rs +++ /dev/null @@ -1,355 +0,0 @@ -use crate::build_runtime_graph::graph_parse::CleanedDefinitionGraph; -use crate::proto::{ChangeValue, ChangeValueWithCounter, DispatchResult, NodeWillExecute, WrappedChangeValue}; - -pub trait ExecutionState { - fn get_count_node_execution(&self, node: &[u8]) -> Option; - fn inc_counter_node_execution(&mut self, node: &[u8]) -> u64; - fn get_value(&self, address: &[u8]) -> Option<(u64, ChangeValue)>; - fn set_value(&mut self, address: &[u8], counter: u64, value: ChangeValue); -} - - -/// This is used to evaluate if a newly introduced node should be immediately evaluated -/// against the state of the system. -pub fn evaluate_changes_against_node( - state: &impl ExecutionState, - paths_to_satisfy: &Vec> -) -> Option> { - // for each of the matched nodes, we need to evaluate the query against the current state - // check if the updated state object is satisfying all necessary paths for this query - let mut satisfied_paths = vec![]; - - for path in paths_to_satisfy { - if let Some(change_value) = state.get_value(path.join(":").as_bytes()) { - satisfied_paths.push(change_value.clone()); - } - } - - if satisfied_paths.len() != paths_to_satisfy.len() { return None } - - Some(satisfied_paths.into_iter().map(|(counter, v)| WrappedChangeValue { - monotonic_counter: counter, - change_value: Some(v), - }).collect()) -} - - -/// This dispatch method is responsible for identifying which nodes should execute based on -/// a current key value state and a clean definition graph. It returns a list of nodes that -/// should be executed, and the path references that they were satisfied with. This exists in -/// the core implementation because it may be used in client code or our server. It will mutate -/// the provided ExecutionState to reflect the application of the provided change. Execution state -/// may internally persist records of what environment this change occurred in. -pub fn dispatch_and_mutate_state( - clean_definition_graph: &CleanedDefinitionGraph, - state: &mut impl ExecutionState, - change_value_with_counter: &ChangeValueWithCounter -) -> DispatchResult { - let g = clean_definition_graph; - - // TODO: dispatch with an vec![] address path should do what? - - // First pass we update the values present in the change - for filled_value in &change_value_with_counter.filled_values { - let filled_value_address = &filled_value.clone().path.unwrap().address; - - // In order to avoid double-execution of nodes, we need to check if the value has changed. - // matching here means that the state we are assessing execution of has already bee applied to our state. - // The state may have already been applied in a parent branch if the execution is taking place there as well. - if let Some((prev_counter, _prev_change_value)) = state.get_value(filled_value_address.join(":").as_bytes()) { - if prev_counter >= change_value_with_counter.monotonic_counter { - // Value has not updated - skip this change reflecting it and continue to the next change - continue - } - } - - state.set_value( - filled_value_address.join(":").as_bytes().clone(), - change_value_with_counter.monotonic_counter, - filled_value.clone()); - - } - - - // node_executions looks like a list of node names and their inputs - // Nodes may execute _multiple times_ in response to some changes that might occur. - let mut node_executions: Vec = vec![]; - // Apply a second pass to resolve into nodes that should execute - for filled_value in &change_value_with_counter.filled_values { - let filled_value_address = &filled_value.clone().path.unwrap().address; - - // TODO: if we're subscribed to all of the outputs of a node this will-eval a lot - // filter to nodes matched by the affected path -> name - // nodes with no queries are referred to by the empty string (derived from empty vec![]) and are always matched - if let Some(matched_node_names) = g.dispatch_table.get(filled_value_address.join(":").as_str()) { - for node_that_should_exec in matched_node_names { - if let Some(choice_paths_to_satisfy) = g.query_paths.get(node_that_should_exec) { - for (idx, opt_paths_to_satisfy) in choice_paths_to_satisfy.iter().enumerate() { - // TODO: NodeWillExecute should include _which_ query was satisfied - if let Some(paths_to_satisfy) = opt_paths_to_satisfy { - if let Some(change_values_used_in_execution) = evaluate_changes_against_node(state, paths_to_satisfy) { - let node_will_execute = NodeWillExecute { - source_node: node_that_should_exec.clone(), - change_values_used_in_execution, - matched_query_index: idx as u64 - }; - node_executions.push(node_will_execute); - } - } else { - // No paths to satisfy - // we've already executed this node, so we don't need to do it again - if state.get_count_node_execution(node_that_should_exec.as_bytes()).unwrap_or(0) > 0 { - continue; - } - node_executions.push(NodeWillExecute { - source_node: node_that_should_exec.clone(), - change_values_used_in_execution: vec![], - matched_query_index: idx as u64 - }); - } - state.inc_counter_node_execution(node_that_should_exec.as_bytes()); - - } - - } - } - } - } - - // we only _tell_ what we think should happen. We don't actually do it. - // it is up to the wrapping SDK what to do or not do with our information - DispatchResult { - operations: node_executions, - } -} - - -#[cfg(test)] -mod tests { - use crate::proto::{File, item, Item, ItemCore, OutputType, Path, PromptGraphNodeEcho, Query, SerializedValue}; - use crate::graph_definition::DefinitionGraph; - use std::collections::HashMap; - use crate::proto::serialized_value::Val; - - use super::*; - - #[derive(Debug)] - pub struct TestState { - value: HashMap, (u64, ChangeValue)>, - node_executions: HashMap, u64> - } - impl TestState { - fn new() -> Self { - Self { - value: HashMap::new(), - node_executions: HashMap::new() - } - } - } - - impl ExecutionState for TestState { - fn inc_counter_node_execution(&mut self, node: &[u8]) -> u64 { - let v = self.node_executions.entry(node.to_vec()).or_insert(0); - *v += 1; - *v - } - - fn get_count_node_execution(&self, node: &[u8]) -> Option { - self.node_executions.get(node).map(|x| *x) - } - - fn get_value(&self, address: &[u8]) -> Option<(u64, ChangeValue)> { - self.value.get(address).cloned() - } - - fn set_value(&mut self, address: &[u8], counter: u64, value: ChangeValue) { - self.value.insert(address.to_vec(), (counter, value)); - } - } - - fn get_file_empty_query() -> File { - File { - id: "test".to_string(), - nodes: vec![Item{ - core: Some(ItemCore { - name: "EmptyNode".to_string(), - triggers: vec![Query{ query: None}], - output: Some(OutputType { - output: "{}".to_string(), - }), - output_tables: vec![], - }), - item: Some(item::Item::NodeEcho(PromptGraphNodeEcho { - }))}], - } - } - - fn get_file() -> File { - File { - id: "test".to_string(), - nodes: vec![Item{ - core: Some(ItemCore { - name: "".to_string(), - triggers: vec![Query { - query: None, - }], - output: Some(OutputType { - output: "{} ".to_string(), - }), - output_tables: vec![] - }), - item: Some(item::Item::NodeEcho(PromptGraphNodeEcho { - }))}], - } - } - - fn get_file_with_paths() -> File { - File { - id: "test".to_string(), - nodes: vec![Item{ - core: Some(ItemCore { - name: "test_node".to_string(), - triggers: vec![Query { - query: Some("SELECT path1, path2 FROM source".to_string()), - }], - output: Some(OutputType { - output: "{} ".to_string(), - }), - output_tables: vec![] - }), - item: Some(item::Item::NodeEcho(PromptGraphNodeEcho { }))}], - } - } - - - #[test] - fn test_dispatch_with_file_and_change() { - let mut state = TestState::new(); - let file = get_file(); - let d = DefinitionGraph::from_file(file); - let g = CleanedDefinitionGraph::new(&d); - let c = ChangeValueWithCounter { - filled_values: vec![], - parent_monotonic_counters: vec![], - monotonic_counter: 0, - branch: 0, - source_node: "".to_string(), - }; - let result = dispatch_and_mutate_state(&g, &mut state, &c); - assert_eq!(result.operations.len(), 0); - } - - #[test] - fn test_we_dispatch_nodes_that_have_no_query_once() { - let mut state = TestState::new(); - let file = get_file_empty_query(); - let d = DefinitionGraph::from_file(file); - let g = CleanedDefinitionGraph::new(&d); - let c = ChangeValueWithCounter { - filled_values: vec![ChangeValue { - path: Some(Path { - address: vec![], - }), - value: None, - branch: 0, - }], - parent_monotonic_counters: vec![], - monotonic_counter: 0, - branch: 0, - source_node: "EmptyNode".to_string(), - }; - let result = dispatch_and_mutate_state(&g, &mut state, &c); - assert_eq!(result.operations.len(), 1); - assert_eq!(result.operations[0], NodeWillExecute { - source_node: "EmptyNode".to_string(), - change_values_used_in_execution: vec![], - matched_query_index: 0 - }); - - // Does not re-execute - let result = dispatch_and_mutate_state(&g, &mut state, &c); - assert_eq!(result.operations.len(), 0); - } - - #[test] - fn test_all_paths_must_be_satisfied_before_dispatch() { - // State should start empty - let mut state = TestState::new(); - let file = get_file_with_paths(); - let d = DefinitionGraph::from_file(file); - let g = CleanedDefinitionGraph::new(&d); - - // Confirm the dispatch table has the paths that we expect - assert_eq!(g.dispatch_table.get("source:path1"), Some(&vec!["test_node".to_string()])); - assert_eq!(g.dispatch_table.get("source:path2"), Some(&vec!["test_node".to_string()])); - - // Dispatch a change that satisfies only one of the two paths - let c = ChangeValueWithCounter { - filled_values: vec![ - ChangeValue { - path: Some(Path { - address: vec!["source".to_string(), "path1".to_string()], - }), - value: Some(SerializedValue{ val: Some(Val::String("value".to_string()))}), - branch: 0, - }, - ], - parent_monotonic_counters: vec![], - monotonic_counter: 1, - branch: 0, - source_node: "__initialize__".to_string(), - }; - - let result = dispatch_and_mutate_state(&g, &mut state, &c); - // The dispatch should return no operations - assert_eq!(result.operations.len(), 0); - - - // Fill the second path - let c = ChangeValueWithCounter { - filled_values: vec![ - ChangeValue { - path: Some(Path { - address: vec!["source".to_string(), "path2".to_string()], - }), - value: Some(SerializedValue{ val: Some(Val::String("value".to_string()))}), - branch: 0, - }, - ], - parent_monotonic_counters: vec![], - monotonic_counter: 1, - branch: 0, - source_node: "__initialize__".to_string(), - }; - let result = dispatch_and_mutate_state(&g, &mut state, &c); - // This should now return an operation because both paths have been satisfied - assert_eq!(result.operations.len(), 1); - assert_eq!(result.operations[0], NodeWillExecute { - source_node: "test_node".to_string(), - change_values_used_in_execution: vec![ - WrappedChangeValue { - monotonic_counter: 1, - change_value: Some(ChangeValue { - path: Some(Path { - address: vec!["source".to_string(), "path1".to_string()], - }), - value: Some(SerializedValue{ val: Some(Val::String("value".to_string()))}), - branch: 0, - }), - }, - WrappedChangeValue { - monotonic_counter: 1, - change_value: Some(ChangeValue { - path: Some(Path { - address: vec!["source".to_string(), "path2".to_string()], - }), - value: Some(SerializedValue{ val: Some(Val::String("value".to_string()))}), - branch: 0, - }), - }, - ], - matched_query_index: 0, - }); - } - -} diff --git a/toolchain/prompt-graph-core/src/generated_protobufs/promptgraph.rs b/toolchain/prompt-graph-core/src/generated_protobufs/promptgraph.rs deleted file mode 100644 index 246d21a..0000000 --- a/toolchain/prompt-graph-core/src/generated_protobufs/promptgraph.rs +++ /dev/null @@ -1,2569 +0,0 @@ -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Query { - #[prost(string, optional, tag = "1")] - pub query: ::core::option::Option<::prost::alloc::string::String>, -} -/// Processed version of the Query -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct QueryPaths { - #[prost(string, tag = "1")] - pub node: ::prost::alloc::string::String, - #[prost(message, repeated, tag = "2")] - pub path: ::prost::alloc::vec::Vec, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct OutputType { - #[prost(string, tag = "2")] - pub output: ::prost::alloc::string::String, -} -/// Processed version of the OutputType -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct OutputPaths { - #[prost(string, tag = "1")] - pub node: ::prost::alloc::string::String, - #[prost(message, repeated, tag = "2")] - pub path: ::prost::alloc::vec::Vec, -} -/// Alias is a reference to another node, any value set -/// on this node will propagate for the alias as well -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PromptGraphAlias { - #[prost(string, tag = "2")] - pub from: ::prost::alloc::string::String, - #[prost(string, tag = "3")] - pub to: ::prost::alloc::string::String, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PromptGraphConstant { - #[prost(message, optional, tag = "2")] - pub value: ::core::option::Option, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PromptGraphVar {} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PromptGraphOutputValue {} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PromptGraphNodeCodeSourceCode { - #[prost(enumeration = "SupportedSourceCodeLanguages", tag = "1")] - pub language: i32, - #[prost(string, tag = "2")] - pub source_code: ::prost::alloc::string::String, - #[prost(bool, tag = "3")] - pub template: bool, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PromptGraphParameterNode {} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PromptGraphMap { - #[prost(string, tag = "4")] - pub path: ::prost::alloc::string::String, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PromptGraphNodeCode { - #[prost(oneof = "prompt_graph_node_code::Source", tags = "6, 7, 8")] - pub source: ::core::option::Option, -} -/// Nested message and enum types in `PromptGraphNodeCode`. -pub mod prompt_graph_node_code { - #[derive(serde::Deserialize, serde::Serialize)] - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Source { - #[prost(message, tag = "6")] - SourceCode(super::PromptGraphNodeCodeSourceCode), - #[prost(bytes, tag = "7")] - Zipfile(::prost::alloc::vec::Vec), - #[prost(string, tag = "8")] - S3Path(::prost::alloc::string::String), - } -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PromptGraphNodeLoader { - #[prost(oneof = "prompt_graph_node_loader::LoadFrom", tags = "1")] - pub load_from: ::core::option::Option, -} -/// Nested message and enum types in `PromptGraphNodeLoader`. -pub mod prompt_graph_node_loader { - #[derive(serde::Deserialize, serde::Serialize)] - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum LoadFrom { - /// Load a zip file, decompress it, and make the paths available as keys - #[prost(bytes, tag = "1")] - ZipfileBytes(::prost::alloc::vec::Vec), - } -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PromptGraphNodeCustom { - #[prost(string, tag = "1")] - pub type_name: ::prost::alloc::string::String, -} -/// TODO: we should allow the user to freely manipulate wall-clock time -/// Output value of this should just be the timestamp -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PromptGraphNodeSchedule { - #[prost(oneof = "prompt_graph_node_schedule::Policy", tags = "1, 2, 3")] - pub policy: ::core::option::Option, -} -/// Nested message and enum types in `PromptGraphNodeSchedule`. -pub mod prompt_graph_node_schedule { - #[derive(serde::Deserialize, serde::Serialize)] - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Policy { - #[prost(string, tag = "1")] - Crontab(::prost::alloc::string::String), - #[prost(string, tag = "2")] - NaturalLanguage(::prost::alloc::string::String), - #[prost(string, tag = "3")] - EveryMs(::prost::alloc::string::String), - } -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PromptGraphNodePrompt { - #[prost(string, tag = "4")] - pub template: ::prost::alloc::string::String, - #[prost(float, tag = "7")] - pub temperature: f32, - #[prost(float, tag = "8")] - pub top_p: f32, - #[prost(int32, tag = "9")] - pub max_tokens: i32, - #[prost(float, tag = "10")] - pub presence_penalty: f32, - #[prost(float, tag = "11")] - pub frequency_penalty: f32, - /// TODO: set the user token - /// TODO: support logit bias - #[prost(string, repeated, tag = "12")] - pub stop: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, - #[prost(oneof = "prompt_graph_node_prompt::Model", tags = "5, 6")] - pub model: ::core::option::Option, -} -/// Nested message and enum types in `PromptGraphNodePrompt`. -pub mod prompt_graph_node_prompt { - #[derive(serde::Deserialize, serde::Serialize)] - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Model { - #[prost(enumeration = "super::SupportedChatModel", tag = "5")] - ChatModel(i32), - #[prost(enumeration = "super::SupportedCompletionModel", tag = "6")] - CompletionModel(i32), - } -} -/// TODO: this expects a selector for the query? - no its a template and you build that -/// TODO: what about the output type? pre-defined -/// TODO: what about the metadata? -/// TODO: metadata could be an independent query, or it could instead be a template too -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PromptGraphNodeMemory { - #[prost(string, tag = "3")] - pub collection_name: ::prost::alloc::string::String, - #[prost(string, tag = "4")] - pub template: ::prost::alloc::string::String, - #[prost(enumeration = "MemoryAction", tag = "7")] - pub action: i32, - #[prost(oneof = "prompt_graph_node_memory::EmbeddingModel", tags = "5")] - pub embedding_model: ::core::option::Option< - prompt_graph_node_memory::EmbeddingModel, - >, - #[prost(oneof = "prompt_graph_node_memory::VectorDbProvider", tags = "6")] - pub vector_db_provider: ::core::option::Option< - prompt_graph_node_memory::VectorDbProvider, - >, -} -/// Nested message and enum types in `PromptGraphNodeMemory`. -pub mod prompt_graph_node_memory { - #[derive(serde::Deserialize, serde::Serialize)] - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum EmbeddingModel { - #[prost(enumeration = "super::SupportedEmebddingModel", tag = "5")] - Model(i32), - } - #[derive(serde::Deserialize, serde::Serialize)] - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum VectorDbProvider { - #[prost(enumeration = "super::SupportedVectorDatabase", tag = "6")] - Db(i32), - } -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PromptGraphNodeObservation { - #[prost(string, tag = "4")] - pub integration: ::prost::alloc::string::String, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PromptGraphNodeComponent { - #[prost(oneof = "prompt_graph_node_component::Transclusion", tags = "4, 5, 6")] - pub transclusion: ::core::option::Option, -} -/// Nested message and enum types in `PromptGraphNodeComponent`. -pub mod prompt_graph_node_component { - #[derive(serde::Deserialize, serde::Serialize)] - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Transclusion { - #[prost(message, tag = "4")] - InlineFile(super::File), - #[prost(bytes, tag = "5")] - BytesReference(::prost::alloc::vec::Vec), - #[prost(string, tag = "6")] - S3PathReference(::prost::alloc::string::String), - } -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PromptGraphNodeEcho {} -/// TODO: configure resolving joins -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PromptGraphNodeJoin {} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ItemCore { - #[prost(string, tag = "1")] - pub name: ::prost::alloc::string::String, - #[prost(message, repeated, tag = "2")] - pub triggers: ::prost::alloc::vec::Vec, - #[prost(string, repeated, tag = "3")] - pub output_tables: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, - #[prost(message, optional, tag = "4")] - pub output: ::core::option::Option, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Item { - #[prost(message, optional, tag = "1")] - pub core: ::core::option::Option, - #[prost( - oneof = "item::Item", - tags = "2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17" - )] - pub item: ::core::option::Option, -} -/// Nested message and enum types in `Item`. -pub mod item { - #[derive(serde::Deserialize, serde::Serialize)] - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Item { - #[prost(message, tag = "2")] - Alias(super::PromptGraphAlias), - #[prost(message, tag = "3")] - Map(super::PromptGraphMap), - #[prost(message, tag = "4")] - Constant(super::PromptGraphConstant), - #[prost(message, tag = "5")] - Variable(super::PromptGraphVar), - #[prost(message, tag = "6")] - Output(super::PromptGraphOutputValue), - /// TODO: delete above this line - #[prost(message, tag = "7")] - NodeCode(super::PromptGraphNodeCode), - #[prost(message, tag = "8")] - NodePrompt(super::PromptGraphNodePrompt), - #[prost(message, tag = "9")] - NodeMemory(super::PromptGraphNodeMemory), - #[prost(message, tag = "10")] - NodeComponent(super::PromptGraphNodeComponent), - #[prost(message, tag = "11")] - NodeObservation(super::PromptGraphNodeObservation), - #[prost(message, tag = "12")] - NodeParameter(super::PromptGraphParameterNode), - #[prost(message, tag = "13")] - NodeEcho(super::PromptGraphNodeEcho), - #[prost(message, tag = "14")] - NodeLoader(super::PromptGraphNodeLoader), - #[prost(message, tag = "15")] - NodeCustom(super::PromptGraphNodeCustom), - #[prost(message, tag = "16")] - NodeJoin(super::PromptGraphNodeJoin), - #[prost(message, tag = "17")] - NodeSchedule(super::PromptGraphNodeSchedule), - } -} -/// TODO: add a flag for 'Cleaned', 'Dirty', 'Validated' -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct File { - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, - #[prost(message, repeated, tag = "2")] - pub nodes: ::prost::alloc::vec::Vec, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Path { - #[prost(string, repeated, tag = "1")] - pub address: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct TypeDefinition { - #[prost(oneof = "type_definition::Type", tags = "1, 2, 3, 4, 5, 6, 7")] - pub r#type: ::core::option::Option, -} -/// Nested message and enum types in `TypeDefinition`. -pub mod type_definition { - #[derive(serde::Deserialize, serde::Serialize)] - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Type { - #[prost(message, tag = "1")] - Primitive(super::PrimitiveType), - #[prost(message, tag = "2")] - Array(::prost::alloc::boxed::Box), - #[prost(message, tag = "3")] - Object(super::ObjectType), - #[prost(message, tag = "4")] - Union(super::UnionType), - #[prost(message, tag = "5")] - Intersection(super::IntersectionType), - #[prost(message, tag = "6")] - Optional(::prost::alloc::boxed::Box), - #[prost(message, tag = "7")] - Enum(super::EnumType), - } -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PrimitiveType { - #[prost(oneof = "primitive_type::Primitive", tags = "1, 2, 3, 4, 5")] - pub primitive: ::core::option::Option, -} -/// Nested message and enum types in `PrimitiveType`. -pub mod primitive_type { - #[derive(serde::Deserialize, serde::Serialize)] - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Primitive { - #[prost(bool, tag = "1")] - IsString(bool), - #[prost(bool, tag = "2")] - IsNumber(bool), - #[prost(bool, tag = "3")] - IsBoolean(bool), - #[prost(bool, tag = "4")] - IsNull(bool), - #[prost(bool, tag = "5")] - IsUndefined(bool), - } -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ArrayType { - #[prost(message, optional, boxed, tag = "1")] - pub r#type: ::core::option::Option<::prost::alloc::boxed::Box>, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ObjectType { - #[prost(map = "string, message", tag = "1")] - pub fields: ::std::collections::HashMap< - ::prost::alloc::string::String, - TypeDefinition, - >, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct UnionType { - #[prost(message, repeated, tag = "1")] - pub types: ::prost::alloc::vec::Vec, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct IntersectionType { - #[prost(message, repeated, tag = "1")] - pub types: ::prost::alloc::vec::Vec, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct OptionalType { - #[prost(message, optional, boxed, tag = "1")] - pub r#type: ::core::option::Option<::prost::alloc::boxed::Box>, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct EnumType { - #[prost(map = "string, string", tag = "1")] - pub values: ::std::collections::HashMap< - ::prost::alloc::string::String, - ::prost::alloc::string::String, - >, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct SerializedValueArray { - #[prost(message, repeated, tag = "1")] - pub values: ::prost::alloc::vec::Vec, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct SerializedValueObject { - #[prost(map = "string, message", tag = "1")] - pub values: ::std::collections::HashMap< - ::prost::alloc::string::String, - SerializedValue, - >, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct SerializedValue { - #[prost(oneof = "serialized_value::Val", tags = "2, 3, 4, 5, 6, 7")] - pub val: ::core::option::Option, -} -/// Nested message and enum types in `SerializedValue`. -pub mod serialized_value { - #[derive(serde::Deserialize, serde::Serialize)] - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Val { - #[prost(float, tag = "2")] - Float(f32), - #[prost(int32, tag = "3")] - Number(i32), - #[prost(string, tag = "4")] - String(::prost::alloc::string::String), - #[prost(bool, tag = "5")] - Boolean(bool), - #[prost(message, tag = "6")] - Array(super::SerializedValueArray), - #[prost(message, tag = "7")] - Object(super::SerializedValueObject), - } -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ChangeValue { - #[prost(message, optional, tag = "1")] - pub path: ::core::option::Option, - #[prost(message, optional, tag = "2")] - pub value: ::core::option::Option, - #[prost(uint64, tag = "3")] - pub branch: u64, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct WrappedChangeValue { - #[prost(uint64, tag = "3")] - pub monotonic_counter: u64, - #[prost(message, optional, tag = "4")] - pub change_value: ::core::option::Option, -} -/// Computation of a node -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct NodeWillExecute { - #[prost(string, tag = "1")] - pub source_node: ::prost::alloc::string::String, - #[prost(message, repeated, tag = "2")] - pub change_values_used_in_execution: ::prost::alloc::vec::Vec, - #[prost(uint64, tag = "3")] - pub matched_query_index: u64, -} -/// Group of node computations to run -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct DispatchResult { - #[prost(message, repeated, tag = "1")] - pub operations: ::prost::alloc::vec::Vec, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct NodeWillExecuteOnBranch { - #[prost(uint64, tag = "1")] - pub branch: u64, - #[prost(uint64, tag = "2")] - pub counter: u64, - #[prost(string, optional, tag = "3")] - pub custom_node_type_name: ::core::option::Option<::prost::alloc::string::String>, - #[prost(message, optional, tag = "4")] - pub node: ::core::option::Option, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ChangeValueWithCounter { - #[prost(message, repeated, tag = "1")] - pub filled_values: ::prost::alloc::vec::Vec, - #[prost(uint64, repeated, tag = "2")] - pub parent_monotonic_counters: ::prost::alloc::vec::Vec, - #[prost(uint64, tag = "3")] - pub monotonic_counter: u64, - #[prost(uint64, tag = "4")] - pub branch: u64, - #[prost(string, tag = "5")] - pub source_node: ::prost::alloc::string::String, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct CounterWithPath { - #[prost(uint64, tag = "1")] - pub monotonic_counter: u64, - #[prost(message, optional, tag = "2")] - pub path: ::core::option::Option, -} -/// Input proposals -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct InputProposal { - #[prost(string, tag = "1")] - pub name: ::prost::alloc::string::String, - #[prost(message, optional, tag = "2")] - pub output: ::core::option::Option, - #[prost(uint64, tag = "3")] - pub counter: u64, - #[prost(uint64, tag = "4")] - pub branch: u64, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct RequestInputProposalResponse { - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, - #[prost(uint64, tag = "2")] - pub proposal_counter: u64, - #[prost(message, repeated, tag = "3")] - pub changes: ::prost::alloc::vec::Vec, - #[prost(uint64, tag = "4")] - pub branch: u64, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct DivergentBranch { - #[prost(uint64, tag = "1")] - pub branch: u64, - #[prost(uint64, tag = "2")] - pub diverges_at_counter: u64, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Branch { - #[prost(uint64, tag = "1")] - pub id: u64, - #[prost(uint64, repeated, tag = "2")] - pub source_branch_ids: ::prost::alloc::vec::Vec, - #[prost(message, repeated, tag = "3")] - pub divergent_branches: ::prost::alloc::vec::Vec, - #[prost(uint64, tag = "4")] - pub diverges_at_counter: u64, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Empty {} -/// This is the return value from api calls that reports the current counter and branch the operation -/// was performed on. -#[derive(serde::Deserialize, serde::Serialize)] -#[derive(typescript_type_def::TypeDef)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ExecutionStatus { - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, - #[prost(uint64, tag = "2")] - pub monotonic_counter: u64, - #[prost(uint64, tag = "3")] - pub branch: u64, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct FileAddressedChangeValueWithCounter { - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, - #[prost(string, tag = "2")] - pub node_name: ::prost::alloc::string::String, - #[prost(uint64, tag = "3")] - pub branch: u64, - #[prost(uint64, tag = "4")] - pub counter: u64, - #[prost(message, optional, tag = "5")] - pub change: ::core::option::Option, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct RequestOnlyId { - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, - #[prost(uint64, tag = "2")] - pub branch: u64, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct FilteredPollNodeWillExecuteEventsRequest { - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct RequestAtFrame { - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, - #[prost(uint64, tag = "2")] - pub frame: u64, - #[prost(uint64, tag = "3")] - pub branch: u64, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct RequestNewBranch { - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, - #[prost(uint64, tag = "2")] - pub source_branch_id: u64, - #[prost(uint64, tag = "3")] - pub diverges_at_counter: u64, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct RequestListBranches { - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ListBranchesRes { - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, - #[prost(message, repeated, tag = "2")] - pub branches: ::prost::alloc::vec::Vec, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct RequestFileMerge { - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, - #[prost(message, optional, tag = "2")] - pub file: ::core::option::Option, - #[prost(uint64, tag = "3")] - pub branch: u64, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ParquetFile { - #[prost(bytes = "vec", tag = "1")] - pub data: ::prost::alloc::vec::Vec, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct QueryAtFrame { - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, - #[prost(message, optional, tag = "2")] - pub query: ::core::option::Option, - #[prost(uint64, tag = "3")] - pub frame: u64, - #[prost(uint64, tag = "4")] - pub branch: u64, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct QueryAtFrameResponse { - #[prost(message, repeated, tag = "1")] - pub values: ::prost::alloc::vec::Vec, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct RequestAckNodeWillExecuteEvent { - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, - #[prost(uint64, tag = "3")] - pub branch: u64, - #[prost(uint64, tag = "4")] - pub counter: u64, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct RespondPollNodeWillExecuteEvents { - #[prost(message, repeated, tag = "1")] - pub node_will_execute_events: ::prost::alloc::vec::Vec, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PromptLibraryRecord { - #[prost(message, optional, tag = "1")] - pub record: ::core::option::Option, - #[prost(uint64, tag = "3")] - pub version_counter: u64, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct UpsertPromptLibraryRecord { - #[prost(string, tag = "1")] - pub template: ::prost::alloc::string::String, - #[prost(string, tag = "2")] - pub name: ::prost::alloc::string::String, - #[prost(string, tag = "3")] - pub id: ::prost::alloc::string::String, - #[prost(string, optional, tag = "4")] - pub description: ::core::option::Option<::prost::alloc::string::String>, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ListRegisteredGraphsResponse { - #[prost(string, repeated, tag = "1")] - pub ids: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, -} -#[derive(serde::Deserialize, serde::Serialize)] -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum SupportedChatModel { - Gpt4 = 0, - Gpt40314 = 1, - Gpt432k = 2, - Gpt432k0314 = 3, - Gpt35Turbo = 4, - Gpt35Turbo0301 = 5, -} -impl SupportedChatModel { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - SupportedChatModel::Gpt4 => "GPT_4", - SupportedChatModel::Gpt40314 => "GPT_4_0314", - SupportedChatModel::Gpt432k => "GPT_4_32K", - SupportedChatModel::Gpt432k0314 => "GPT_4_32K_0314", - SupportedChatModel::Gpt35Turbo => "GPT_3_5_TURBO", - SupportedChatModel::Gpt35Turbo0301 => "GPT_3_5_TURBO_0301", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "GPT_4" => Some(Self::Gpt4), - "GPT_4_0314" => Some(Self::Gpt40314), - "GPT_4_32K" => Some(Self::Gpt432k), - "GPT_4_32K_0314" => Some(Self::Gpt432k0314), - "GPT_3_5_TURBO" => Some(Self::Gpt35Turbo), - "GPT_3_5_TURBO_0301" => Some(Self::Gpt35Turbo0301), - _ => None, - } - } -} -#[derive(serde::Deserialize, serde::Serialize)] -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum SupportedCompletionModel { - TextDavinci003 = 0, - TextDavinci002 = 1, - TextCurie001 = 2, - TextBabbage001 = 3, - TextAda00 = 4, -} -impl SupportedCompletionModel { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - SupportedCompletionModel::TextDavinci003 => "TEXT_DAVINCI_003", - SupportedCompletionModel::TextDavinci002 => "TEXT_DAVINCI_002", - SupportedCompletionModel::TextCurie001 => "TEXT_CURIE_001", - SupportedCompletionModel::TextBabbage001 => "TEXT_BABBAGE_001", - SupportedCompletionModel::TextAda00 => "TEXT_ADA_00", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "TEXT_DAVINCI_003" => Some(Self::TextDavinci003), - "TEXT_DAVINCI_002" => Some(Self::TextDavinci002), - "TEXT_CURIE_001" => Some(Self::TextCurie001), - "TEXT_BABBAGE_001" => Some(Self::TextBabbage001), - "TEXT_ADA_00" => Some(Self::TextAda00), - _ => None, - } - } -} -#[derive(serde::Deserialize, serde::Serialize)] -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum SupportedEmebddingModel { - TextEmbeddingAda002 = 0, - TextSearchAdaDoc001 = 1, -} -impl SupportedEmebddingModel { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - SupportedEmebddingModel::TextEmbeddingAda002 => "TEXT_EMBEDDING_ADA_002", - SupportedEmebddingModel::TextSearchAdaDoc001 => "TEXT_SEARCH_ADA_DOC_001", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "TEXT_EMBEDDING_ADA_002" => Some(Self::TextEmbeddingAda002), - "TEXT_SEARCH_ADA_DOC_001" => Some(Self::TextSearchAdaDoc001), - _ => None, - } - } -} -#[derive(serde::Deserialize, serde::Serialize)] -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum SupportedVectorDatabase { - InMemory = 0, - Chroma = 1, - Pineconedb = 2, - Qdrant = 3, -} -impl SupportedVectorDatabase { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - SupportedVectorDatabase::InMemory => "IN_MEMORY", - SupportedVectorDatabase::Chroma => "CHROMA", - SupportedVectorDatabase::Pineconedb => "PINECONEDB", - SupportedVectorDatabase::Qdrant => "QDRANT", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "IN_MEMORY" => Some(Self::InMemory), - "CHROMA" => Some(Self::Chroma), - "PINECONEDB" => Some(Self::Pineconedb), - "QDRANT" => Some(Self::Qdrant), - _ => None, - } - } -} -#[derive(serde::Deserialize, serde::Serialize)] -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum SupportedSourceCodeLanguages { - Deno = 0, - Starlark = 1, -} -impl SupportedSourceCodeLanguages { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - SupportedSourceCodeLanguages::Deno => "DENO", - SupportedSourceCodeLanguages::Starlark => "STARLARK", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "DENO" => Some(Self::Deno), - "STARLARK" => Some(Self::Starlark), - _ => None, - } - } -} -#[derive(serde::Deserialize, serde::Serialize)] -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum MemoryAction { - Read = 0, - Write = 1, - Delete = 2, -} -impl MemoryAction { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - MemoryAction::Read => "READ", - MemoryAction::Write => "WRITE", - MemoryAction::Delete => "DELETE", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "READ" => Some(Self::Read), - "WRITE" => Some(Self::Write), - "DELETE" => Some(Self::Delete), - _ => None, - } - } -} -/// Generated client implementations. -pub mod execution_runtime_client { - #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] - use tonic::codegen::*; - use tonic::codegen::http::Uri; - /// API: - #[derive(Debug, Clone)] - pub struct ExecutionRuntimeClient { - inner: tonic::client::Grpc, - } - impl ExecutionRuntimeClient { - /// Attempt to create a new client by connecting to a given endpoint. - pub async fn connect(dst: D) -> Result - where - D: TryInto, - D::Error: Into, - { - let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; - Ok(Self::new(conn)) - } - } - impl ExecutionRuntimeClient - where - T: tonic::client::GrpcService, - T::Error: Into, - T::ResponseBody: Body + Send + 'static, - ::Error: Into + Send, - { - pub fn new(inner: T) -> Self { - let inner = tonic::client::Grpc::new(inner); - Self { inner } - } - pub fn with_origin(inner: T, origin: Uri) -> Self { - let inner = tonic::client::Grpc::with_origin(inner, origin); - Self { inner } - } - pub fn with_interceptor( - inner: T, - interceptor: F, - ) -> ExecutionRuntimeClient> - where - F: tonic::service::Interceptor, - T::ResponseBody: Default, - T: tonic::codegen::Service< - http::Request, - Response = http::Response< - >::ResponseBody, - >, - >, - , - >>::Error: Into + Send + Sync, - { - ExecutionRuntimeClient::new(InterceptedService::new(inner, interceptor)) - } - /// Compress requests with the given encoding. - /// - /// This requires the server to support it otherwise it might respond with an - /// error. - #[must_use] - pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.inner = self.inner.send_compressed(encoding); - self - } - /// Enable decompressing responses. - #[must_use] - pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.inner = self.inner.accept_compressed(encoding); - self - } - /// Limits the maximum size of a decoded message. - /// - /// Default: `4MB` - #[must_use] - pub fn max_decoding_message_size(mut self, limit: usize) -> Self { - self.inner = self.inner.max_decoding_message_size(limit); - self - } - /// Limits the maximum size of an encoded message. - /// - /// Default: `usize::MAX` - #[must_use] - pub fn max_encoding_message_size(mut self, limit: usize) -> Self { - self.inner = self.inner.max_encoding_message_size(limit); - self - } - pub async fn run_query( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/promptgraph.ExecutionRuntime/RunQuery", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("promptgraph.ExecutionRuntime", "RunQuery")); - self.inner.unary(req, path, codec).await - } - /// * Merge a new file - if an existing file is available at the id, will merge the new file into the existing one - pub async fn merge( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/promptgraph.ExecutionRuntime/Merge", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("promptgraph.ExecutionRuntime", "Merge")); - self.inner.unary(req, path, codec).await - } - /// * Get the current graph state of a file at a branch and counter position - pub async fn current_file_state( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/promptgraph.ExecutionRuntime/CurrentFileState", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new("promptgraph.ExecutionRuntime", "CurrentFileState"), - ); - self.inner.unary(req, path, codec).await - } - /// * Get the parquet history for a specific branch and Id - returns bytes - pub async fn get_parquet_history( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/promptgraph.ExecutionRuntime/GetParquetHistory", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new("promptgraph.ExecutionRuntime", "GetParquetHistory"), - ); - self.inner.unary(req, path, codec).await - } - /// * Resume execution - pub async fn play( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/promptgraph.ExecutionRuntime/Play", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("promptgraph.ExecutionRuntime", "Play")); - self.inner.unary(req, path, codec).await - } - /// * Pause execution - pub async fn pause( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/promptgraph.ExecutionRuntime/Pause", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("promptgraph.ExecutionRuntime", "Pause")); - self.inner.unary(req, path, codec).await - } - /// * Split history into a separate branch - pub async fn branch( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/promptgraph.ExecutionRuntime/Branch", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("promptgraph.ExecutionRuntime", "Branch")); - self.inner.unary(req, path, codec).await - } - /// * Get all branches - pub async fn list_branches( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/promptgraph.ExecutionRuntime/ListBranches", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("promptgraph.ExecutionRuntime", "ListBranches")); - self.inner.unary(req, path, codec).await - } - /// * List all registered files - pub async fn list_registered_graphs( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/promptgraph.ExecutionRuntime/ListRegisteredGraphs", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new( - "promptgraph.ExecutionRuntime", - "ListRegisteredGraphs", - ), - ); - self.inner.unary(req, path, codec).await - } - /// * Receive a stream of input proposals <- this is a server-side stream - pub async fn list_input_proposals( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response>, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/promptgraph.ExecutionRuntime/ListInputProposals", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new("promptgraph.ExecutionRuntime", "ListInputProposals"), - ); - self.inner.server_streaming(req, path, codec).await - } - /// * Push responses to input proposals (these wait for some input from a host until they're resolved) <- RPC client to server - pub async fn respond_to_input_proposal( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/promptgraph.ExecutionRuntime/RespondToInputProposal", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new( - "promptgraph.ExecutionRuntime", - "RespondToInputProposal", - ), - ); - self.inner.unary(req, path, codec).await - } - /// * Observe the stream of execution events <- this is a server-side stream - pub async fn list_change_events( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response>, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/promptgraph.ExecutionRuntime/ListChangeEvents", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new("promptgraph.ExecutionRuntime", "ListChangeEvents"), - ); - self.inner.server_streaming(req, path, codec).await - } - pub async fn list_node_will_execute_events( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response>, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/promptgraph.ExecutionRuntime/ListNodeWillExecuteEvents", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new( - "promptgraph.ExecutionRuntime", - "ListNodeWillExecuteEvents", - ), - ); - self.inner.server_streaming(req, path, codec).await - } - /// * Observe when the server thinks our local node implementation should execute and with what changes - pub async fn poll_custom_node_will_execute_events( - &mut self, - request: impl tonic::IntoRequest< - super::FilteredPollNodeWillExecuteEventsRequest, - >, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/promptgraph.ExecutionRuntime/PollCustomNodeWillExecuteEvents", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new( - "promptgraph.ExecutionRuntime", - "PollCustomNodeWillExecuteEvents", - ), - ); - self.inner.unary(req, path, codec).await - } - pub async fn ack_node_will_execute_event( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/promptgraph.ExecutionRuntime/AckNodeWillExecuteEvent", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new( - "promptgraph.ExecutionRuntime", - "AckNodeWillExecuteEvent", - ), - ); - self.inner.unary(req, path, codec).await - } - /// * Receive events from workers <- this is an RPC client to server, we don't need to wait for a response from the server - pub async fn push_worker_event( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/promptgraph.ExecutionRuntime/PushWorkerEvent", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new("promptgraph.ExecutionRuntime", "PushWorkerEvent"), - ); - self.inner.unary(req, path, codec).await - } - pub async fn push_template_partial( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/promptgraph.ExecutionRuntime/PushTemplatePartial", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new( - "promptgraph.ExecutionRuntime", - "PushTemplatePartial", - ), - ); - self.inner.unary(req, path, codec).await - } - } -} -/// Generated server implementations. -pub mod execution_runtime_server { - #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] - use tonic::codegen::*; - /// Generated trait containing gRPC methods that should be implemented for use with ExecutionRuntimeServer. - #[async_trait] - pub trait ExecutionRuntime: Send + Sync + 'static { - async fn run_query( - &self, - request: tonic::Request, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - >; - /// * Merge a new file - if an existing file is available at the id, will merge the new file into the existing one - async fn merge( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status>; - /// * Get the current graph state of a file at a branch and counter position - async fn current_file_state( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status>; - /// * Get the parquet history for a specific branch and Id - returns bytes - async fn get_parquet_history( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status>; - /// * Resume execution - async fn play( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status>; - /// * Pause execution - async fn pause( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status>; - /// * Split history into a separate branch - async fn branch( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status>; - /// * Get all branches - async fn list_branches( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status>; - /// * List all registered files - async fn list_registered_graphs( - &self, - request: tonic::Request, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - >; - /// Server streaming response type for the ListInputProposals method. - type ListInputProposalsStream: futures_core::Stream< - Item = std::result::Result, - > - + Send - + 'static; - /// * Receive a stream of input proposals <- this is a server-side stream - async fn list_input_proposals( - &self, - request: tonic::Request, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - >; - /// * Push responses to input proposals (these wait for some input from a host until they're resolved) <- RPC client to server - async fn respond_to_input_proposal( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status>; - /// Server streaming response type for the ListChangeEvents method. - type ListChangeEventsStream: futures_core::Stream< - Item = std::result::Result, - > - + Send - + 'static; - /// * Observe the stream of execution events <- this is a server-side stream - async fn list_change_events( - &self, - request: tonic::Request, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - >; - /// Server streaming response type for the ListNodeWillExecuteEvents method. - type ListNodeWillExecuteEventsStream: futures_core::Stream< - Item = std::result::Result, - > - + Send - + 'static; - async fn list_node_will_execute_events( - &self, - request: tonic::Request, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - >; - /// * Observe when the server thinks our local node implementation should execute and with what changes - async fn poll_custom_node_will_execute_events( - &self, - request: tonic::Request, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - >; - async fn ack_node_will_execute_event( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status>; - /// * Receive events from workers <- this is an RPC client to server, we don't need to wait for a response from the server - async fn push_worker_event( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status>; - async fn push_template_partial( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status>; - } - /// API: - #[derive(Debug)] - pub struct ExecutionRuntimeServer { - inner: _Inner, - accept_compression_encodings: EnabledCompressionEncodings, - send_compression_encodings: EnabledCompressionEncodings, - max_decoding_message_size: Option, - max_encoding_message_size: Option, - } - struct _Inner(Arc); - impl ExecutionRuntimeServer { - pub fn new(inner: T) -> Self { - Self::from_arc(Arc::new(inner)) - } - pub fn from_arc(inner: Arc) -> Self { - let inner = _Inner(inner); - Self { - inner, - accept_compression_encodings: Default::default(), - send_compression_encodings: Default::default(), - max_decoding_message_size: None, - max_encoding_message_size: None, - } - } - pub fn with_interceptor( - inner: T, - interceptor: F, - ) -> InterceptedService - where - F: tonic::service::Interceptor, - { - InterceptedService::new(Self::new(inner), interceptor) - } - /// Enable decompressing requests with the given encoding. - #[must_use] - pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.accept_compression_encodings.enable(encoding); - self - } - /// Compress responses with the given encoding, if the client supports it. - #[must_use] - pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.send_compression_encodings.enable(encoding); - self - } - /// Limits the maximum size of a decoded message. - /// - /// Default: `4MB` - #[must_use] - pub fn max_decoding_message_size(mut self, limit: usize) -> Self { - self.max_decoding_message_size = Some(limit); - self - } - /// Limits the maximum size of an encoded message. - /// - /// Default: `usize::MAX` - #[must_use] - pub fn max_encoding_message_size(mut self, limit: usize) -> Self { - self.max_encoding_message_size = Some(limit); - self - } - } - impl tonic::codegen::Service> for ExecutionRuntimeServer - where - T: ExecutionRuntime, - B: Body + Send + 'static, - B::Error: Into + Send + 'static, - { - type Response = http::Response; - type Error = std::convert::Infallible; - type Future = BoxFuture; - fn poll_ready( - &mut self, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } - fn call(&mut self, req: http::Request) -> Self::Future { - let inner = self.inner.clone(); - match req.uri().path() { - "/promptgraph.ExecutionRuntime/RunQuery" => { - #[allow(non_camel_case_types)] - struct RunQuerySvc(pub Arc); - impl< - T: ExecutionRuntime, - > tonic::server::UnaryService - for RunQuerySvc { - type Response = super::QueryAtFrameResponse; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { (*inner).run_query(request).await }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let inner = inner.0; - let method = RunQuerySvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/promptgraph.ExecutionRuntime/Merge" => { - #[allow(non_camel_case_types)] - struct MergeSvc(pub Arc); - impl< - T: ExecutionRuntime, - > tonic::server::UnaryService - for MergeSvc { - type Response = super::ExecutionStatus; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { (*inner).merge(request).await }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let inner = inner.0; - let method = MergeSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/promptgraph.ExecutionRuntime/CurrentFileState" => { - #[allow(non_camel_case_types)] - struct CurrentFileStateSvc(pub Arc); - impl< - T: ExecutionRuntime, - > tonic::server::UnaryService - for CurrentFileStateSvc { - type Response = super::File; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - (*inner).current_file_state(request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let inner = inner.0; - let method = CurrentFileStateSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/promptgraph.ExecutionRuntime/GetParquetHistory" => { - #[allow(non_camel_case_types)] - struct GetParquetHistorySvc(pub Arc); - impl< - T: ExecutionRuntime, - > tonic::server::UnaryService - for GetParquetHistorySvc { - type Response = super::ParquetFile; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - (*inner).get_parquet_history(request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let inner = inner.0; - let method = GetParquetHistorySvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/promptgraph.ExecutionRuntime/Play" => { - #[allow(non_camel_case_types)] - struct PlaySvc(pub Arc); - impl< - T: ExecutionRuntime, - > tonic::server::UnaryService for PlaySvc { - type Response = super::ExecutionStatus; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { (*inner).play(request).await }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let inner = inner.0; - let method = PlaySvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/promptgraph.ExecutionRuntime/Pause" => { - #[allow(non_camel_case_types)] - struct PauseSvc(pub Arc); - impl< - T: ExecutionRuntime, - > tonic::server::UnaryService - for PauseSvc { - type Response = super::ExecutionStatus; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { (*inner).pause(request).await }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let inner = inner.0; - let method = PauseSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/promptgraph.ExecutionRuntime/Branch" => { - #[allow(non_camel_case_types)] - struct BranchSvc(pub Arc); - impl< - T: ExecutionRuntime, - > tonic::server::UnaryService - for BranchSvc { - type Response = super::ExecutionStatus; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { (*inner).branch(request).await }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let inner = inner.0; - let method = BranchSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/promptgraph.ExecutionRuntime/ListBranches" => { - #[allow(non_camel_case_types)] - struct ListBranchesSvc(pub Arc); - impl< - T: ExecutionRuntime, - > tonic::server::UnaryService - for ListBranchesSvc { - type Response = super::ListBranchesRes; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - (*inner).list_branches(request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let inner = inner.0; - let method = ListBranchesSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/promptgraph.ExecutionRuntime/ListRegisteredGraphs" => { - #[allow(non_camel_case_types)] - struct ListRegisteredGraphsSvc(pub Arc); - impl tonic::server::UnaryService - for ListRegisteredGraphsSvc { - type Response = super::ListRegisteredGraphsResponse; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - (*inner).list_registered_graphs(request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let inner = inner.0; - let method = ListRegisteredGraphsSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/promptgraph.ExecutionRuntime/ListInputProposals" => { - #[allow(non_camel_case_types)] - struct ListInputProposalsSvc(pub Arc); - impl< - T: ExecutionRuntime, - > tonic::server::ServerStreamingService - for ListInputProposalsSvc { - type Response = super::InputProposal; - type ResponseStream = T::ListInputProposalsStream; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - (*inner).list_input_proposals(request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let inner = inner.0; - let method = ListInputProposalsSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.server_streaming(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/promptgraph.ExecutionRuntime/RespondToInputProposal" => { - #[allow(non_camel_case_types)] - struct RespondToInputProposalSvc(pub Arc); - impl< - T: ExecutionRuntime, - > tonic::server::UnaryService - for RespondToInputProposalSvc { - type Response = super::Empty; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - (*inner).respond_to_input_proposal(request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let inner = inner.0; - let method = RespondToInputProposalSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/promptgraph.ExecutionRuntime/ListChangeEvents" => { - #[allow(non_camel_case_types)] - struct ListChangeEventsSvc(pub Arc); - impl< - T: ExecutionRuntime, - > tonic::server::ServerStreamingService - for ListChangeEventsSvc { - type Response = super::ChangeValueWithCounter; - type ResponseStream = T::ListChangeEventsStream; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - (*inner).list_change_events(request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let inner = inner.0; - let method = ListChangeEventsSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.server_streaming(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/promptgraph.ExecutionRuntime/ListNodeWillExecuteEvents" => { - #[allow(non_camel_case_types)] - struct ListNodeWillExecuteEventsSvc(pub Arc); - impl< - T: ExecutionRuntime, - > tonic::server::ServerStreamingService - for ListNodeWillExecuteEventsSvc { - type Response = super::NodeWillExecuteOnBranch; - type ResponseStream = T::ListNodeWillExecuteEventsStream; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - (*inner).list_node_will_execute_events(request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let inner = inner.0; - let method = ListNodeWillExecuteEventsSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.server_streaming(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/promptgraph.ExecutionRuntime/PollCustomNodeWillExecuteEvents" => { - #[allow(non_camel_case_types)] - struct PollCustomNodeWillExecuteEventsSvc( - pub Arc, - ); - impl< - T: ExecutionRuntime, - > tonic::server::UnaryService< - super::FilteredPollNodeWillExecuteEventsRequest, - > for PollCustomNodeWillExecuteEventsSvc { - type Response = super::RespondPollNodeWillExecuteEvents; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request< - super::FilteredPollNodeWillExecuteEventsRequest, - >, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - (*inner).poll_custom_node_will_execute_events(request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let inner = inner.0; - let method = PollCustomNodeWillExecuteEventsSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/promptgraph.ExecutionRuntime/AckNodeWillExecuteEvent" => { - #[allow(non_camel_case_types)] - struct AckNodeWillExecuteEventSvc(pub Arc); - impl< - T: ExecutionRuntime, - > tonic::server::UnaryService - for AckNodeWillExecuteEventSvc { - type Response = super::ExecutionStatus; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request< - super::RequestAckNodeWillExecuteEvent, - >, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - (*inner).ack_node_will_execute_event(request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let inner = inner.0; - let method = AckNodeWillExecuteEventSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/promptgraph.ExecutionRuntime/PushWorkerEvent" => { - #[allow(non_camel_case_types)] - struct PushWorkerEventSvc(pub Arc); - impl< - T: ExecutionRuntime, - > tonic::server::UnaryService< - super::FileAddressedChangeValueWithCounter, - > for PushWorkerEventSvc { - type Response = super::ExecutionStatus; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request< - super::FileAddressedChangeValueWithCounter, - >, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - (*inner).push_worker_event(request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let inner = inner.0; - let method = PushWorkerEventSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/promptgraph.ExecutionRuntime/PushTemplatePartial" => { - #[allow(non_camel_case_types)] - struct PushTemplatePartialSvc(pub Arc); - impl< - T: ExecutionRuntime, - > tonic::server::UnaryService - for PushTemplatePartialSvc { - type Response = super::ExecutionStatus; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - (*inner).push_template_partial(request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let inner = inner.0; - let method = PushTemplatePartialSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - _ => { - Box::pin(async move { - Ok( - http::Response::builder() - .status(200) - .header("grpc-status", "12") - .header("content-type", "application/grpc") - .body(empty_body()) - .unwrap(), - ) - }) - } - } - } - } - impl Clone for ExecutionRuntimeServer { - fn clone(&self) -> Self { - let inner = self.inner.clone(); - Self { - inner, - accept_compression_encodings: self.accept_compression_encodings, - send_compression_encodings: self.send_compression_encodings, - max_decoding_message_size: self.max_decoding_message_size, - max_encoding_message_size: self.max_encoding_message_size, - } - } - } - impl Clone for _Inner { - fn clone(&self) -> Self { - Self(Arc::clone(&self.0)) - } - } - impl std::fmt::Debug for _Inner { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.0) - } - } - impl tonic::server::NamedService for ExecutionRuntimeServer { - const NAME: &'static str = "promptgraph.ExecutionRuntime"; - } -} diff --git a/toolchain/prompt-graph-core/src/graph_definition.rs b/toolchain/prompt-graph-core/src/graph_definition.rs deleted file mode 100644 index 1cd81ce..0000000 --- a/toolchain/prompt-graph-core/src/graph_definition.rs +++ /dev/null @@ -1,437 +0,0 @@ -use anyhow::anyhow; -use prost::Message; -use serde::{Deserialize, Serialize}; - -use crate::proto as dsl; -use crate::proto::{ItemCore, Query}; -use crate::proto::prompt_graph_node_loader::LoadFrom; - -/// Maps a string to a supported vector database type -fn map_string_to_vector_database(encoding: &str) -> anyhow::Result { - match encoding { - "IN_MEMORY" => Ok(dsl::SupportedVectorDatabase::InMemory), - "CHROMA" => Ok(dsl::SupportedVectorDatabase::Chroma), - "PINECONEDB" => Ok(dsl::SupportedVectorDatabase::Pineconedb), - "QDRANT" => Ok(dsl::SupportedVectorDatabase::Qdrant), - _ => { - Err(anyhow!("Unknown vector database: {}", encoding)) - }, - } -} - -/// Maps a string to a supported embedding model type -fn map_string_to_embedding_model(encoding: &str) -> anyhow::Result { - match encoding { - "TEXT_EMBEDDING_ADA_002" => Ok(dsl::SupportedEmebddingModel::TextEmbeddingAda002), - "TEXT_SEARCH_ADA_DOC_001" => Ok(dsl::SupportedEmebddingModel::TextSearchAdaDoc001), - _ => { - Err(anyhow!("Unknown embedding model: {}", encoding)) - }, - } -} - -/// Maps a string to a supported chat model type -fn map_string_to_chat_model(encoding: &str) -> anyhow::Result { - match encoding { - "GPT_4" => Ok(dsl::SupportedChatModel::Gpt4), - "GPT_4_0314" => Ok(dsl::SupportedChatModel::Gpt40314), - "GPT_4_32K" => Ok(dsl::SupportedChatModel::Gpt432k), - "GPT_4_32K_0314" => Ok(dsl::SupportedChatModel::Gpt432k0314), - "GPT_3_5_TURBO" => Ok(dsl::SupportedChatModel::Gpt35Turbo), - "GPT_3_5_TURBO_0301" => Ok(dsl::SupportedChatModel::Gpt35Turbo0301), - _ => { - Err(anyhow!("Unknown chat model: {}", encoding)) - }, - } -} - -/// Maps a string to a supported source language type -fn map_string_to_supported_source_langauge(encoding: &str) -> anyhow::Result { - match encoding { - "DENO" => Ok(dsl::SupportedSourceCodeLanguages::Deno), - "STARLARK" => Ok(dsl::SupportedSourceCodeLanguages::Starlark), - _ => { - Err(anyhow!("Unknown source language: {}", encoding)) - }, - } -} - -/// Converts a string representing a query definition to a Query type -fn create_query(query_def: Option) -> dsl::Query { - dsl::Query { - query: query_def.map(|d|d), - } -} - -/// Converts a string representing an output definition to an OutputType type -fn create_output(output_def: &str) -> Option { - Some(dsl::OutputType { - output: output_def.to_string(), - }) -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum SourceNodeType { - Code(String, String, bool), - S3(String), - Zipfile(Vec), -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct DefinitionGraph { - internal: dsl::File, -} - - -/// A graph definition or DefinitionGraph defines a graph of executable nodes connected by edges or 'triggers'. -/// The graph is defined in a DSL (domain specific language) that is compiled into a binary formatted File that can be -/// executed by the prompt-graph-core runtime. -impl DefinitionGraph { - - /// Returns the File object representing this graph definition - pub fn get_file(&self) -> &dsl::File { - &self.internal - } - - /// Returns an empty graph definition - pub fn zero() -> Self { - Self { - internal: dsl::File::default() - } - } - - /// Sets this graph definition to read from & write to the given File object - pub fn from_file(file: dsl::File) -> Self { - Self { - internal: file - } - } - - /// Store the given bytes (representing protobuf graph definition) as a - /// new File object and associate this graph definition with it - pub fn new(bytes: &[u8]) -> Self { - Self { - internal: dsl::File::decode(bytes).unwrap() - } - } - - /// Read and return the nodes from internal File object - pub(crate) fn get_nodes(&self) -> &Vec { - &self.internal.nodes - } - - /// Read and return a mutable collection of nodes from internal File object - pub(crate) fn get_nodes_mut(&mut self) -> &Vec { - &self.internal.nodes - } - - /// Serialize the internal File object to bytes and return them - pub(crate) fn serialize(&self) -> Vec { - let mut buffer = Vec::new(); - self.internal.encode(&mut buffer).unwrap(); - buffer - } - - /// Push a given node (defined as Item type) to the internal graph definition - pub fn register_node(&mut self, item: dsl::Item) { - self.internal.nodes.push(item); - } - - /// Push a given node (defined as bytes) to the internal graph definition - pub fn register_node_bytes(&mut self, item: &[u8]) { - let item = dsl::Item::decode(item).unwrap(); - self.internal.nodes.push(item); - } -} - - -#[deprecated(since="0.1.0", note="do not use")] -pub fn create_entrypoint_query( - query_def: Option -) -> dsl::Item { - let query_element = dsl::Query { - query: query_def.map(|x| x.to_string()), - }; - let _node = dsl::PromptGraphNodeCode { - source: None, - }; - dsl::Item { - core: Some(ItemCore { - name: "RegistrationCodeNode".to_string(), - triggers: vec![query_element], - output: Default::default(), - output_tables: vec![], - }), - item: None, - } -} - -/// Takes in common node parameters and returns a fulfilled node type (a dsl::Item type) -pub fn create_node_parameter( - name: String, - output_def: String -) -> dsl::Item { - dsl::Item { - core: Some(ItemCore { - name: name.to_string(), - output: create_output(&output_def), - triggers: vec![Query { query: None }], - output_tables: vec![], - }), - item: Some(dsl::item::Item::NodeParameter(dsl::PromptGraphParameterNode { - })), - } -} - -/// Returns a Map type node, which maps a Path (key) to a given String (value) -pub fn create_op_map( - name: String, - query_defs: Vec>, - path: String, - output_tables: Vec -) -> dsl::Item { - dsl::Item { - core: Some(ItemCore { - name: name.to_string(), - triggers: query_defs.into_iter().map(create_query).collect(), - // TODO: needs to have the type of the input - output: create_output(r#" - { - result: String - } - "#), - output_tables, - }), - item: Some(dsl::item::Item::Map(dsl::PromptGraphMap { - path: path.to_string(), - })), - } -} - -// TODO: automatically wire these into prompt nodes that support function calling -// TODO: https://platform.openai.com/docs/guides/gpt/function-calling -/// Takes in executable code and returns a node that executes said code when triggered -/// This executable code can take the format of: -/// - a raw string of code in a supported language -/// - a path to an S3 bucket containing code in a supported language -/// - a zip file containing code in a supported language -pub fn create_code_node( - name: String, - query_defs: Vec>, - output_def: String, - source_type: SourceNodeType, - output_tables: Vec, -) -> dsl::Item { - let source = match source_type { - SourceNodeType::Code(language, code, template) => { - // https://github.com/denoland/deno/discussions/17345 - // https://github.com/a-poor/js-in-rs/blob/main/src/main.rs - dsl::prompt_graph_node_code::Source::SourceCode( dsl::PromptGraphNodeCodeSourceCode{ - template, - language: map_string_to_supported_source_langauge(&language).unwrap() as i32, - source_code: code.to_string(), - }) - } - SourceNodeType::S3(path) => { - dsl::prompt_graph_node_code::Source::S3Path(path) - } - SourceNodeType::Zipfile(file) => { - dsl::prompt_graph_node_code::Source::Zipfile(file) - } - }; - - dsl::Item { - core: Some(ItemCore { - name: name.to_string(), - triggers: query_defs.into_iter().map(create_query).collect(), - output: create_output(&output_def), - output_tables - }), - item: Some(dsl::item::Item::NodeCode(dsl::PromptGraphNodeCode{ - source: Some(source), - })), - } -} - - - -// TODO: automatically wire these into prompt nodes that support function calling -// TODO: https://platform.openai.com/docs/guides/gpt/function-calling -/// Returns a custom node that executes a given function -/// When registering a custom node in the SDK, you provide an in-language function and -/// tell chidori to register that function under the given "type_name". -/// This function executed is then executed in the graph -/// when referenced by this "type_name" parameter -pub fn create_custom_node( - name: String, - query_defs: Vec>, - output_def: String, - type_name: String, - output_tables: Vec -) -> dsl::Item { - dsl::Item { - core: Some(ItemCore { - name: name.to_string(), - triggers: query_defs.into_iter().map(create_query).collect(), - output: create_output(&output_def), - output_tables - }), - item: Some(dsl::item::Item::NodeCustom(dsl::PromptGraphNodeCustom{ - type_name, - })), - } -} - -/// Returns a node that, when triggered, echoes back its input for easier querying -pub fn create_observation_node( - name: String, - query_defs: Vec>, - output_def: String, - output_tables: Vec -) -> dsl::Item { - dsl::Item { - core: Some(ItemCore { - name: name.to_string(), - triggers: query_defs.into_iter().map(create_query).collect(), - output: create_output(&output_def), - output_tables - }), - item: Some(dsl::item::Item::NodeObservation(dsl::PromptGraphNodeObservation{ - integration: "".to_string(), - })), - } -} - -/// Returns a node that can perform some READ/WRITE/DELETE operation on -/// a specified Vector database, using the specified configuration options -/// (options like the embedding_model to use and collection_name namespace to query within) -pub fn create_vector_memory_node( - name: String, - query_defs: Vec>, - output_def: String, - action: String, - embedding_model: String, - template: String, - db_vendor: String, - collection_name: String, - output_tables: Vec -) -> anyhow::Result { - let model = dsl::prompt_graph_node_memory::EmbeddingModel::Model(map_string_to_embedding_model(&embedding_model)? as i32); - let vector_db = dsl::prompt_graph_node_memory::VectorDbProvider::Db(map_string_to_vector_database(&db_vendor)? as i32); - - let action = match action.as_str() { - "READ" => { - dsl::MemoryAction::Read as i32 - }, - "WRITE" => { - dsl::MemoryAction::Write as i32 - }, - "DELETE" => { - dsl::MemoryAction::Delete as i32 - } - _ => { unreachable!("Invalid action") } - }; - - Ok(dsl::Item { - core: Some(ItemCore { - name: name.to_string(), - triggers: query_defs.into_iter().map(create_query).collect(), - output: create_output(&output_def), - output_tables - }), - item: Some(dsl::item::Item::NodeMemory(dsl::PromptGraphNodeMemory{ - collection_name: collection_name, - action, - embedding_model: Some(model), - template: template, - vector_db_provider: Some(vector_db), - })), - }) -} - -/// Returns a node that can implement logic from another graph definition -/// This is useful for reusing logic across multiple graphs -/// The graph definition to transclude is specified by either -/// - a path to an S3 bucket containing a graph definition -/// - raw bytes of a graph definition -/// - a File object containing a graph definition -pub fn create_component_node( - name: String, - query_defs: Vec>, - output_def: String, - output_tables: Vec, -) -> dsl::Item { - dsl::Item { - core: Some(ItemCore { - name: name.to_string(), - triggers: query_defs.into_iter().map(create_query).collect(), - output: create_output(&output_def), - output_tables - }), - item: Some(dsl::item::Item::NodeComponent(dsl::PromptGraphNodeComponent { - transclusion: None, - })), - } -} - -/// Returns a node that can read bytes from a given source -pub fn create_loader_node( - name: String, - query_defs: Vec>, - output_def: String, - load_from: LoadFrom, - output_tables: Vec, -) -> dsl::Item { - dsl::Item { - core: Some(ItemCore { - name: name.to_string(), - triggers: query_defs.into_iter().map(create_query).collect(), - output: create_output(&output_def), - output_tables - }), - item: Some(dsl::item::Item::NodeLoader(dsl::PromptGraphNodeLoader { - load_from: Some(load_from), - })), - } -} - -/// Returns a node that, when triggered, performs an API call to a given language model endpoint, -/// using the template parameter as the prompt input to the language model, and returns the result -/// to the graph as a String type labeled "promptResult" -pub fn create_prompt_node( - name: String, - query_defs: Vec>, - template: String, - model: String, - output_tables: Vec, -) -> anyhow::Result { - let chat_model = map_string_to_chat_model(&model)?; - let model = dsl::prompt_graph_node_prompt::Model::ChatModel(chat_model as i32); - // TODO: use handlebars Template object in order to inspect the contents of and validate the template against the query - // https://github.com/sunng87/handlebars-rust/blob/23ca8d76bee783bf72f627b4c4995d1d11008d17/src/template.rs#L963 - // self.handlebars.register_template_string(name, template).unwrap(); - // println!("{:?}", Template::compile(&template).unwrap()); - Ok(dsl::Item { - core: Some(ItemCore { - name: name.to_string(), - triggers: query_defs.into_iter().map(create_query).collect(), - output: create_output(r#" - { - promptResult: String - } - "#), - output_tables - }), - item: Some(dsl::item::Item::NodePrompt(dsl::PromptGraphNodePrompt{ - template: template.to_string(), - model: Some(model), - // TODO: add output but set it to some sane defaults - temperature: 1.0, - top_p: 1.0, - max_tokens: 100, - presence_penalty: 0.0, - frequency_penalty: 0.0, - stop: vec![], - })), - }) -} diff --git a/toolchain/prompt-graph-core/src/lib.rs b/toolchain/prompt-graph-core/src/lib.rs index 709c41a..08602ac 100644 --- a/toolchain/prompt-graph-core/src/lib.rs +++ b/toolchain/prompt-graph-core/src/lib.rs @@ -1,27 +1,8 @@ +#![feature(is_sorted)] extern crate protobuf; -use crate::proto::{ChangeValue, Path, SerializedValue}; -use crate::proto::serialized_value::Val; -pub mod graph_definition; -pub mod execution_router; -pub mod utils; -pub mod proto; -pub mod build_runtime_graph; -pub mod reactivity; -pub mod time_travel; +pub mod execution; +pub mod library; pub mod prompt_composition; - - -/// Our local server implementation is an extension of this. Implementing support for multiple -/// agent implementations to run on the same machine. -pub fn create_change_value(address: Vec, val: Option, branch: u64) -> ChangeValue { - ChangeValue{ - path: Some(Path { - address, - }), - value: Some(SerializedValue { - val, - }), - branch, - } -} +pub mod time_travel; +pub mod utils; diff --git a/toolchain/prompt-graph-core/src/library/mod.rs b/toolchain/prompt-graph-core/src/library/mod.rs new file mode 100644 index 0000000..df1033a --- /dev/null +++ b/toolchain/prompt-graph-core/src/library/mod.rs @@ -0,0 +1 @@ +pub mod std; diff --git a/toolchain/prompt-graph-core/src/library/std/code/mod.rs b/toolchain/prompt-graph-core/src/library/std/code/mod.rs new file mode 100644 index 0000000..0ac4fba --- /dev/null +++ b/toolchain/prompt-graph-core/src/library/std/code/mod.rs @@ -0,0 +1,2 @@ +pub mod runtime_deno; +pub mod runtime_starlark; diff --git a/toolchain/prompt-graph-core/src/library/std/code/runtime_deno.rs b/toolchain/prompt-graph-core/src/library/std/code/runtime_deno.rs new file mode 100644 index 0000000..6e82ad9 --- /dev/null +++ b/toolchain/prompt-graph-core/src/library/std/code/runtime_deno.rs @@ -0,0 +1,61 @@ +use anyhow::Result; +use deno_core::serde_json::Value; +use deno_core::{serde_json, serde_v8, v8, FastString, JsRuntime, RuntimeOptions}; + +pub fn source_code_run_deno(source_code: String, _state: Option) -> Result> { + // Wrap the source code in an entrypoint function so that it immediately evaluates + let wrapped_source_code = format!( + r#"(function main() {{ + {} + }})();"#, + source_code + ); + + let mut runtime = JsRuntime::new(RuntimeOptions::default()); + + // TODO: the script receives the arguments as a json payload "#state" + let result = runtime.execute_script( + "main.js", + FastString::Owned(wrapped_source_code.into_boxed_str()), + ); + + match result { + Ok(global) => { + let scope = &mut runtime.handle_scope(); + let local = v8::Local::new(scope, global); + let deserialized_value = serde_v8::from_v8::(scope, local); + return Ok(if let Ok(value) = deserialized_value { + Some(value) + } else { + None + }); + } + Err(e) => Err(e), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_source_code_run_deno_success() { + let source_code = String::from("return 42;"); + let result = source_code_run_deno(source_code, None); + assert_eq!(result.unwrap(), Some(serde_json::json!(42))); + } + + #[test] + fn test_source_code_run_deno_failure() { + let source_code = String::from("throw new Error('Test Error');"); + let result = source_code_run_deno(source_code, None); + assert!(result.is_err()); + } + + #[test] + fn test_source_code_run_deno_json_serialization() { + let source_code = String::from("return {foo: 'bar'};"); + let result = source_code_run_deno(source_code, None); + assert_eq!(result.unwrap(), Some(serde_json::json!({"foo": "bar"}))); + } +} diff --git a/toolchain/prompt-graph-core/src/library/std/code/runtime_pyo3.rs b/toolchain/prompt-graph-core/src/library/std/code/runtime_pyo3.rs new file mode 100644 index 0000000..e69de29 diff --git a/toolchain/prompt-graph-core/src/library/std/code/runtime_starlark.rs b/toolchain/prompt-graph-core/src/library/std/code/runtime_starlark.rs new file mode 100644 index 0000000..0c50a73 --- /dev/null +++ b/toolchain/prompt-graph-core/src/library/std/code/runtime_starlark.rs @@ -0,0 +1,42 @@ +use indoc::indoc; +use starlark::environment::{Globals, Module as StarlarkModule}; +use starlark::eval::Evaluator; +use starlark::syntax::{AstModule, Dialect}; +use starlark::values::Value as StarlarkValue; + +pub fn source_code_run_starlark(source_code: String) -> Option { + let ast: AstModule = AstModule::parse( + "hello_world.star", + source_code.to_owned(), + &Dialect::Standard, + ) + .unwrap(); + let globals: Globals = Globals::standard(); + let module: StarlarkModule = StarlarkModule::new(); + let mut eval: Evaluator = Evaluator::new(&module); + let res: StarlarkValue = eval.eval_module(ast, &globals).unwrap(); + let v: serde_json::Value = serde_json::from_str(&res.to_json().unwrap()).unwrap(); + Some(v) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_source_code_run_starlark() { + // Define a sample Starlark code + let starlark_code = indoc! { r#" + # Starlark code that outputs JSON data + def main(): + return {"key": "value"} + + main() + "#} + .to_string(); + + let expected_output = serde_json::json!({"key": "value"}); + let actual_output = source_code_run_starlark(starlark_code).unwrap(); + assert_eq!(actual_output, expected_output); + } +} diff --git a/toolchain/prompt-graph-core/src/library/std/io/mod.rs b/toolchain/prompt-graph-core/src/library/std/io/mod.rs new file mode 100644 index 0000000..16a80cb --- /dev/null +++ b/toolchain/prompt-graph-core/src/library/std/io/mod.rs @@ -0,0 +1 @@ +pub mod zip; diff --git a/toolchain/prompt-graph-core/src/library/std/io/zip/mod.rs b/toolchain/prompt-graph-core/src/library/std/io/zip/mod.rs new file mode 100644 index 0000000..953b7bc --- /dev/null +++ b/toolchain/prompt-graph-core/src/library/std/io/zip/mod.rs @@ -0,0 +1,46 @@ +use anyhow::Result; +use std::io::{Cursor, Read}; +use zip; + +// TODO: +// * zipfile in message +// * zipfile over http +// * http - load webpage +// * http - load json +// * sqlite - database proxy +// * arbitrary changes pushed by the host environment + +pub fn extract_zip(bytes: &[u8]) -> Result { + let cursor = Cursor::new(bytes); + let mut zip = zip::ZipArchive::new(cursor)?; + for i in 0..zip.len() { + let mut file = zip.by_index(i)?; + if file.is_dir() { + continue; + } + if file.name().contains("__MACOSX") || file.name().contains(".DS_Store") { + continue; + } + let mut buffer = Vec::new(); + file.read_to_end(&mut buffer).expect("Failed to read file"); + let string = String::from_utf8_lossy(&buffer); + } + Ok(true) +} + +#[cfg(test)] +mod tests { + use super::*; + use indoc::indoc; + use std::fs::File; + + #[test] + fn test_exec_load_node_zip_bytes() -> Result<()> { + // Open the file in read-only mode + let mut file = File::open("./tests/data/files_and_dirs.zip")?; + let mut buffer = Vec::new(); + file.read_to_end(&mut buffer)?; + extract_zip(&buffer); + Ok(()) + } +} diff --git a/toolchain/prompt-graph-exec/src/runtime_nodes/node_memory/in_memory.rs b/toolchain/prompt-graph-core/src/library/std/memory/in_memory/mod.rs similarity index 56% rename from toolchain/prompt-graph-exec/src/runtime_nodes/node_memory/in_memory.rs rename to toolchain/prompt-graph-core/src/library/std/memory/in_memory/mod.rs index 735302f..bc9d6e7 100644 --- a/toolchain/prompt-graph-exec/src/runtime_nodes/node_memory/in_memory.rs +++ b/toolchain/prompt-graph-core/src/library/std/memory/in_memory/mod.rs @@ -1,45 +1,44 @@ -use std::collections::HashMap; -use hnsw_rs_thousand_birds::hnsw::{Hnsw, Neighbour}; +use crate::library::std::memory::{VectorDatabase, VectorDbError}; +use async_trait::async_trait; use hnsw_rs_thousand_birds::dist::DistDot; +use hnsw_rs_thousand_birds::hnsw::{Hnsw, Neighbour}; +use serde_json::Value; +use std::collections::HashMap; // TODO: manage multiple independent named collections -pub struct InMemoryVectorDb { - db: HashMap, +pub struct InMemoryVectorDb { + db: HashMap, id_counter: usize, - hnsw: Hnsw:: + hnsw: Hnsw, } -impl InMemoryVectorDb where T: Clone{ +impl InMemoryVectorDb { pub fn new() -> Self { let mut hnsw = Hnsw::::new( // max_nb_connection (in hnsw initialization) The maximum number of links from one // point to others. Values ranging from 16 to 64 are standard initialising values, // the higher the more time consuming. 16, - 100_000, - // max_layer (in hnsw initialization) // The maximum number of layers in graph. Must be less or equal than 16. 16, - // ef_construction (in hnsw initialization) // This parameter controls the width of the search for neighbours during insertion. // Values from 200 to 800 are standard initialising values, the higher the more time consuming. 200, - // Distance function - DistDot{} + DistDot {}, ); hnsw.set_extend_candidates(true); Self { db: HashMap::new(), id_counter: 0, - hnsw + hnsw, } } - pub fn insert(&mut self, data: &Vec<(&Vec, T)>) { + pub fn insert(&mut self, data: &Vec<(&Vec, Value)>) { // usize is the id let mut insert_set = vec![]; for item in data { @@ -56,20 +55,68 @@ impl InMemoryVectorDb where T: Clone{ // TODO: supports searching multiple keys at once, we should support that let neighbors = self.hnsw.parallel_search(&vec![data], num_neighbors, 16); for neighbor in neighbors.first().unwrap() { - results.push((neighbor.clone(), self.db.get(&neighbor.d_id).unwrap().clone())); + results.push(( + neighbor.clone(), + self.db.get(&neighbor.d_id).unwrap().clone(), + )); } results } +} +struct MemoryInMemory { + client: InMemoryVectorDb, + collection_name: String, } +#[async_trait] +impl VectorDatabase> for MemoryInMemory { + fn attach_client(client: InMemoryVectorDb) -> Result { + // Assuming QdrantClient can be constructed from a connection string + // let client = MyQdrantClient( + // QdrantClient::new(Some(QdrantClientConfig { + // uri: connection_string.to_string(), + // ..Default::default() + // })) + // .unwrap(), + // ); + Ok(MemoryInMemory { + client, + collection_name: "default_collection".to_string(), // Or use a parameter + }) + } + + async fn create_collection( + &self, + collection_name: String, + embedding_length: u64, + ) -> Result<(), VectorDbError> { + unimplemented!(); + } + async fn insert_vector( + &self, + id: u64, + vector: Vec, + payload: Option, + ) -> Result<(), VectorDbError> { + db.insert(&row); + unimplemented!(); + } + + async fn query_by_vector( + &self, + vector: Vec, + top_k: usize, + ) -> Result, VectorDbError> { + unimplemented!(); + } +} #[cfg(test)] mod tests { use super::*; - #[ignore] #[test] fn test_memory_db() { let mut db = InMemoryVectorDb::new(); @@ -77,9 +124,9 @@ mod tests { let contents = HashMap::from([("name", "test")]); let row = vec![(&embedding, contents)]; db.insert(&row); - let search = vec![0.1, 0.1, 0.1]; + let search = vec![0.1, 0.2, 0.3]; let result = db.search(search, 1); assert_eq!(result.len(), 1); assert_eq!(result[0].1, HashMap::from([("name", "test")])); } -} \ No newline at end of file +} diff --git a/toolchain/prompt-graph-core/src/library/std/memory/mod.rs b/toolchain/prompt-graph-core/src/library/std/memory/mod.rs new file mode 100644 index 0000000..6ecc34e --- /dev/null +++ b/toolchain/prompt-graph-core/src/library/std/memory/mod.rs @@ -0,0 +1,61 @@ +use anyhow; +use async_trait::async_trait; + +use std::collections::HashMap; +pub mod in_memory; +pub mod qdrant; + +// Define a custom error type for our vector database interactions + +#[derive(Debug)] +pub enum VectorDbError { + ConnectionError(String), + QueryError(String), + InsertionError(String), + CollectionCreationError(String), + // other error types... +} + +struct CoreValueEmbedding {} + +trait TraitValueEmbedding { + fn embed(&self) -> anyhow::Result>; +} + +struct CoreVectorDatabase { + name: String, + table: String, + schema: String, +} + +// The trait for vector database interaction +#[async_trait] +pub trait VectorDatabase { + // Connects to the vector database + fn attach_client(client: C) -> Result + where + Self: Sized; + + async fn create_collection( + &self, + collection_name: String, + embedding_length: u64, + ) -> Result<(), VectorDbError>; + + // Inserts a vector into the database + async fn insert_vector( + &self, + id: u64, + vector: Vec, + payload: Option, + ) -> Result<(), VectorDbError>; + + // Queries the database by vector + async fn query_by_vector( + &self, + vector: Vec, + top_k: usize, + ) -> Result, VectorDbError>; + + // Additional methods like update, delete, etc. can be added here +} diff --git a/toolchain/prompt-graph-core/src/library/std/memory/qdrant/mod.rs b/toolchain/prompt-graph-core/src/library/std/memory/qdrant/mod.rs new file mode 100644 index 0000000..8ebe5aa --- /dev/null +++ b/toolchain/prompt-graph-core/src/library/std/memory/qdrant/mod.rs @@ -0,0 +1,228 @@ +use crate::library::std::memory::{VectorDatabase, VectorDbError}; +use async_trait::async_trait; +use base64; +use handlebars::JsonValue; +use qdrant_client::prelude::*; +use qdrant_client::qdrant::point_id::PointIdOptions; +use qdrant_client::qdrant::vectors_config::Config; +use qdrant_client::qdrant::{ + CreateCollection, PointId, SearchPoints, SearchResponse, VectorParams, Vectors, +}; +use serde_json::Value; +use std::collections::HashMap; + +#[async_trait] +pub trait WrappedQdrantClient { + async fn upsert_points_blocking( + &self, + collection_name: String, + points: Vec, + option: Option, + ) -> Result<(), String>; + + async fn search_points(&self, params: &SearchPoints) -> Result; + + async fn create_collection(&self, collection: &CreateCollection) -> Result<(), String>; +} + +pub struct MyQdrantClient(QdrantClient); + +#[async_trait] +impl WrappedQdrantClient for MyQdrantClient { + async fn upsert_points_blocking( + &self, + collection_name: String, + points: Vec, + option: Option, + ) -> Result<(), String> { + // Actual implementation for QdrantClient + Ok(()) + } + + async fn search_points(&self, params: &SearchPoints) -> Result { + // Actual implementation for QdrantClient + Err("Not implemented".to_string()) + } + + async fn create_collection(&self, collection: &CreateCollection) -> Result<(), String> { + // Actual implementation for QdrantClient + Err("Not implemented".to_string()) + } +} + +struct MemoryQdrant { + client: C, + collection_name: String, +} + +#[async_trait] +impl VectorDatabase for MemoryQdrant { + fn attach_client(client: C) -> Result { + // Assuming QdrantClient can be constructed from a connection string + // let client = MyQdrantClient( + // QdrantClient::new(Some(QdrantClientConfig { + // uri: connection_string.to_string(), + // ..Default::default() + // })) + // .unwrap(), + // ); + Ok(MemoryQdrant { + client, + collection_name: "default_collection".to_string(), // Or use a parameter + }) + } + + async fn create_collection( + &self, + collection_name: String, + embedding_length: u64, + ) -> Result<(), VectorDbError> { + self.client + .create_collection(&CreateCollection { + collection_name: collection_name.into(), + vectors_config: Some(qdrant_client::qdrant::VectorsConfig { + config: Some(Config::Params(VectorParams { + size: embedding_length, + distance: Distance::Cosine.into(), + ..Default::default() + })), + }), + ..Default::default() + }) + .await + .map_err(|e| VectorDbError::CollectionCreationError(e.to_string())) + } + + async fn insert_vector( + &self, + id: u64, + vector: Vec, + payload: Option, + ) -> Result<(), VectorDbError> { + // Additional payload handling here if needed + let points = vec![PointStruct::new( + PointId::from(id), + Vectors::from(vector.to_vec()), + if let Some(payload) = payload { + payload.try_into().unwrap() + } else { + Payload::default() + }, + )]; + self.client + .upsert_points_blocking(self.collection_name.clone(), points, None) + .await + .map_err(|e| VectorDbError::InsertionError(e.to_string())) // Map the error to VectorDbError + } + + async fn query_by_vector( + &self, + vector: Vec, + top_k: usize, + ) -> Result, VectorDbError> { + let search_result = self + .client + .search_points(&SearchPoints { + collection_name: self.collection_name.clone(), + vector: vector.to_vec(), + filter: None, + limit: top_k as u64, + with_payload: Some(true.into()), + ..Default::default() + }) + .await + .map_err(|e| VectorDbError::QueryError(e.to_string()))?; // Map the error to VectorDbError + + let ids = search_result + .result + .into_iter() + .filter_map(|point| { + point.id.map(|id| match id.point_id_options.unwrap() { + PointIdOptions::Num(id) => id, + _ => 0, + }) + }) + .collect(); + + Ok(ids) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use qdrant_client::qdrant::PointId; + use qdrant_client::qdrant::ScoredPoint; + use qdrant_client::qdrant::Vectors; + + // Mock a QdrantClient for testing purposes + struct MockQdrantClient {} + + impl MockQdrantClient { + fn new(_connection_string: &str) -> Self { + MockQdrantClient {} + } + } + + #[async_trait] + impl WrappedQdrantClient for MockQdrantClient { + async fn upsert_points_blocking( + &self, + _collection_name: String, + _points: Vec, + _option: Option, + ) -> Result<(), String> { + Ok(()) + } + + async fn search_points(&self, _params: &SearchPoints) -> Result { + Ok(SearchResponse { + result: vec![ + ScoredPoint { + id: Some(PointId::from(1)), + payload: Default::default(), + score: 0.9, // Example score + version: 1, // Example version + vectors: Some(Vectors::from(vec![0.1, 0.2])), + }, + ScoredPoint { + id: Some(PointId::from(2)), + payload: Default::default(), + score: 0.8, // Example score + version: 1, // Example version + vectors: Some(Vectors::from(vec![0.3, 0.4])), + }, + ], + time: 0.0, + }) + } + + async fn create_collection(&self, _collection: &CreateCollection) -> Result<(), String> { + Ok(()) + } + } + + #[tokio::test] + async fn test_insert_vector() { + let db = MemoryQdrant { + client: MockQdrantClient::new("mock_connection_string"), + collection_name: "test_collection".to_string(), + }; + + let result = db.insert_vector(123, vec![0.5, 0.6], None).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_query_by_vector() { + let db = MemoryQdrant { + client: MockQdrantClient::new("mock_connection_string"), + collection_name: "test_collection".to_string(), + }; + + let result = db.query_by_vector(vec![0.5, 0.6], 2).await; + assert!(result.is_ok()); + let ids = result.unwrap(); + assert_eq!(ids, vec![1, 2]); + } +} diff --git a/toolchain/prompt-graph-core/src/library/std/mod.rs b/toolchain/prompt-graph-core/src/library/std/mod.rs new file mode 100644 index 0000000..0ae71d1 --- /dev/null +++ b/toolchain/prompt-graph-core/src/library/std/mod.rs @@ -0,0 +1,5 @@ +pub mod code; +pub mod io; +pub mod memory; +// pub mod prompt; +// pub mod schedule; diff --git a/toolchain/prompt-graph-core/src/library/std/prompt/mod.rs b/toolchain/prompt-graph-core/src/library/std/prompt/mod.rs new file mode 100644 index 0000000..729577f --- /dev/null +++ b/toolchain/prompt-graph-core/src/library/std/prompt/mod.rs @@ -0,0 +1,7 @@ +struct ChatCompletionReq { + model: SupportedChatModel, + frequency_penalty: Option, + max_tokens: Option, + presence_penalty: Option, + stop: Option>, +} diff --git a/toolchain/prompt-graph-exec/src/integrations/openai/batch.rs b/toolchain/prompt-graph-core/src/library/std/prompt/openai/batch.rs similarity index 75% rename from toolchain/prompt-graph-exec/src/integrations/openai/batch.rs rename to toolchain/prompt-graph-core/src/library/std/prompt/openai/batch.rs index ac900db..97f689c 100644 --- a/toolchain/prompt-graph-exec/src/integrations/openai/batch.rs +++ b/toolchain/prompt-graph-core/src/library/std/prompt/openai/batch.rs @@ -1,22 +1,27 @@ use std::env; +use crate::library::std::prompt::ChatCompletionReq; use openai_api_rs::v1::api::Client; -use openai_api_rs::v1::chat_completion::{ChatCompletionRequest, ChatCompletionResponse}; use openai_api_rs::v1::chat_completion; +use openai_api_rs::v1::chat_completion::{ChatCompletionRequest, ChatCompletionResponse}; use openai_api_rs::v1::chat_completion::{ - GPT3_5_TURBO, - GPT3_5_TURBO_0301, - GPT4, - GPT4_0314, - GPT4_32K, - GPT4_32K_0314 + GPT3_5_TURBO, GPT3_5_TURBO_0301, GPT4, GPT4_0314, GPT4_32K, GPT4_32K_0314, }; use openai_api_rs::v1::error::APIError; -use prompt_graph_core::proto::{PromptGraphNodePrompt, SupportedChatModel}; - +pub enum SupportedChatModel { + Gpt4, + Gpt40314, + Gpt432k, + Gpt432k0314, + Gpt35Turbo, + Gpt35Turbo0301, +} -pub async fn chat_completion(_n: &PromptGraphNodePrompt, openai_model: SupportedChatModel, templated_string: String) -> Result { +pub async fn chat_completion( + chat_completion_req: ChatCompletionReq, + templated_string: String, +) -> Result { let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); let model = match openai_model { @@ -26,7 +31,8 @@ pub async fn chat_completion(_n: &PromptGraphNodePrompt, openai_model: Supported SupportedChatModel::Gpt432k0314 => GPT4_32K_0314, SupportedChatModel::Gpt35Turbo => GPT3_5_TURBO, SupportedChatModel::Gpt35Turbo0301 => GPT3_5_TURBO_0301, - }.to_string(); + } + .to_string(); let req = ChatCompletionRequest { model, diff --git a/toolchain/prompt-graph-core/src/library/std/prompt/openai/mod.rs b/toolchain/prompt-graph-core/src/library/std/prompt/openai/mod.rs new file mode 100644 index 0000000..e69de29 diff --git a/toolchain/prompt-graph-exec/src/integrations/openai/streaming.rs b/toolchain/prompt-graph-core/src/library/std/prompt/openai/streaming.rs similarity index 73% rename from toolchain/prompt-graph-exec/src/integrations/openai/streaming.rs rename to toolchain/prompt-graph-core/src/library/std/prompt/openai/streaming.rs index 48ee91d..0271edf 100644 --- a/toolchain/prompt-graph-exec/src/integrations/openai/streaming.rs +++ b/toolchain/prompt-graph-core/src/library/std/prompt/openai/streaming.rs @@ -1,11 +1,10 @@ -use std::pin::Pin; -use std::task::{Context, Poll}; use deno_core::serde_json; use futures_util::stream::Stream; use openai_api_rs::v1::chat_completion::ChatCompletionRequest; use reqwest::{Client, Response}; use serde_json::Value; - +use std::pin::Pin; +use std::task::{Context, Poll}; pub struct GptStream { response: Pin> + Send>>, @@ -13,7 +12,10 @@ pub struct GptStream { first_chunk: bool, } -pub async fn gpt_stream(api_key: String, completion: ChatCompletionRequest) -> Result { +pub async fn gpt_stream( + api_key: String, + completion: ChatCompletionRequest, +) -> Result { let api_url = "https://api.openai.com/v1/chat/completions"; let client = Client::new(); let response: Response = match client @@ -35,7 +37,10 @@ pub async fn gpt_stream(api_key: String, completion: ChatCompletionRequest) -> R first_chunk: true, }) } else { - let error_text = response.text().await.unwrap_or_else(|_| String::from("Unknown error")); + let error_text = response + .text() + .await + .unwrap_or_else(|_| String::from("Unknown error")); Err(format!("API request error: {}", error_text)) } } @@ -67,7 +72,9 @@ impl Stream for GptStream { Ok(json) => { if let Some(choices) = json.get("choices") { if let Some(choice) = choices.get(0) { - if let Some(content) = choice.get("delta").and_then(|delta| delta.get("content")) { + if let Some(content) = + choice.get("delta").and_then(|delta| delta.get("content")) + { if let Some(content_str) = content.as_str() { self.buffer.push_str(content_str); let output = self.buffer.replace("\\n", "\n"); @@ -95,42 +102,48 @@ impl Stream for GptStream { } } - #[cfg(test)] mod tests { - use std::env; use super::*; use futures_util::stream::StreamExt; use openai_api_rs::v1::chat_completion::{ChatCompletionMessage, MessageRole}; + use std::env; #[cfg(feature = "integration-tests")] #[tokio::test] async fn test_gpt_stream_raw_line() { let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); - let stream = gpt_stream(api_key, ChatCompletionRequest { - model: "gpt-3.5-turbo".to_string(), - messages: vec![ChatCompletionMessage { - role: MessageRole::user, - content: Some("One sentence to describe a simple advanced usage of Rust".to_string()), - name: None, + let stream = gpt_stream( + api_key, + ChatCompletionRequest { + model: "gpt-3.5-turbo".to_string(), + messages: vec![ChatCompletionMessage { + role: MessageRole::user, + content: Some( + "One sentence to describe a simple advanced usage of Rust".to_string(), + ), + name: None, + function_call: None, + }], + functions: None, function_call: None, - }], - functions: None, - function_call: None, - temperature: None, - top_p: None, - n: None, - stream: Some(true), - stop: None, - max_tokens: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - user: None, - }).await.unwrap(); + temperature: None, + top_p: None, + n: None, + stream: Some(true), + stop: None, + max_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + }, + ) + .await + .unwrap(); let mut stream = Box::pin(stream); while let Some(value) = stream.next().await { println!("{}", value); } } -} \ No newline at end of file +} diff --git a/toolchain/prompt-graph-core/src/library/std/prompt/prompt.rs b/toolchain/prompt-graph-core/src/library/std/prompt/prompt.rs new file mode 100644 index 0000000..0f412a6 --- /dev/null +++ b/toolchain/prompt-graph-core/src/library/std/prompt/prompt.rs @@ -0,0 +1,14 @@ +use tokio::time::sleep; + +use prompt_graph_core::prompt_composition::templates::render_template_prompt; +use prompt_graph_core::proto::serialized_value::Val; +use prompt_graph_core::proto::{item, ChangeValue, SupportedChatModel}; +use std::time::Duration; + +use crate::executor::NodeExecutionContext; +use crate::integrations::openai::batch::chat_completion; +use prompt_graph_core::proto::prompt_graph_node_prompt::Model; + +trait ChatModel {} + +trait CompletionModel {} diff --git a/toolchain/prompt-graph-core/src/library/std/schedule/README.md b/toolchain/prompt-graph-core/src/library/std/schedule/README.md new file mode 100644 index 0000000..a402ca5 --- /dev/null +++ b/toolchain/prompt-graph-core/src/library/std/schedule/README.md @@ -0,0 +1,24 @@ + +# Services to schedule tasks + +For handling scheduling tasks similar to cron, but with a focus on durable scheduling, there are several services and technologies you might consider. These include both cloud-based services and open-source tools: + +AWS CloudWatch Events & AWS Lambda: AWS CloudWatch Events can be used to trigger scheduled events. These events can then initiate AWS Lambda functions, which can perform a variety of tasks. This setup is highly scalable and reliable. + +Google Cloud Scheduler: This is a fully managed cron job service from Google Cloud that allows you to schedule virtually any task. It can trigger HTTP/S endpoints or publish messages to a Pub/Sub topic, integrating seamlessly with other Google Cloud services. + +Azure Logic Apps & Azure Functions: Azure Logic Apps provides a way to schedule and automate workflows. When combined with Azure Functions, it allows for powerful, serverless execution of tasks. + +Kubernetes CronJobs: If you're using Kubernetes, CronJobs can schedule tasks (jobs) to run at specific times or intervals. This is particularly useful in containerized environments. + +Apache Airflow: An open-source tool designed to orchestrate complex computational workflows and data processing pipelines. It allows you to programatically author, schedule, and monitor workflows. + +Celery Beat: For Python applications, Celery with Celery Beat can be used to schedule regular tasks. It is often used in conjunction with Django but can be used in any Python application. + +Quartz Scheduler: A richly featured, open-source job scheduling library that can be integrated within virtually any Java application. + +Rundeck: An open-source job scheduler and runbook automation tool that enables you to run defined tasks on a schedule. It's useful for operational tasks and can integrate with various external tools. + +Hangfire (for .NET): An open-source framework for background job processing in .NET applications. It supports scheduled tasks and can be used in any .NET application. + +Nomad (by HashiCorp): While primarily a workload orchestrator, Nomad can handle periodic, cron-like tasks across a distributed infrastructure. \ No newline at end of file diff --git a/toolchain/prompt-graph-core/src/library/std/schedule/mod.rs b/toolchain/prompt-graph-core/src/library/std/schedule/mod.rs new file mode 100644 index 0000000..8337712 --- /dev/null +++ b/toolchain/prompt-graph-core/src/library/std/schedule/mod.rs @@ -0,0 +1 @@ +// diff --git a/toolchain/prompt-graph-core/src/prompt_composition/mod.rs b/toolchain/prompt-graph-core/src/prompt_composition/mod.rs index 4058a3c..b9a1684 100644 --- a/toolchain/prompt-graph-core/src/prompt_composition/mod.rs +++ b/toolchain/prompt-graph-core/src/prompt_composition/mod.rs @@ -1 +1 @@ -pub mod templates; \ No newline at end of file +pub mod templates; diff --git a/toolchain/prompt-graph-core/src/prompt_composition/templates.rs b/toolchain/prompt-graph-core/src/prompt_composition/templates.rs index 9f9b232..b326ac3 100644 --- a/toolchain/prompt-graph-core/src/prompt_composition/templates.rs +++ b/toolchain/prompt-graph-core/src/prompt_composition/templates.rs @@ -1,15 +1,14 @@ -/// This is a wasm-compatible implementation of how we handle templates for prompts -/// I made the decision to implement this in order to avoid needing to build equivalents for multiple platforms. - -use std::collections::HashMap; -use handlebars::{Handlebars, Path, Template}; +use crate::execution::primitives::serialized_value::RkyvSerializedValue as RKV; +///! The goal of prompt_composition is to enable the composition of prompts in a way where we can +///! trace how the final prompt was assembled and why. +///! +///! This is a wasm-compatible implementation of how we handle templates for prompts +use anyhow::Result; use handlebars::template::{Parameter, TemplateElement}; +use handlebars::{Handlebars, Path, Template}; +use serde_json::value::Map as JsonMap; use serde_json::{Map, Value}; -use serde_json::value::{Map as JsonMap}; -use anyhow::{Result}; -use crate::proto::serialized_value::Val; -use crate::proto::{ChangeValue, PromptLibraryRecord, SerializedValue, SerializedValueArray, SerializedValueObject}; - +use std::collections::HashMap; // https://github.com/microsoft/guidance @@ -18,16 +17,15 @@ use crate::proto::{ChangeValue, PromptLibraryRecord, SerializedValue, Serialized // https://github.com/sunng87/handlebars-rust/blob/23ca8d76bee783bf72f627b4c4995d1d11008d17/src/template.rs#L963 // self.handlebars.register_template_string(name, template).unwrap(); -/// Verify that the template and included query paths are valid -pub fn validate_template(template_str: &str, _query_paths: Vec>) { - // let mut handlebars = Handlebars::new(); - let template = Template::compile(template_str).unwrap(); - let mut reference_paths = Vec::new(); - traverse_ast(&template, &mut reference_paths, vec![]); - println!("{:?}", reference_paths); - // TODO: check that all query paths are satisfied by this template - // handlebars.register_template("test", template).unwrap(); -} +// /// Verify that the template and included query paths are valid +// pub fn validate_template(template_str: &str, _query_paths: Vec>) { +// // let mut handlebars = Handlebars::new(); +// let template = Template::compile(template_str).unwrap(); +// let mut reference_paths = Vec::new(); +// println!("{:?}", reference_paths); +// // TODO: check that all query paths are satisfied by this template +// // handlebars.register_template("test", template).unwrap(); +// } #[derive(Debug, Clone)] struct ContextBlock { @@ -38,139 +36,175 @@ struct ContextBlock { /// Traverse over every partial template in a Template (which can be a set of template partials) and validate that each /// partial template can be matched to a either 1) some template type that Handlebars recognizes /// or 2) a query path that can pull data out of the event log -fn traverse_ast(template: &Template, reference_paths: &mut Vec<(Path, Vec)>, context: Vec) { +fn analyze_referenced_partials( + template: &Template, + reference_paths: &mut Vec<(Path, Vec)>, + context: Vec, + partial_library: &HashMap, +) { for el in &template.elements { match el { TemplateElement::RawString(_) => {} - TemplateElement::HtmlExpression(helper_block) | - TemplateElement::Expression(helper_block) | - TemplateElement::HelperBlock(helper_block) => { + TemplateElement::HtmlExpression(helper_block) + | TemplateElement::Expression(helper_block) + | TemplateElement::HelperBlock(helper_block) => { let deref = *(helper_block.clone()); - let _params = &deref.params; - match &deref.name { - Parameter::Name(_name) => { - // println!("name, {:?} - params {:?}", name, params); - // reference_paths.push((None, context.clone())); - } - Parameter::Path(path) => { - reference_paths.push((path.clone(), context.clone())); - } - Parameter::Literal(_) => { - } - Parameter::Subexpression(_) => {} - } if let Some(next_template) = deref.template { let mut ctx = context.clone(); ctx.extend(vec![ContextBlock { name: deref.name.clone(), params: deref.params.clone(), }]); - traverse_ast(&next_template, reference_paths, ctx); + analyze_referenced_partials( + &next_template, + reference_paths, + ctx, + partial_library, + ); } } - TemplateElement::DecoratorExpression(_) => {} + TemplateElement::DecoratorExpression(decorator_block) => {} TemplateElement::DecoratorBlock(_) => {} - TemplateElement::PartialExpression(_) => {} - TemplateElement::PartialBlock(_) => {} + TemplateElement::PartialExpression(x) => { + println!("PartialExpression {:?}", x); + let deref = *(x.clone()); + if let Some(ident) = deref.indent { + if let Some(record) = partial_library.get(&ident) { + let mut ctx = context.clone(); + ctx.extend(vec![ContextBlock { + name: deref.name.clone(), + params: deref.params.clone(), + }]); + // TODO: recursively analyze the partial template + } + } + } + TemplateElement::PartialBlock(x) => { + println!("PartialBlock {:?}", x) + } TemplateElement::Comment(_) => {} } } } -fn convert_template_to_prompt() { - -} +fn convert_template_to_prompt() {} -fn infer_query_from_template() { +fn infer_query_from_template() {} +#[derive(PartialEq, Debug)] +enum ChatModelRoles { + User, + System, + Assistant, } -fn extract_roles_from_template() { - -} - -/// Recursively flatten a SerializedValue into a set of key paths and values -pub fn flatten_value_keys(sval: SerializedValue, current_path: Vec) -> Vec<(Vec, Val)> { - let mut flattened = vec![]; - match sval.val { - Some(Val::Object(a)) => { - for (key, value) in &a.values { - let mut path = current_path.clone(); - path.push(key.clone()); - flattened.extend(flatten_value_keys(value.clone(), path)); +fn extract_roles_from_template( + template: &Template, + context: Vec, +) -> Vec<(ChatModelRoles, Option