Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 75 additions & 3 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import abc
import sys
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

import numpy as np
Expand All @@ -74,7 +75,6 @@

if TYPE_CHECKING:
import loopy as lp
import pyopencl as cl
import pytato

if getattr(sys, "_BUILDING_SPHINX_DOCS", False):
Expand Down Expand Up @@ -235,6 +235,16 @@ def get_target(self):

# {{{ PytatoPyOpenCLArrayContext


@dataclass
class ProfileEvent:
"""Holds a profile event that has not been collected by the profiler yet."""

start_cl_event: cl._cl.Event
stop_cl_event: cl._cl.Event
t_unit_name: str


class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
"""
An :class:`ArrayContext` that uses :mod:`pytato` data types to represent
Expand All @@ -259,7 +269,7 @@ def __init__(
self, queue: cl.CommandQueue, allocator=None, *,
use_memory_pool: bool | None = None,
compile_trace_callback: Callable[[Any, str, Any], None] | None = None,

profile_kernels: bool = False,
# do not use: only for testing
_force_svm_arg_limit: int | None = None,
) -> None:
Expand Down Expand Up @@ -322,6 +332,59 @@ def __init__(

self._force_svm_arg_limit = _force_svm_arg_limit

self._enable_profiling(profile_kernels)

# {{{ Profiling functionality

def _enable_profiling(self, enable: bool) -> None:
# List of ProfileEvents that haven't been transferred to profiled
# results yet
self._profile_events: list[ProfileEvent] = []

# Dict of kernel name -> list of kernel execution times
self._profile_results: dict[str, list[int]] = {}

if enable:
import pyopencl as cl
if not self.queue.properties & cl.command_queue_properties.PROFILING_ENABLE:
raise RuntimeError("Profiling was not enabled in the command queue. "
"Please create the queue with "
"cl.command_queue_properties.PROFILING_ENABLE.")
self.profile_kernels = True

else:
self.profile_kernels = False

def _wait_and_transfer_profile_events(self) -> None:
"""Wait for all profiling events to finish and transfer the results
to *self._profile_results*."""
import pyopencl as cl
# First, wait for completion of all events
if self._profile_events:
cl.wait_for_events([p_event.stop_cl_event
for p_event in self._profile_events])

# Then, collect all events and store them
for t in self._profile_events:
name = t.t_unit_name

time = t.stop_cl_event.profile.end - t.start_cl_event.profile.end

self._profile_results.setdefault(name, []).append(time)

self._profile_events = []

def _add_profiling_events(self, start: cl._cl.Event, stop: cl._cl.Event,
t_unit_name: str) -> None:
"""Add profiling events to the list of profiling events."""
self._profile_events.append(ProfileEvent(start, stop, t_unit_name))

def _reset_profiling_data(self) -> None:
"""Reset profiling data."""
self._profile_results = {}

# }}}

@property
def _frozen_array_types(self) -> tuple[type, ...]:
import pyopencl.array as cla
Expand Down Expand Up @@ -546,9 +609,18 @@ def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray:
self._dag_transform_cache[normalized_expr])

assert len(pt_prg.bound_arguments) == 0
_evt, out_dict = pt_prg(self.queue,

if self.profile_kernels:
import pyopencl as cl
start_evt = cl.enqueue_marker(self.queue)

evt, out_dict = pt_prg(self.queue,
allocator=self.allocator,
**bound_arguments)

if self.profile_kernels:
self._add_profiling_events(start_evt, evt, pt_prg.program.entrypoint)

assert len(set(out_dict) & set(key_to_frozen_subary)) == 0

key_to_frozen_subary = {
Expand Down
18 changes: 16 additions & 2 deletions arraycontext/impl/pytato/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,10 +636,17 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
input_kwargs_for_loopy = _args_to_device_buffers(
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)

_evt, out_dict = self.pytato_program(queue=self.actx.queue,
if self.actx.profile_kernels:
import pyopencl as cl
start_evt = cl.enqueue_marker(self.actx.queue)

evt, out_dict = self.pytato_program(queue=self.actx.queue,
allocator=self.actx.allocator,
**input_kwargs_for_loopy)

if self.actx.profile_kernels:
self.actx._add_profiling_events(start_evt, evt, fn_name)

def to_output_template(keys, _):
name_in_program = self.output_id_to_name_in_program[keys]
return self.actx.thaw(to_tagged_cl_array(
Expand Down Expand Up @@ -675,10 +682,17 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
input_kwargs_for_loopy = _args_to_device_buffers(
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)

_evt, out_dict = self.pytato_program(queue=self.actx.queue,
if self.actx.profile_kernels:
import pyopencl as cl
start_evt = cl.enqueue_marker(self.actx.queue)

evt, out_dict = self.pytato_program(queue=self.actx.queue,
allocator=self.actx.allocator,
**input_kwargs_for_loopy)

if self.actx.profile_kernels:
self.actx._add_profiling_events(start_evt, evt, fn_name)

return self.actx.thaw(to_tagged_cl_array(out_dict[self.output_name],
axes=get_cl_axes_from_pt_axes(
self.output_axes),
Expand Down
46 changes: 46 additions & 0 deletions arraycontext/impl/pytato/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
__doc__ = """
.. autofunction:: transfer_from_numpy
.. autofunction:: transfer_to_numpy


