Skip to content

Commit

Permalink
Merge pull request #524 from stan-dev/feature/standalone_functions
Browse files Browse the repository at this point in the history
Feature/standalone functions
  • Loading branch information
rok-cesnovar authored Oct 5, 2020
2 parents 6dd4f2b + 011adbe commit 17c9e13
Show file tree
Hide file tree
Showing 9 changed files with 848 additions and 18 deletions.
85 changes: 68 additions & 17 deletions src/stan_math_backend/Stan_math_code_gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ open Fmt
open Expression_gen
open Statement_gen

let standalone_functions = ref false

let stanc_args_to_print =
let sans_model_and_hpp_paths x =
not String.(is_suffix ~suffix:".stan" x || is_prefix ~prefix:"--o" x)
Expand Down Expand Up @@ -155,6 +157,9 @@ let pp_template_decorator ppf = function
| templates ->
pf ppf "@[<hov>template <%a>@]@ " (list ~sep:comma string) templates

let mk_extra_args templates args =
List.map ~f:(fun (t, v) -> t ^ "& " ^ v) (List.zip_exn templates args)

(** Print the C++ function definition.
@param ppf A pretty printer
Refactor this please - one idea might be to have different functions for
Expand All @@ -170,9 +175,6 @@ let pp_fun_def ppf Program.({fdrt; fdname; fdargs; fdbody; _})
else if is_rng then (["base_rng__"], ["RNG"])
else ([], [])
in
let mk_extra_args templates args =
List.map ~f:(fun (t, v) -> t ^ "& " ^ v) (List.zip_exn templates args)
in
let argtypetemplates, args = get_templates_and_args fdargs in
let pp_body ppf (Stmt.Fixed.({pattern; _}) as fdbody) =
let text = pf ppf "%s@;" in
Expand Down Expand Up @@ -266,6 +268,47 @@ let pp_fun_def ppf Program.({fdrt; fdname; fdargs; fdbody; _})
, List.map ~f:(fun (_, name, _) -> name) fdargs
@ extra @ ["pstream__"] )

(* Creates functions outside the model namespaces which only call the ones
inside the namespaces *)
let pp_standalone_fun_def namespace_fun ppf
Program.({fdname; fdargs; fdbody; fdrt; _}) =
let extra, extra_templates =
if is_user_lp fdname then
(["lp__"; "lp_accum__"], ["double"; "stan::math::accumulator<double>"])
else if String.is_suffix fdname ~suffix:"_rng" then
(["base_rng__"], ["boost::ecuyer1988"])
else ([], [])
in
let args =
List.map
~f:(fun (_, name, ut) ->
strf "const %a& %s" pp_unsizedtype_custom_scalar
(stantype_prim_str ut, ut)
name )
fdargs
in
let pp_sig_standalone ppf _ =
let arg_strs =
args
@ mk_extra_args extra_templates extra
@ ["std::ostream* pstream__ = nullptr"]
in
pf ppf "(@[<hov>%a@]) " (list ~sep:comma string) arg_strs
in
let mark_function_comment = "// [[stan::function]]" in
let return_type = match fdrt with None -> "void" | _ -> "auto" in
let return_stmt = match fdrt with None -> "" | _ -> "return " in
match fdbody with
| None -> pf ppf ";@ "
| Some _ ->
pf ppf "@,%s@,%s %s%a @,{@, %s%s::%a;@,}@," mark_function_comment
return_type fdname pp_sig_standalone "" return_stmt namespace_fun
pp_call_str
( ( if is_user_dist fdname || is_user_lp fdname then fdname ^ "<false>"
else fdname )
, List.map ~f:(fun (_, name, _) -> name) fdargs @ extra @ ["pstream__"]
)

let version = "// Code generated by %%NAME%% %%VERSION%%"
let includes = "#include <stan/model/model_header.hpp>"

Expand All @@ -283,7 +326,7 @@ let pp_validate_data ppf (name, st) =
pp_call
("context__.to_vec", pp_expr, SizedType.get_dims st)

