From b65cd82611b82e25902c10389beb5853fe8b8550 Mon Sep 17 00:00:00 2001 From: Stephen Chung Date: Fri, 27 Sep 2024 13:07:37 +0800 Subject: [PATCH 1/2] Add SQL custom syntax test --- tests/custom_syntax.rs | 69 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 67 insertions(+), 2 deletions(-) diff --git a/tests/custom_syntax.rs b/tests/custom_syntax.rs index ac504dfe7..9a2a3fbd3 100644 --- a/tests/custom_syntax.rs +++ b/tests/custom_syntax.rs @@ -1,6 +1,5 @@ #![cfg(not(feature = "no_custom_syntax"))] - -use rhai::{Dynamic, Engine, EvalAltResult, FnPtr, ImmutableString, LexError, ParseErrorType, Position, Scope, INT}; +use rhai::{Dynamic, Engine, EvalAltResult, ImmutableString, LexError, ParseErrorType, Position, Scope, INT}; #[test] fn test_custom_syntax() { @@ -432,3 +431,69 @@ fn test_custom_syntax_raw2() { assert_eq!(engine.eval::("#42/2").unwrap(), 21); assert_eq!(engine.eval::("sign(#1)").unwrap(), 1); } + +#[test] +fn test_custom_syntax_raw_sql() { + let mut engine = Engine::new(); + + engine.register_custom_syntax_with_state_raw( + "SELECT", + |symbols, lookahead, state| { + // Build a SQL statement as the state + let mut sql: String = if state.is_unit() { Default::default() } else { state.take().cast::().into() }; + + // At every iteration, the last symbol is the new one + let r = match symbols.last().unwrap().as_str() { + // Terminate parsing when we see `;` + ";" => None, + // Variable substitution -- parse the following as a block + "$" => Some("$block$".into()), + // Block parsed, replace it with `?` as SQL parameter + "$block$" => { + if !sql.is_empty() { + sql.push(' '); + } + sql.push('?'); + Some(lookahead.into()) // Always accept the next token + } + // Otherwise simply concat the tokens + _ => { + if !sql.is_empty() { + sql.push(' '); + } + sql.push_str(symbols.last().unwrap().as_str()); + Some(lookahead.into()) // Always accept the next token + } + }; + + // SQL statement done! + *state = sql.into(); + + match lookahead { + // End of script? + "{EOF}" => Ok(None), + _ => Ok(r), + } + }, + false, + |context, inputs, state| { + // Our SQL statement + let sql = state.as_immutable_string_ref().unwrap(); + let mut output = sql.to_string(); + + // Inputs will be parameters + for input in inputs { + let value = context.eval_expression_tree(input).unwrap(); + output.push('\n'); + output.push_str(&value.to_string()); + } + + Ok(output.into()) + }, + ); + + let mut scope = Scope::new(); + scope.push("id", 123 as INT); + + assert_eq!(engine.eval_with_scope::(&mut scope, "SELECT * FROM table WHERE id = ${id}").unwrap(), "SELECT * FROM table WHERE id = ?\n123"); +} From 5107a6a3be3298f45c84dc5be51cb873d5673964 Mon Sep 17 00:00:00 2001 From: Stephen Chung Date: Fri, 27 Sep 2024 22:28:58 +0800 Subject: [PATCH 2/2] Fix Scope deserialization lifetime --- src/serde/deserialize.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serde/deserialize.rs b/src/serde/deserialize.rs index c8f59bd85..6325ac1f9 100644 --- a/src/serde/deserialize.rs +++ b/src/serde/deserialize.rs @@ -270,7 +270,7 @@ impl<'de> Deserialize<'de> for ImmutableString { } } -impl<'de> Deserialize<'de> for Scope<'de> { +impl<'de> Deserialize<'de> for Scope<'_> { #[inline(always)] fn deserialize>(deserializer: D) -> Result { #[derive(Debug, Clone, Hash, Deserialize)]