Skip to content

Commit

Permalink
implement action of 4D tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
daanhb committed Oct 3, 2022
1 parent ea056a6 commit 6d15e97
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/bases/poly/ops/laguerre.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ iscompatiblegrid(dict::Laguerre, grid::LaguerreNodes) = length(dict) == length(g
isorthogonal(dict::Laguerre, measure::GaussLaguerre) = laguerre_α(dict) laguerre_α(measure) && opsorthogonal(dict, measure)

isorthonormal(dict::Laguerre, measure::LaguerreWeight) = isorthogonal(dict, measure) && laguerre_α(dict) == 0
isorthonormal(dict::Laguerre, measure::GaussLaguerre) where T = isorthogonal(dict, measure) && laguerre_α(dict) == 0
isorthonormal(dict::Laguerre, measure::GaussLaguerre) = isorthogonal(dict, measure) && laguerre_α(dict) == 0
issymmetric(::Laguerre) = false

isorthogonal(dict::Laguerre, measure::LaguerreWeight) = dict.α == measure.α
Expand Down
157 changes: 137 additions & 20 deletions src/operator/tensorproductoperator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,6 @@ function apply_tensor!(op, coef_dest, coef_src, operators::Tuple{A,B}, scratch,
op2 = operators[2]
intermediate = scratch[1]

# println("coef_src: ", size(coef_src))
# println("coef_dest: ", size(coef_dest))
# println("src1: ", size(src1))
# println("src2: ", size(src2))
# println("dest1: ", size(dest1))
# println("dest2: ", size(dest2))
# println("op1: ", size(op1))
# println("op2: ", size(op2))
# println("intermediate: ", size(intermediate))
#
for j in eachindex(src2)
for i in eachindex(src1)
src1[i] = coef_src[i,j]
Expand Down Expand Up @@ -202,16 +192,6 @@ function apply_tensor!(op, coef_dest, coef_src, operators::Tuple{A,B,C}, scratch
intermediate1 = scratch[1]
intermediate2 = scratch[2]

# println("coef_src: ", size(coef_src))
# println("coef_dest: ", size(coef_dest))
# println("src1: ", size(src1))
# println("src2: ", size(src2))
# println("dest1: ", size(dest1))
# println("dest2: ", size(dest2))
# println("op1: ", size(op1))
# println("op2: ", size(op2))
# println("intermediate: ", size(intermediate))

for k in eachindex(src3)
for j in eachindex(src2)
for i in eachindex(src1)
Expand Down Expand Up @@ -292,6 +272,143 @@ function apply_inplace_tensor!(op, coef_srcdest, operators::Tuple{A,B,C}, src_sc
coef_srcdest
end

function apply_tensor!(op, coef_dest, coef_src, operators::Tuple{A,B,C,D}, scratch, src_scratch, dest_scratch) where {A,B,C,D}
src1 = src_scratch[1]
src2 = src_scratch[2]
src3 = src_scratch[3]
src4 = src_scratch[4]
dest1 = dest_scratch[1]
dest2 = dest_scratch[2]
dest3 = dest_scratch[3]
dest4 = dest_scratch[4]
op1 = operators[1]
op2 = operators[2]
op3 = operators[3]
op4 = operators[4]
intermediate1 = scratch[1]
intermediate2 = scratch[2]
intermediate3 = scratch[3]

for l in eachindex(src4)
for k in eachindex(src3)
for j in eachindex(src2)
for i in eachindex(src1)
src1[i] = coef_src[i,j,k,l]
end
apply!(op1, dest1, src1)
for i in eachindex(dest1)
intermediate1[i,j,k,l] = dest1[i]
end
end
end
end
for l in eachindex(src4)
for k in eachindex(src3)
for i in eachindex(dest1)
for j in eachindex(src2)
src2[j] = intermediate1[i,j,k,l]
end
apply!(op2, dest2, src2)
for j in eachindex(dest2)
intermediate2[i,j,k,l] = dest2[j]
end
end
end
end
for l in eachindex(src4)
for i in eachindex(dest1)
for j in eachindex(dest2)
for k in eachindex(src3)
src3[k] = intermediate2[i,j,k,l]
end
apply!(op3, dest3, src3)
for k in eachindex(dest3)
intermediate3[i,j,k,l] = dest3[k]
end
end
end
end
for i in eachindex(dest1)
for j in eachindex(dest2)
for k in eachindex(dest3)
for l in eachindex(src4)
src4[l] = intermediate3[i,j,k,l]
end
apply!(op4, dest4, src4)
for l in eachindex(dest4)
coef_dest[i,j,k,l] = dest4[l]
end
end
end
end
coef_dest
end

function apply_inplace_tensor!(op, coef_srcdest, operators::Tuple{A,B,C,D}, src_scratch) where {A,B,C,D}
src1 = src_scratch[1]
src2 = src_scratch[2]
src3 = src_scratch[3]
src4 = src_scratch[4]
op1 = operators[1]
op2 = operators[2]
op3 = operators[3]
op4 = operators[4]

for l in eachindex(src4)
for k in eachindex(src3)
for j in eachindex(src2)
for i in eachindex(src1)
src1[i] = coef_srcdest[i,j,k,l]
end
apply!(op1, src1)
for i in eachindex(src1)
coef_srcdest[i,j,k,l] = src1[i]
end
end
end
end
for l in eachindex(src4)
for k in eachindex(src3)
for i in eachindex(dest1)
for j in eachindex(src2)
src2[j] = coef_srcdest[i,j,k,l]
end
apply!(op2, src2)
for j in eachindex(src2)
coef_srcdest[i,j,k,l] = src2[j]
end
end
end
end
for l in eachindex(src4)
for i in eachindex(dest1)
for j in eachindex(dest2)
for k in eachindex(src3)
src3[k] = coef_srcdest[i,j,k,l]
end
apply!(op3, src3)
for k in eachindex(src3)
coef_srcdest[i,j,k,l] = src3[k]
end
end
end
end
for i in eachindex(dest1)
for j in eachindex(dest2)
for k in eachindex(dest3)
for l in eachindex(src4)
src4[l] = coef_srcdest[i,j,k,l]
end
apply!(op4, src4)
for l in eachindex(src4)
coef_srcdest[i,j,k,l] = src4[l]
end
end
end
end
coef_srcdest
end

function stencilarray(op::TensorProductOperator)
A = Any[]
push!(A, component(op,1))
Expand Down
2 changes: 1 addition & 1 deletion src/util/arrays/specialarrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ unsafe_array_getindex(A::RestrictionIndexMatrix{T,1}, i::Int, j::Int) where {T}
unsafe_array_getindex(A::RestrictionIndexMatrix{T,N}, i::Int, j::Int) where {T,N} =
unsafe_array_getindex(A, i, CartesianIndices(CartesianIndex(_original_size(A)))[j])

unsafe_array_getindex(A::RestrictionIndexMatrix{T,N}, i::Int, j::CartesianIndex{N}) where {T,I,N} =
unsafe_array_getindex(A::RestrictionIndexMatrix{T,N}, i::Int, j::CartesianIndex{N}) where {T,N} =
@inbounds getindex(subindices(A),i)==j ? one(T) : zero(T)

Base.eltype(::IndexMatrix{T}) where T = T
Expand Down

0 comments on commit 6d15e97

Please sign in to comment.