Skip to content

Commit c0c8fd2

Browse files
committed
mg map_cache handling for compute backends
1 parent 2da0f5b commit c0c8fd2

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

firedrake/mg/utils.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,28 @@
88
from firedrake.cython import mgimpl as impl
99

1010

11-
def get_or_set_mg_hierarchy_map_cache(cache_dict, entity_dofs_key,
12-
create_map_on_cpu):
11+
def get_or_set_mg_hierarchy_map_cache(cache_dict, entity_dofs_key, create_map_on_cpu):
1312
"""
1413
:arg cache_dict: An instance of :class:`dict` that maps from tuple
1514
``(entity_dofs_key, compute_backend)`` to the corresponding map.
16-
:arg create_host_map: A callable that takes no argument and returns the map
15+
:arg create_map_on_cpu: A callable that takes no argument and returns the map
1716
on the CPU backend.
1817
:returns map: An instance of :class:`pyop2.base.Map`.
1918
"""
2019
try:
2120
return cache_dict[(entity_dofs_key, op2.compute_backend)]
2221
except KeyError:
22+
2323
from pyop2.sequential import cpu_backend
24-
host_map = cache_dict.setdefault((entity_dofs_key,
25-
cpu_backend), create_map_on_cpu())
26-
return cache_dict.setdefault((entity_dofs_key, op2.compute_backend),
27-
op2.compute_backend.Map(host_map))
24+
if (entity_dofs_key, cpu_backend) not in cache_dict:
25+
cache_dict[(entity_dofs_key, cpu_backend)] = create_map_on_cpu()
26+
27+
map_on_cpu = cache_dict[(entity_dofs_key, cpu_backend)]
28+
29+
if (entity_dofs_key, op2.compute_backend) not in cache_dict:
30+
cache_dict[(entity_dofs_key, op2.compute_backend)] = op2.compute_backend.Map(map_on_cpu)
31+
32+
return cache_dict[(entity_dofs_key, op2.compute_backend)]
2833

2934

3035
def fine_node_to_coarse_node_map(Vf, Vc):

0 commit comments

Comments
 (0)