Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FFT using pyvkfft and use loopy callables #114

Merged
merged 67 commits into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from 58 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
ad39de6
Use a separate class for M2L translation
isuruf Apr 29, 2022
9cb4caa
Fix docs and caching
isuruf Apr 29, 2022
c5bc9b2
Fix p2p warning
isuruf Apr 29, 2022
88c9fd8
Use VkFFT for M2L generate data
isuruf May 1, 2022
bab8f38
Fix profiling events
isuruf May 1, 2022
68864a2
simplify m2l data zeros
isuruf May 1, 2022
03cc94e
Add pyvkfft to requirements
isuruf May 2, 2022
c440950
Fix flake8 warning
isuruf May 2, 2022
78faead
Fix typo
isuruf May 2, 2022
4f33038
VkFFT for M2L preprocess local
isuruf May 2, 2022
efb6d99
vkfft for postprocess local
isuruf May 2, 2022
2e1b10e
Fix AggregateProfilingEvent
isuruf May 2, 2022
59d7be6
Fix another typo
isuruf May 2, 2022
72b4875
M2L Translation Factory
isuruf May 3, 2022
83f7fd8
vim markers
isuruf May 3, 2022
bed782f
Merge branch 'isuruf/m2l' into fft
isuruf May 3, 2022
7cf5404
Fix tests
isuruf May 5, 2022
584d2c9
Fix toys
isuruf May 5, 2022
6f5ad1f
Fix test_m2l_toeplitz
isuruf May 5, 2022
47a4a27
Fix more tests
isuruf May 5, 2022
60ef708
Use a better rscale to get the test passing
isuruf May 5, 2022
85e0ed1
Use pytential dev branch
isuruf May 5, 2022
9f74eec
Merge branch 'isuruf/m2l' into fft
isuruf May 5, 2022
9162d17
Merge branch 'main' into m2l
isuruf May 6, 2022
3880661
Merge branch 'isuruf/m2l' of https://github.com/inducer/sumpy into fft
isuruf May 6, 2022
a57f727
remove whitespace on blank line
isuruf May 6, 2022
ea6a99c
Try 2r/order instead of r/order
isuruf May 6, 2022
df991c6
Merge branch 'isuruf/m2l' of https://github.com/inducer/sumpy into fft
isuruf May 8, 2022
4eaad6a
fix using updated pytential
isuruf May 8, 2022
9119880
Merge branch 'main' of https://github.com/inducer/sumpy into fft
isuruf May 13, 2022
21236d7
Fix tests
isuruf May 13, 2022
9ddd3a9
use pytential branch with pyvkfft req
isuruf May 14, 2022
e5dea13
Add explanation about caller being responsible for the FFT
isuruf May 14, 2022
52de95d
Fix for bessel
isuruf May 14, 2022
3ddc0db
Merge branch 'main' into fft
inducer May 19, 2022
bec642e
Add pyvkfft to setup.py reqs
isuruf May 25, 2022
e6d62a4
use list comprehension
isuruf May 25, 2022
774f869
Type annotations
isuruf May 25, 2022
2c8a5bf
fix vim marker
isuruf May 25, 2022
61ddc0b
remove unused function
isuruf May 25, 2022
50f8bb3
m2l_data_inner -> m2l_data
isuruf May 25, 2022
c3eaa32
more descriptive name for child_knl
isuruf May 25, 2022
9cc214e
knl -> expr_knl for clarity
isuruf May 25, 2022
7d9f535
move loop unroll to optimized
isuruf May 25, 2022
07c1c93
Add explanation about translation_classes_dependent_data_loopy_knl
isuruf May 25, 2022
25dd7fc
make coeffs output only and rewrite
isuruf May 25, 2022
1252c5a
Re-arrange m2l so that event processing is easier
isuruf May 25, 2022
3b3f6f3
flake8: single quotes -> double quotes
isuruf May 25, 2022
8e9649f
Fix data not being input
isuruf May 25, 2022
a66a7cc
make args to cached_vkfft_app explicit
isuruf May 25, 2022
fff0d13
cache vkfftapp in wrangler
isuruf May 25, 2022
e80c71d
keep coeffs is_input and is_output for e2e
isuruf May 25, 2022
8d72d66
out-of-place fft
isuruf Jun 6, 2022
98b4bcc
Use a separate queue for configuration
isuruf Jun 6, 2022
d7da927
Merge branch 'main' into fft
isuruf Jun 6, 2022
3e9632b
allocate array for out-of-place
isuruf Jun 6, 2022
aafe8b2
fix typo
isuruf Jun 6, 2022
d4bfc05
Remove caching of opencl fft app
isuruf Jun 6, 2022
899006e
Comment out pytentual fork
isuruf Jun 9, 2022
a0942ed
fix vkfft queues
isuruf Jun 9, 2022
a5996a5
use private API for now
isuruf Jun 22, 2022
da7f310
Merge branch 'main' into fft
isuruf Jun 23, 2022
bf083c9
Merge branch 'main' into fft
isuruf Jun 25, 2022
741a1cc
Add comment on pyvkfft PR
isuruf Jun 25, 2022
15df148
Merge branch 'main' into fft
inducer Jun 25, 2022
b8af2bd
remove inplace
isuruf Aug 1, 2022
e073d73
Merge branch 'main' into fft
isuruf Aug 1, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ jobs:
run: |
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
if [[ "$DOWNSTREAM_PROJECT" == "pytential" && "$GITHUB_HEAD_REF" == "m2l" ]]; then
DOWNSTREAM_PROJECT=https://github.com/isuruf/pytential.git@m2l_translation
if [[ "$DOWNSTREAM_PROJECT" == "pytential" && "$GITHUB_HEAD_REF" == "fft" ]]; then
DOWNSTREAM_PROJECT=https://github.com/isuruf/pytential.git@pyvkfft
fi
test_downstream "$DOWNSTREAM_PROJECT"