Profiling-related functions
^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: tabulate_profiling_data
"""


Expand Down Expand Up @@ -35,6 +41,7 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, cast

import pytools
from pytato.array import (
AbstractResultWithNamedArrays,
Array,
Expand All @@ -51,6 +58,7 @@

from arraycontext import ArrayContext
from arraycontext.impl.pyopencl.taggable_cl_array import Axis as ClAxis
from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext


if TYPE_CHECKING:
Expand Down Expand Up @@ -221,4 +229,42 @@ def transfer_to_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames:

# }}}


# {{{ Profiling

def tabulate_profiling_data(actx: PytatoPyOpenCLArrayContext) -> pytools.Table:
"""Return a :class:`pytools.Table` with the profiling results."""
actx._wait_and_transfer_profile_events()

tbl = pytools.Table()

# Table header
tbl.add_row(("Kernel", "# Calls", "Time_sum [ns]", "Time_avg [ns]"))

# Precision of results
g = ".5g"

total_calls = 0
total_time = 0.0

for kernel_name, times in actx._profile_results.items():
num_calls = len(times)
total_calls += num_calls

t_sum = sum(times)
t_avg = t_sum / num_calls
if t_sum is not None:
total_time += t_sum

tbl.add_row((kernel_name, num_calls, f"{t_sum:{g}}", f"{t_avg:{g}}"))

tbl.add_row(("", "", "", ""))
tbl.add_row(("Total", total_calls, f"{total_time:{g}}", "--"))

actx._reset_profiling_data()

return tbl

# }}}

# vim: foldmethod=marker
94 changes: 94 additions & 0 deletions test/test_pytato_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import logging

import numpy as np
import pytest

from pytools.tag import Tag
Expand Down Expand Up @@ -274,6 +275,99 @@ def twice(x, y, a):
assert isinstance(ep.arg_dict["_actx_in_2"], lp.ArrayArg)


def test_profiling_actx():
import pyopencl as cl
cl_ctx = cl.create_some_context()
queue = cl.CommandQueue(cl_ctx,
properties=cl.command_queue_properties.PROFILING_ENABLE)

actx = PytatoPyOpenCLArrayContext(queue, profile_kernels=True)

def twice(x):
return 2 * x

# {{{ Compiled test

f = actx.compile(twice)

assert len(actx._profile_events) == 0

for _ in range(10):
assert actx.to_numpy(f(99)) == 198

assert len(actx._profile_events) == 10
actx._wait_and_transfer_profile_events()
assert len(actx._profile_events) == 0
assert len(actx._profile_results) == 1
assert len(actx._profile_results["twice"]) == 10

from arraycontext.impl.pytato.utils import tabulate_profiling_data

print(tabulate_profiling_data(actx))
assert len(actx._profile_results) == 0

# }}}

# {{{ Uncompiled/frozen test

assert len(actx._profile_events) == 0

for _ in range(10):
assert np.all(actx.to_numpy(twice(actx.from_numpy(np.array([99, 99])))) == 198)

assert len(actx._profile_events) == 10
actx._wait_and_transfer_profile_events()
assert len(actx._profile_events) == 0
assert len(actx._profile_results) == 1
assert len(actx._profile_results["frozen_result"]) == 10

print(tabulate_profiling_data(actx))

assert len(actx._profile_results) == 0

# }}}

# {{{ test disabling profiling

actx._enable_profiling(False)

assert len(actx._profile_events) == 0

for _ in range(10):
assert actx.to_numpy(f(99)) == 198

assert len(actx._profile_events) == 0
assert len(actx._profile_results) == 0

# }}}

# {{{ test enabling profiling

actx._enable_profiling(True)

assert len(actx._profile_events) == 0

for _ in range(10):
assert actx.to_numpy(f(99)) == 198

assert len(actx._profile_events) == 10
actx._wait_and_transfer_profile_events()
assert len(actx._profile_events) == 0
assert len(actx._profile_results) == 1

# }}}

queue2 = cl.CommandQueue(cl_ctx)

with pytest.raises(RuntimeError):
PytatoPyOpenCLArrayContext(queue2, profile_kernels=True)

actx2 = PytatoPyOpenCLArrayContext(queue2)

with pytest.raises(RuntimeError):
actx2._enable_profiling(True)


if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
Expand Down
Loading