Skip to content

Commit

Permalink
Merge pull request #1356 from stan-dev/fix/tuple-arg-passing
Browse files Browse the repository at this point in the history
Fix two issues with tuple functions
  • Loading branch information
WardBrian authored Sep 8, 2023
2 parents 504f5c1 + ca286a2 commit 934496d
Show file tree
Hide file tree
Showing 35 changed files with 19,713 additions and 5,603 deletions.
1 change: 1 addition & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def runPerformanceTests(String testsPath, String stancFlags = ""){
sh """
cd performance-tests-cmdstan/cmdstan
echo 'O=0' >> make/local
echo 'CXXFLAGS+=-Wall' >> make/local
make -j${env.PARALLEL} build; cd ..
./runPerformanceTests.py -j${env.PARALLEL} --runs=0 ${testsPath}
"""
Expand Down
20 changes: 10 additions & 10 deletions src/analysis_and_optimization/Memory_patterns.ml
Original file line number Diff line number Diff line change
Expand Up @@ -189,17 +189,17 @@ and query_initial_demotable_funs (in_loop : bool) (acc : string Set.Poly.t)
match is_fun_soa_supported name exprs with
| true -> Set.Poly.union acc demoted_eigen_names
| false -> Set.Poly.union acc demoted_and_top_level_names ) )
| CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec) ->
| CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec | FnMakeTuple) ->
Set.Poly.union acc demoted_and_top_level_names
| CompilerInternal (_ : 'a Internal_fun.t) -> acc
| UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) ->
Set.Poly.union acc demoted_and_top_level_names