Expand Down
1 change: 1 addition & 0 deletions .test-conda-env-py3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ dependencies:
- python-symengine=0.6.0
- pyfmmlib
- pyrsistent
- pyvkfft
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
numpy
sympy
pyrsistent
pyvkfft
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you thinking of this as an optional dependency? If yes, it should be declared as an extra in setup.py. If not, it should be in setup.py outright. (I don't see a reason not to make it a hard dependency.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made it a hard dep in bec642e

git+https://github.com/inducer/pytools.git#egg=pytools
git+https://github.com/inducer/pymbolic.git#egg=pymbolic
git+https://github.com/inducer/islpy.git#egg=islpy
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,6 @@ def write_git_revision(package_name):
"dataclasses>=0.7;python_version<='3.6'",
"sympy>=0.7.2",
"pymbolic>=2021.1",
"pyvkfft>=2022.1",
],
)
237 changes: 88 additions & 149 deletions sumpy/e2e.py

Large diffs are not rendered by default.

311 changes: 270 additions & 41 deletions sumpy/expansion/m2l.py

Large diffs are not rendered by default.

178 changes: 118 additions & 60 deletions sumpy/fmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@
E2EFromChildren, E2EFromParent,
M2LGenerateTranslationClassesDependentData,
M2LPreprocessMultipole, M2LPostprocessLocal)
from sumpy.tools import to_complex_dtype
from sumpy.tools import (to_complex_dtype, AggregateProfilingEvent,
run_opencl_fft, get_opencl_fft_app)

from typing import TypeVar, List, Union


# {{{ tree-independent data for wrangler
Expand Down Expand Up @@ -176,6 +179,11 @@ def p2p(self):
exclude_self=self.exclude_self,
strength_usage=self.strength_usage)

@memoize_method
def opencl_fft_app(self, shape, dtype, inplace):
with cl.CommandQueue(self.cl_context) as queue:
return get_opencl_fft_app(queue, shape, dtype, inplace)

# }}}


Expand All @@ -184,16 +192,28 @@ def p2p(self):
_SECONDS_PER_NANOSECOND = 1e-9


"""
EventLike objects have an attribute native_event that returns
a cl.Event that indicates the end of the event.
"""
EventLike = TypeVar("CLEventLike")


class UnableToCollectTimingData(UserWarning):
pass


class SumpyTimingFuture:

def __init__(self, queue, events):
def __init__(self, queue, events: List[Union[cl.Event, EventLike]]):
self.queue = queue
self.events = events

