Skip to content

Commit

Permalink
Merge pull request #1421 from stan-dev/fix/complex-matrix-division-mi…
Browse files Browse the repository at this point in the history
…scompilation

Fix complex matrix division miscompilation
  • Loading branch information
WardBrian authored May 3, 2024
2 parents 9e9953f + 5b4e80b commit 2c42646
Show file tree
Hide file tree
Showing 5 changed files with 1,382 additions and 1,039 deletions.
16 changes: 14 additions & 2 deletions src/stan_math_backend/Lower_expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,18 @@ let functor_suffix_select = function
let is_scalar e =
match Expr.Typed.type_of e with UInt | UReal | UComplex -> true | _ -> false

let is_matrix e = Expr.Typed.type_of e = UMatrix
let is_row_vector e = Expr.Typed.type_of e = URowVector
(** Used to determine if [operator/] should be
mdivide_right() or divide() *)
let is_matrix e =
match Expr.Typed.type_of e with
| UMatrix | UComplexMatrix -> true
| _ -> false

let is_row_vector e =
match Expr.Typed.type_of e with
| URowVector | UComplexRowVector -> true
| _ -> false

let first es = List.nth_exn es 0
let second es = List.nth_exn es 1
let default_multiplier = 1
Expand Down Expand Up @@ -241,6 +251,8 @@ and lower_operator_app op es_in =
| Minus -> lower_binary_op Subtract "stan::math::subtract" es
| Times -> lower_binary_op Multiply "stan::math::multiply" es
| Divide | IntDivide ->
(* XXX: This conditional is probably a sign that we need to rethink how we store Operators
in the MIR *)
if
is_matrix (second es)
&& (is_matrix (first es) || is_row_vector (first es))
Expand Down
10 changes: 10 additions & 0 deletions test/integration/good/code-gen/complex_numbers/basic_op_param.stan
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ transformed parameters {
tp_c_matrix = vec * crowvec;
tp_c_matrix = cvec * rowvec;

// matrix-matrix division
tp_c_matrix = cmat / cmat;
tp_c_matrix = cmat / mat;
tp_c_matrix = mat / cmat;

complex_vector[N] tp_c_vector = crowvec';
// matrix-vector products
tp_c_vector = cmat * cvec;
Expand Down Expand Up @@ -115,6 +120,11 @@ transformed parameters {
tp_c_rowvector = -crowvec;
tp_c_rowvector = -rowvec;

// rowvector-matrix division
tp_c_rowvector = crowvec / cmat;
tp_c_rowvector = crowvec / mat;
tp_c_rowvector = rowvec / cmat;

complex tp_c;
// rowvector-vector multiply
tp_c = crowvec * cvec;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ generated quantities {
gq_c_matrix = vec * crowvec;
gq_c_matrix = cvec * rowvec;

// matrix-matrix division
gq_c_matrix = cmat / cmat;
gq_c_matrix = cmat / mat;
gq_c_matrix = mat / cmat;

complex_vector[N] gq_c_vector = crowvec';
// matrix-vector products
gq_c_vector = cmat * cvec;
Expand Down Expand Up @@ -113,6 +118,11 @@ generated quantities {
gq_c_rowvector = -crowvec;
gq_c_rowvector = -rowvec;

// rowvector-matrix division
gq_c_rowvector = crowvec / cmat;
gq_c_rowvector = crowvec / mat;
gq_c_rowvector = rowvec / cmat;

complex gq_c;
// rowvector-vector multiply
gq_c = crowvec * cvec;
Expand Down
91 changes: 51 additions & 40 deletions test/integration/good/code-gen/complex_numbers/basic_ops_mix.stan
Original file line number Diff line number Diff line change
@@ -1,32 +1,29 @@
data {
int N;
complex_matrix[N,N] cmat;
complex_matrix[N, N] cmat;
complex_vector[N] cvec;
complex_row_vector[N] crowvec;
complex z;

matrix[N,N] mat;
matrix[N, N] mat;
vector[N] vec;
row_vector[N] rowvec;
real r;
}

parameters {
complex_matrix[N,N] cvmat;
complex_matrix[N, N] cvmat;
complex_vector[N] cvvec;
complex_row_vector[N] cvrowvec;
complex zv;

matrix[N,N] vmat;
matrix[N, N] vmat;
vector[N] vvec;
row_vector[N] vrowvec;
real v;
}


transformed parameters {
complex_matrix[N,N] tp_c_matrix;

complex_matrix[N, N] tp_c_matrix;
// matrix-matrix multiply and elt
tp_c_matrix = cmat * cvmat;
tp_c_matrix = cmat * vmat;
Expand All @@ -37,160 +34,174 @@ transformed parameters {
tp_c_matrix = cmat ./ cvmat;
tp_c_matrix = cmat ./ vmat;
tp_c_matrix = mat ./ cvmat;

// matrix-scalar multiply
tp_c_matrix = cmat * zv;
tp_c_matrix = z * cvmat;
tp_c_matrix = r * cvmat;
tp_c_matrix = cmat * v;
tp_c_matrix = mat * zv;
tp_c_matrix = z * vmat;

// matrix-matrix addition and subtraction
tp_c_matrix = cmat + cvmat;
tp_c_matrix = cmat + vmat;
tp_c_matrix = mat + cvmat;
tp_c_matrix = cmat - cvmat;
tp_c_matrix = cmat - vmat;
tp_c_matrix = mat - cvmat;

// vector-rowvector multiply
tp_c_matrix = cvec * cvrowvec;
tp_c_matrix = vec * cvrowvec;
tp_c_matrix = cvec * vrowvec;


// matrix-matrix division
tp_c_matrix = cvmat / cmat;
tp_c_matrix = cvmat / mat;
tp_c_matrix = vmat / cmat;
tp_c_matrix = cmat / cvmat;
tp_c_matrix = cmat / vmat;
tp_c_matrix = mat / cvmat;

complex_vector[N] tp_c_vector = cvrowvec';
// matrix-vector products
tp_c_vector = cmat * cvvec;
tp_c_vector = cvmat * cvec;
tp_c_vector = mat * cvvec;
tp_c_vector = cmat * vvec;

// vector-scalar multiplication
tp_c_vector = z * cvvec;
tp_c_vector = cvec * zv;
tp_c_vector = r * cvvec;
tp_c_vector = cvec * v;
tp_c_vector = z * vvec;
tp_c_vector = vec * zv;

// vector-vector elt mult and div
tp_c_vector = cvec .* cvvec;
tp_c_vector = vec .* cvvec;
tp_c_vector = cvec .* vvec;
tp_c_vector = cvec ./ cvvec;
tp_c_vector = vec ./ cvvec;
tp_c_vector = cvec ./ vvec;

// vector-vector addition and subtraction
tp_c_vector = cvec + cvvec;
tp_c_vector = vec + cvvec;
tp_c_vector = cvec + vvec;
tp_c_vector = cvec - cvvec;
tp_c_vector = vec - cvvec;
tp_c_vector = cvec - vvec;

complex_row_vector[N] tp_c_rowvector = cvvec';
// rowvector-matrix multiplication
tp_c_rowvector = crowvec * cvmat;
tp_c_rowvector = rowvec * cvmat;
tp_c_rowvector = crowvec * vmat;

// rowvector-scalar multiplication
tp_c_rowvector = z * cvrowvec;
tp_c_rowvector = crowvec * zv;
tp_c_rowvector = r * cvrowvec;
tp_c_rowvector = crowvec * v;
tp_c_rowvector = z * vrowvec;
tp_c_rowvector = rowvec * zv;

// rowvector-rowvector elt mult and div
tp_c_rowvector = crowvec .* cvrowvec;
tp_c_rowvector = crowvec .* vrowvec;
tp_c_rowvector = rowvec .* cvrowvec;
tp_c_rowvector = crowvec ./ cvrowvec;
tp_c_rowvector = crowvec ./ vrowvec;
tp_c_rowvector = rowvec ./ cvrowvec;

// rowvector-rowvector addition and subtraction
tp_c_rowvector = crowvec + cvrowvec;
tp_c_rowvector = crowvec + vrowvec;
tp_c_rowvector = rowvec + cvrowvec;
tp_c_rowvector = crowvec - cvrowvec;
tp_c_rowvector = crowvec - vrowvec;
tp_c_rowvector = rowvec - cvrowvec;


// rowvector-matrix division
tp_c_rowvector = cvrowvec / cmat;
tp_c_rowvector = cvrowvec / mat;
tp_c_rowvector = vrowvec / cmat;
tp_c_rowvector = crowvec / cvmat;
tp_c_rowvector = crowvec / vmat;
tp_c_rowvector = rowvec / cvmat;

complex tp_c;
// rowvector-vector multiply
tp_c = crowvec * cvvec;
tp_c = cvrowvec * cvec;
tp_c = crowvec * vvec;
tp_c = rowvec * cvvec;

// broadcasting
tp_c_matrix = z - cvmat;
tp_c_matrix = r - cvmat;
tp_c_matrix = cvmat - r;
tp_c_matrix = cvmat - z;

tp_c_matrix = z + cvmat;
tp_c_matrix = r + cvmat;
tp_c_matrix = cvmat + r;
tp_c_matrix = cvmat + z;

tp_c_matrix = zv - cmat;
tp_c_matrix = v - cmat;
tp_c_matrix = cmat - v;
tp_c_matrix = cmat - zv;

tp_c_matrix = zv + cmat;
tp_c_matrix = v + cmat;
tp_c_matrix = cmat + v;
tp_c_matrix = cmat + zv;

tp_c_matrix = z ./ cvmat;
tp_c_matrix = r ./ cvmat;
tp_c_matrix = cvmat ./ r;
tp_c_matrix = cvmat ./ z;

tp_c_matrix = zv ./ cmat;
tp_c_matrix = v ./ cmat;
tp_c_matrix = cmat ./ v;
tp_c_matrix = cmat ./ zv;



tp_c_matrix = z .* cvmat;
tp_c_matrix = r .* cvmat;
tp_c_matrix = cvmat .* r;
tp_c_matrix = cvmat .* z;

tp_c_matrix = zv .* cmat;
tp_c_matrix = v .* cmat;
tp_c_matrix = cmat .* v;
tp_c_matrix = cmat .* zv;

tp_c_matrix = z / cvmat;
tp_c_matrix = z / cvmat;
tp_c_matrix = r / cvmat;
tp_c_matrix = cvmat / r;
tp_c_matrix = cvmat / z;

tp_c_matrix = zv / cmat;
tp_c_matrix = v / cmat;
tp_c_matrix = cmat / v;
tp_c_matrix = cmat / zv;

tp_c_matrix = z * cvmat;
tp_c_matrix = r * cvmat;
tp_c_matrix = cvmat * r;
tp_c_matrix = cvmat * z;

tp_c_matrix = zv * cmat;
tp_c_matrix = v * cmat;
tp_c_matrix = cmat * v;
tp_c_matrix = cmat * zv;

// transformations
array[N,N] complex carray;
array[N, N] complex carray;
carray = to_array_2d(cvmat);

}
Loading

0 comments on commit 2c42646

Please sign in to comment.