(**
* Recurse through subexpressions and return a list of Unsized types.
* Recursion continues until
* 1. A non-autodiffable type is found
* 2. An autodiffable scalar is found
(**
* Recurse through subexpressions and return a list of Unsized types.
* Recursion continues until
* 1. A non-autodiffable type is found
* 2. An autodiffable scalar is found
* 3. A `Var` type is found that is an autodiffable matrix
*)
let rec extract_nonderived_admatrix_types
Expand All @@ -225,11 +225,11 @@ let rec extract_nonderived_admatrix_types
else [(adlevel, type_)]

(**
* Recurse through functions to find nonderived ad matrix types.
* Special cases for StanLib functions are for
* Recurse through functions to find nonderived ad matrix types.
* Special cases for StanLib functions are for
* - `check_matching_dims`: compiler function that has no effect on optimization
* - `rep_*vector` These are templated in the C++ to cast up to `Var<Matrix>` types
* - `rep_matrix`. When it's only a scalar being propogated an math library overload can upcast to `Var<Matrix>`
* - `rep_*vector` These are templated in the C++ to cast up to `Var<Matrix>` types
* - `rep_matrix`. When it's only a scalar being propogated an math library overload can upcast to `Var<Matrix>`
*)
and extract_nonderived_admatrix_types_fun (kind : 'a Fun_kind.t)
(exprs : Expr.Typed.t list) =
Expand Down
17 changes: 13 additions & 4 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,17 @@ let truncate_dist ud_dists (id : Ast.identifier)
, Some y ) } in
let funapp meta kind name args =
Expr.{Fixed.pattern= FunApp (trans_fn_kind kind name, args); meta} in
let maybe_promote_to_real tp lb : Expr.Typed.t =
match (tp, Expr.Typed.type_of lb) with
| UnsizedType.UInt, _ -> lb
| _, UInt ->
{ pattern= Promotion (lb, UReal, lb.meta.adlevel)
; meta= {lb.meta with type_= UReal} }
| _ -> lb in
let inclusive_bound tp (lb : Expr.Typed.t) =
if UnsizedType.is_int_type tp then
Expr.Helpers.binop lb Minus Expr.Helpers.one
else lb in
else maybe_promote_to_real tp lb in
let size_adjust e =
if
(not (UnsizedType.is_container ast_obs.Ast.emeta.type_))
Expand All @@ -172,18 +179,20 @@ let truncate_dist ud_dists (id : Ast.identifier)
(funapp lb.meta fk fn
(inclusive_bound tp lb :: trans_exprs ast_args) ) ) ) ]
| TruncateDownFrom ub ->
let fk, fn, _ = find_function_info cdf_suffices in
let fk, fn, tp = find_function_info cdf_suffices in
let ub = trans_expr ub in
[ trunc Greater "max" ub
(targetme ub.meta.loc
(size_adjust (funapp ub.meta fk fn (ub :: trans_exprs ast_args))) )
(size_adjust
(funapp ub.meta fk fn
(maybe_promote_to_real tp ub :: trans_exprs ast_args) ) ) )
]
| TruncateBetween (lb, ub) ->
let fk, fn, tp = find_function_info cdf_suffices in
let lb, ub = (trans_expr lb, trans_expr ub) in
let expr args =
funapp ub.meta (Ast.StanLib FnPlain) "log_diff_exp"
[ funapp ub.meta fk fn (ub :: args)
[ funapp ub.meta fk fn (maybe_promote_to_real tp ub :: args)
; funapp ub.meta fk fn (inclusive_bound tp lb :: args) ] in
let statement =
match
Expand Down
31 changes: 23 additions & 8 deletions src/stan_math_backend/Cpp.ml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ module Types = struct
let local_scalar = TypeLiteral "local_scalar_t__"

(** A [std::vector<t>] *)
let std_vector t = StdVector t
let rec std_vector ?(dims = 1) t =
if dims = 0 then t else std_vector ~dims:(dims - 1) (StdVector t)

let bool = TypeLiteral "bool"
let complex s = Complex s
Expand Down Expand Up @@ -266,6 +267,7 @@ module Decls = struct
VariableDefn
(make_variable_defn ~type_:Int ~name:"current_statement__"
~init:(Assignment (Literal "0")) () )
:: Stmts.unused "current_statement__"

let dummy_var =
VariableDefn
Expand Down Expand Up @@ -299,7 +301,7 @@ end

type template_parameter =
| Typename of string (** The name of a template typename *)
| RequireIs of string * string
| RequireAllCondition of [`Exact of string | `OneOf of string list] * type_
(** A C++ type trait (e.g. is_arithmetic) and the template
name which needs to satisfy that.
These are collated into one require_all_t<> *)
Expand Down Expand Up @@ -412,15 +414,22 @@ module Printing = struct

let pp_requires ~default ppf requires =
if not (List.is_empty requires) then
let pp_require ppf (trait, name) = pf ppf "%s<%s>" trait name in
let pp_single_require t ppf trait = pf ppf "%s<%a>" trait pp_type_ t in
let pp_require ppf (req, t) =
match req with
| `Exact trait -> pp_single_require t ppf trait
| `OneOf traits ->
pf ppf "stan::math::disjunction<@[%a@]>"
(list ~sep:comma (pp_single_require t))
traits in
pf ppf ",@ stan::require_all_t<@[%a@]>*%s"
(list ~sep:comma pp_require)
requires
(if default then " = nullptr" else "")

(**
Pretty print a list of templates as [template <parameter-list>].name
This function pools together [RequireIs] nodes into a [require_all_t]
This function pools together [RequireAllCondition] nodes into a [require_all_t]
*)
let pp_template ~default ppf template_parameters =
let pp_basic_template ppf = function
Expand All @@ -432,7 +441,7 @@ module Printing = struct
if not (List.is_empty template_parameters) then
let templates, requires =
List.partition_map template_parameters ~f:(function
| RequireIs (trait, name) -> Second (trait, name)
| RequireAllCondition (trait, name) -> Second (trait, name)
| Typename name -> First (`Typename name)
| Bool name -> First (`Bool name)
| Require (requirement, args) -> First (`Require (requirement, args)) )
Expand Down Expand Up @@ -727,7 +736,7 @@ module Tests = struct
let ts =
let open Types in
[ matrix (complex local_scalar); const_char_array 43
; std_vector (std_vector Double); const_ref (TemplateType "T0__") ] in
; std_vector ~dims:2 Double; const_ref (TemplateType "T0__") ] in
let open Fmt in
pf stdout "@[<v>%a@]" (list ~sep:comma Printing.pp_type_) ts ;
[%expect
Expand Down Expand Up @@ -762,15 +771,21 @@ module Tests = struct
let funs =
[ make_fun_defn
~templates_init:
([[Typename "T0__"; RequireIs ("stan::is_foobar", "T0__")]], true)
( [ [ Typename "T0__"
; RequireAllCondition
(`Exact "stan::is_foobar", TemplateType "T0__") ] ]
, true )
~name:"foobar" ~return_type:Void ~inline:true ()
; (let s =
[ Comment "A potentially \n long comment"
; Expression (Assign (Var "foo", Literal "3")) ] in
let rethrow = Stmts.rethrow_located s in
make_fun_defn
~templates_init:
([[Typename "T0__"; RequireIs ("stan::is_foobar", "T0__")]], false)
( [ [ Typename "T0__"
; RequireAllCondition
(`Exact "stan::is_foobar", TemplateType "T0__") ] ]
, false )
~name:"foobar" ~return_type:Void ~inline:true ~body:rethrow () ) ]
in
let open Fmt in
Expand Down
35 changes: 17 additions & 18 deletions src/stan_math_backend/Lower_expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,8 @@ and lower_functionals fname suffix es mem_pattern =
| _, args -> (fname, args @ [msgs]) in
let fname = stan_namespace_qualify fname in
let templates = templates false suffix in
Exprs.templated_fun_call fname templates (lower_exprs args) in
Exprs.templated_fun_call fname templates
(lower_exprs ~promote_reals:true args) in
Some lower_hov

and lower_fun_app suffix fname es mem_pattern
Expand All @@ -400,23 +401,21 @@ and lower_user_defined_fun f suffix es =

and lower_compiler_internal ad ut f es =
let open Expression_syntax in
let gen_tuple_literal es : expr =
(* NB: This causes some inefficencies such as eagerly
evaluating eigen expressions and copying data vectors *)
let is_simple (e : Expr.Typed.t) =
match e.pattern with
| Var _ -> e.meta.adlevel <> DataOnly
| Lit _ -> true
| Promotion ({pattern= Var _ | Lit _; _}, _, _) -> is_scalar e
| _ -> false in
if List.for_all ~f:is_simple es then
fun_call "std::forward_as_tuple" (lower_exprs es)
else
Constructor
( Tuple
(List.map es ~f:(fun {meta= {adlevel; type_; _}; _} ->
lower_unsizedtype_local adlevel type_ ) )
, lower_exprs es ) in
let gen_tuple_literal (es : Expr.Typed.t list) : expr =
(* we make full copies of tuples
due to a lack of templating sophistication
in function generation *)
let types =
List.map es ~f:(fun {meta= {adlevel; type_; _}; _} ->
let base_type = lower_unsizedtype_local adlevel type_ in
if
UnsizedType.is_dataonlytype adlevel
&& not
( UnsizedType.is_scalar_type type_
|| UnsizedType.contains_tuple type_ )
then Types.const_ref base_type
else base_type ) in
Constructor (Tuple types, lower_exprs es) in
match f with
| Internal_fun.FnMakeArray ->
let ut =
Expand Down
Loading

0 comments on commit 934496d

Please sign in to comment.