@property
def native_events(self) -> List[cl.Event]:
return [evt if isinstance(evt, cl.Event) else evt.native_event
for evt in self.events]

@memoize_method
def result(self):
from boxtree.timing import TimingResult
Expand All @@ -208,7 +228,7 @@ def result(self):
return TimingResult(wall_elapsed=None)

if self.events:
pyopencl.wait_for_events(self.events)
pyopencl.wait_for_events(self.native_events)

result = 0
for event in self.events:
Expand All @@ -222,7 +242,7 @@ def done(self):
return all(
event.get_info(cl.event_info.COMMAND_EXECUTION_STATUS)
== cl.command_execution_status.COMPLETE
for event in self.events)
for event in self.native_events)

# }}}

Expand Down Expand Up @@ -389,10 +409,18 @@ def local_expansion_zeros(self, template_ary):
dtype=self.dtype)

def m2l_translation_classes_dependent_data_zeros(self, queue):
return cl.array.zeros(
queue,
self.m2l_translation_classes_dependent_data_level_starts()[-1],
dtype=self.preprocessed_mpole_dtype)
result = []
for level in range(self.tree.nlevels):
expn_start, expn_stop = \
self.m2l_translation_classes_dependent_data_level_starts()[
level:level+2]
translation_class_start, translation_class_stop = \
self.m2l_translation_class_level_start_box_nrs()[level:level+2]
exprs_level = cl.array.zeros(queue, expn_stop - expn_start,
dtype=self.preprocessed_mpole_dtype)
result.append(exprs_level.reshape(
translation_class_stop - translation_class_start, -1))
return result

def multipole_expansions_view(self, mpole_exps, level):
expn_start, expn_stop = \
Expand All @@ -412,14 +440,10 @@ def local_expansions_view(self, local_exps, level):

def m2l_translation_classes_dependent_data_view(self,
m2l_translation_classes_dependent_data, level):
expn_start, expn_stop = \
self.m2l_translation_classes_dependent_data_level_starts()[level:level+2]
translation_class_start, translation_class_stop = \
translation_class_start, _ = \
self.m2l_translation_class_level_start_box_nrs()[level:level+2]

exprs_level = m2l_translation_classes_dependent_data[expn_start:expn_stop]
return (translation_class_start, exprs_level.reshape(
translation_class_stop - translation_class_start, -1))
exprs_level = m2l_translation_classes_dependent_data[level]
return (translation_class_start, exprs_level)

@memoize_method
def m2l_preproc_mpole_expansions_level_starts(self):
Expand All @@ -434,18 +458,19 @@ def order_to_size(order):
level_starts=self.tree.level_start_box_nrs)

def m2l_preproc_mpole_expansion_zeros(self, template_ary):
return cl.array.zeros(
template_ary.queue,
self.m2l_preproc_mpole_expansions_level_starts()[-1],
dtype=self.preprocessed_mpole_dtype)

def m2l_preproc_mpole_expansions_view(self, mpole_exps, level):
expn_start, expn_stop = \
result = []
for level in range(self.tree.nlevels):
expn_start, expn_stop = \
self.m2l_preproc_mpole_expansions_level_starts()[level:level+2]
box_start, box_stop = self.tree.level_start_box_nrs[level:level+2]
box_start, box_stop = self.tree.level_start_box_nrs[level:level+2]
exprs_level = cl.array.zeros(template_ary.queue, expn_stop - expn_start,
dtype=self.preprocessed_mpole_dtype)
result.append(exprs_level.reshape(box_stop - box_start, -1))
return result

return (box_start,
mpole_exps[expn_start:expn_stop].reshape(box_stop-box_start, -1))
def m2l_preproc_mpole_expansions_view(self, mpole_exps, level):
box_start, _ = self.tree.level_start_box_nrs[level:level+2]
return (box_start, mpole_exps[level])

m2l_work_array_view = m2l_preproc_mpole_expansions_view
m2l_work_array_zeros = m2l_preproc_mpole_expansion_zeros
Expand Down Expand Up @@ -522,6 +547,11 @@ def box_target_list_kwargs(self):

# }}}