(** Print the constructor of the model class.
(** Print the constructor of the model class.
Read in data steps:
1. context__.validate_dims() to verify the dimensions are correct at runtime.
1. find vals_%s__ from context__.vals_%s(vident)
Expand Down Expand Up @@ -360,7 +403,7 @@ let pp_model_private ppf {Program.prepare_data; _} =
@param intro Anything that needs printed before the method body.
@param outro Anything that needs printed after the method body.
@param cv_attr Optional parameter to add method attributes.
@param ppbody (?A pretty printer of the method's body)
@param ppbody (?A pretty printer of the method's body)
*)
let pp_method ppf rt name params intro ?(outro = []) ?(cv_attr = ["const"])
ppbody =
Expand Down Expand Up @@ -438,8 +481,8 @@ let pp_write_array ppf {Program.prog_name; generate_quantities; _} =
in
pp_method_b ppf "void" "write_array" params intro generate_quantities

(** Prints the for loop for `constrained_param_names`
and `unconstrained_param_names`
(** Prints the for loop for `constrained_param_names`
and `unconstrained_param_names`
@param index_ids Optional named parameter of a SizedType's dimensions
@param ppf A pretty printer
@param dims A list of the dimensions of a SizedType
Expand Down Expand Up @@ -501,7 +544,7 @@ let pp_constrained_param_names ppf {Program.output_vars; _} =
(list ~sep:cut pp_param_names, gqvars) )
~cv_attr

(* Print the `unconstrained_param_names` method of the model class.
(* Print the `unconstrained_param_names` method of the model class.
This is just a copy of constrained, I need to figure out which one is wrong
and fix it eventually. From Bob,
Expand Down Expand Up @@ -589,7 +632,7 @@ let pp_log_prob ppf Program.({prog_name; log_prob; _}) =
let cv_attr = ["const"] in
pp_method_b ppf "T__" "log_prob" params intro log_prob ~outro ~cv_attr

(** Print the body of the constrained and unconstrained sizedtype methods
(** Print the body of the constrained and unconstrained sizedtype methods
in the model class
@param ppf A pretty printer
@param method_name The name of the method to wrap the body in.
Expand Down Expand Up @@ -728,7 +771,7 @@ using namespace stan::math;
using stan::math::pow; |}

(** Functions needed in the model class not defined yet in stan math.
FIXME: Move these to the Stan repo when these repos are joined.
FIXME: Move these to the Stan repo when these repos are joined.
*)
let custom_functions =
{|
Expand Down Expand Up @@ -801,19 +844,27 @@ let pp_prog ppf (p : Program.Typed.t) =
(is_fun_used_with_variadic_fn Stan_math_signatures.is_reduce_sum_fn p)
(is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_ode_fn p)
in
let reduce_sum_struct_decl =
let reduce_sum_struct_decls =
String.Set.map
~f:(fun x -> "struct " ^ x ^ reduce_sum_functor_suffix ^ ";")
(is_fun_used_with_variadic_fn Stan_math_signatures.is_reduce_sum_fn p)
|> Set.elements |> String.concat ~sep:"\n"
in
pf ppf "@[<v>@ %s@ %s@ namespace %s {@ %s@ %s@ %a@ %s@ %a@ %a@ }@ @]" version
includes (namespace p) custom_functions usings Locations.pp_globals s
(String.concat ~sep:"\n" (String.Set.elements reduce_sum_struct_decl))
reduce_sum_struct_decls
(list ~sep:cut pp_fun_def_with_variadic_fn_list)
p.functions_block pp_model p ;
pf ppf "@,using stan_model = %s_namespace::%s;@," p.prog_name p.prog_name ;
pf ppf
{|
p.functions_block
(if !standalone_functions then fun _ _ -> () else pp_model)
p ;
if !standalone_functions then
pf ppf "@[<v>%a@ @]"
(list ~sep:cut (pp_standalone_fun_def (namespace p)))
p.functions_block
else (
pf ppf "@,using stan_model = %s_namespace::%s;@," p.prog_name p.prog_name ;
pf ppf
{|
#ifndef USING_R

// Boilerplate
Expand All @@ -827,4 +878,4 @@ stan::model::model_base& new_model(

#endif
|} ;
pf ppf "@[<v>%a@]" pp_register_map_rect_functors p
pf ppf "@[<v>%a@]" pp_register_map_rect_functors p )
6 changes: 5 additions & 1 deletion src/stanc/stanc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,11 @@ let options =
, " Deprecated. Same as --include-paths." )
; ( "--use-opencl"
, Arg.Set Transform_Mir.use_opencl
, " If set, try to use matrix_cl signatures." ) ]
, " If set, try to use matrix_cl signatures." )
; ( "--standalone-functions"
, Arg.Set Stan_math_code_gen.standalone_functions
, " If set, the generated C++ will be the standalone functions C++ code."
) ]

let print_deprecated_arg_warning =
(* is_prefix is used to also cover the --include-paths=... *)
Expand Down
1 change: 1 addition & 0 deletions src/stancjs/stancjs.ml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ let stan2cpp model_name model_string flags =
Semantic_check.check_that_all_functions_have_definition :=
not (is_flag_set "allow_undefined" || is_flag_set "allow-undefined") ;
Transform_Mir.use_opencl := is_flag_set "use-opencl" ;
Stan_math_code_gen.standalone_functions := is_flag_set "standalone-functions" ;
let ast =
Parse.parse_string Parser.Incremental.program model_string
|> Result.map_error ~f:(Fmt.to_to_string Errors.pp_syntax_error)
Expand Down
43 changes: 43 additions & 0 deletions test/integration/good/code-gen/standalone_functions/basic.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
functions {
real my_log1p_exp(real x) {
return log1p_exp(x);
}

real array_fun(real[] a)
{
return sum(a);
}

real int_array_fun(int[] a)
{
return sum(a);
}

vector my_vector_mul_by_5(vector x)
{
vector[num_elements(x)] result = x * 5.0;
return result;
}

int int_only_multiplication(int a, int b) {
return a*b;
}

real test_lgamma(real x) {
return lgamma(x);
}

// test special functions
void test_lp(real a) {
a ~ normal(0, 1);
}

real test_rng(real a) {
return normal_rng(a, 1);
}

real test_lpdf(real a, real b) {
return normal_lpdf(a | b, 1);
}
}

Loading

0 comments on commit 17c9e13

Please sign in to comment.