def run_opencl_fft(self, queue, input_vec, inverse, wait_for, inplace):
app = get_opencl_fft_app(queue, input_vec.shape, input_vec.dtype,
isuruf marked this conversation as resolved.
Show resolved Hide resolved
inplace)
return run_opencl_fft(app, queue, input_vec, inverse, wait_for)

def form_multipoles(self,
level_start_source_box_nrs, source_boxes,
src_weight_vecs):
Expand Down Expand Up @@ -647,6 +677,7 @@ def eval_direct(self, target_boxes, source_box_starts,

@memoize_method
def multipole_to_local_precompute(self):
result = []
with cl.CommandQueue(self.tree_indep.cl_context) as queue:
m2l_translation_classes_dependent_data = \
self.m2l_translation_classes_dependent_data_zeros(queue)
Expand All @@ -666,6 +697,8 @@ def multipole_to_local_precompute(self):
m2l_translation_classes_dependent_data_view.shape[0]

if ntranslation_classes == 0:
result.append(pyopencl.array.empty_like(
m2l_translation_classes_dependent_data_view))
continue

data = self.translation_classes_data
Expand All @@ -683,13 +716,19 @@ def multipole_to_local_precompute(self):
ntranslation_vectors=m2l_translation_vectors.shape[1],
**self.kernel_extra_kwargs
)
m2l_translation_classes_dependent_data.add_event(evt)

m2l_translation_classes_dependent_data.finish()
if self.tree_indep.m2l_translation.use_fft:
_, m2l_translation_classes_dependent_data_view = \
self.run_opencl_fft(queue,
m2l_translation_classes_dependent_data_view,
inverse=False, wait_for=[evt], inplace=False)
result.append(m2l_translation_classes_dependent_data_view)

m2l_translation_classes_dependent_data = \
m2l_translation_classes_dependent_data.with_queue(None)
return m2l_translation_classes_dependent_data
for lev in range(self.tree.nlevels):
result[lev].finish()

result = [arr.with_queue(None) for arr in result]
return result

def _add_m2l_precompute_kwargs(self, kwargs_for_m2l,
lev):
Expand Down Expand Up @@ -717,25 +756,40 @@ def multipole_to_local(self,
target_boxes, src_box_starts, src_box_lists,
mpole_exps):

preprocess_evts = []
queue = mpole_exps.queue
local_exps = self.local_expansion_zeros(mpole_exps)

if self.tree_indep.m2l_translation.use_preprocessing:
preprocessed_mpole_exps = \
self.m2l_preproc_mpole_expansion_zeros(mpole_exps)
for lev in range(self.tree.nlevels):
m2l_work_array = self.m2l_work_array_zeros(local_exps)
mpole_exps_view_func = self.m2l_preproc_mpole_expansions_view
local_exps_view_func = self.m2l_work_array_view
else:
preprocessed_mpole_exps = mpole_exps
m2l_work_array = local_exps
mpole_exps_view_func = self.multipole_expansions_view
local_exps_view_func = self.local_expansions_view

preprocess_evts = []
translate_evts = []
postprocess_evts = []

for lev in range(self.tree.nlevels):
wait_for = []

start, stop = level_start_target_box_nrs[lev:lev+2]
if start == stop:
continue

if self.tree_indep.m2l_translation.use_preprocessing:
order = self.level_orders[lev]
preprocess_mpole_kernel = \
self.tree_indep.m2l_preprocess_mpole_kernel(order, order)

_, source_mpoles_view = \
self.multipole_expansions_view(mpole_exps, lev)

_, preprocessed_source_mpoles_view = \
self.m2l_preproc_mpole_expansions_view(
preprocessed_mpole_exps, lev)

tr_classes = self.m2l_translation_class_level_start_box_nrs()
if tr_classes[lev] == tr_classes[lev + 1]:
# There is no M2L happening in this level
Expand All @@ -744,33 +798,29 @@ def multipole_to_local(self,
evt, _ = preprocess_mpole_kernel(
queue,
src_expansions=source_mpoles_view,
preprocessed_src_expansions=preprocessed_source_mpoles_view,
preprocessed_src_expansions=preprocessed_mpole_exps[lev],
src_rscale=self.level_to_rscale(lev),
wait_for=wait_for,
**self.kernel_extra_kwargs
)
preprocess_evts.append(evt)
mpole_exps = preprocessed_mpole_exps
m2l_work_array = self.m2l_work_array_zeros(local_exps)
mpole_exps_view_func = self.m2l_preproc_mpole_expansions_view
local_exps_view_func = self.m2l_work_array_view
else:
m2l_work_array = local_exps
mpole_exps_view_func = self.multipole_expansions_view
local_exps_view_func = self.local_expansions_view
wait_for.append(evt)

translate_evts = []
if self.tree_indep.m2l_translation.use_fft:
evt_fft, preprocessed_mpole_exps[lev] = \
self.run_opencl_fft(queue,
preprocessed_mpole_exps[lev],
inverse=False, wait_for=wait_for, inplace=False)
wait_for.append(evt_fft.native_event)
evt = AggregateProfilingEvent([evt, evt_fft])

for lev in range(self.tree.nlevels):
start, stop = level_start_target_box_nrs[lev:lev+2]
if start == stop:
continue
preprocess_evts.append(evt)

order = self.level_orders[lev]
m2l = self.tree_indep.m2l(order, order,
self.supports_translation_classes)

source_level_start_ibox, source_mpoles_view = \
mpole_exps_view_func(mpole_exps, lev)
mpole_exps_view_func(preprocessed_mpole_exps, lev)
target_level_start_ibox, target_locals_view = \
local_exps_view_func(m2l_work_array, lev)

Expand All @@ -795,14 +845,11 @@ def multipole_to_local(self,
kwargs["m2l_translation_classes_dependent_data"].size == 0:
# There is nothing to do for this level
continue
evt, _ = m2l(queue, **kwargs, wait_for=preprocess_evts)

evt, _ = m2l(queue, **kwargs, wait_for=wait_for)
wait_for.append(evt)
translate_evts.append(evt)

postprocess_evts = []

if self.tree_indep.m2l_translation.use_preprocessing:
for lev in range(self.tree.nlevels):
if self.tree_indep.m2l_translation.use_preprocessing:
order = self.level_orders[lev]
postprocess_local_kernel = \
self.tree_indep.m2l_postprocess_local_kernel(order, order)
Expand All @@ -819,17 +866,28 @@ def multipole_to_local(self,
# There is no M2L happening in this level
continue

if self.tree_indep.m2l_translation.use_fft:
evt_fft, target_locals_before_postprocessing_view = \
self.run_opencl_fft(queue,
target_locals_before_postprocessing_view,
inverse=True, wait_for=wait_for, inplace=False)
wait_for.append(evt_fft.native_event)

evt, _ = postprocess_local_kernel(
queue,
tgt_expansions=target_locals_view,
tgt_expansions_before_postprocessing=(
target_locals_before_postprocessing_view),
src_rscale=self.level_to_rscale(lev),
tgt_rscale=self.level_to_rscale(lev),
wait_for=translate_evts,
wait_for=wait_for,
**self.kernel_extra_kwargs,
)
postprocess_evts.append(evt)

if self.tree_indep.m2l_translation.use_fft:
postprocess_evts.append(AggregateProfilingEvent([evt, evt_fft]))
else:
postprocess_evts.append(evt)

timing_events = preprocess_evts + translate_evts + postprocess_evts

Expand Down
2 changes: 1 addition & 1 deletion sumpy/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,6 @@ def get_kernel(self, max_nsources_in_one_box, max_ntargets_in_one_box,
"{[iknl]: 0 <= iknl < noutputs}",
"{[isrc_box]: isrc_box_start <= isrc_box < isrc_box_end}",
"{[idim]: 0 <= idim < dim}",
"{[istrength]: 0 <= istrength < nstrengths}",
"{[isrc]: isrc_start <= isrc < isrc_end}"
]

Expand All @@ -483,6 +482,7 @@ def get_kernel(self, max_nsources_in_one_box, max_ntargets_in_one_box,
shape=(self.strength_count, max_nsources_in_one_box)),
]
domains += [
"{[istrength]: 0 <= istrength < nstrengths}",
"{[inner]: 0 <= inner < nsplit}",
"{[itgt_offset_outer]: 0 <= itgt_offset_outer <= tgt_outer_limit}",
"{[isrc_offset_outer]: 0 <= isrc_offset_outer <= src_outer_limit}",
Expand Down
Loading