Skip to content

Commit e06308d

Browse files
authored
Merge pull request #307 from firedrakeproject/mscroggs/gdim
Update UFL element interface
2 parents 8c1c4c0 + 947d74f commit e06308d

File tree

4 files changed

+9
-8
lines changed

4 files changed

+9
-8
lines changed

tests/test_interpolation_factorisation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_sum_factorisation_scalar_tensor(mesh, element):
5454
source = element(degree - 1)
5555
target = element(degree)
5656
tensor_flops = flop_count(mesh, source, target)
57-
expect = numpy.prod(target.value_shape)
57+
expect = FunctionSpace(mesh, target).value_size
5858
if isinstance(target, FiniteElement):
5959
scalar_flops = tensor_flops
6060
else:

tests/test_tsfc_204.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ def test_physically_mapped_facet():
1313
V = FiniteElement("P", mesh.ufl_cell(), 1)
1414
R = FiniteElement("P", mesh.ufl_cell(), 1)
1515
Vv = VectorElement(BrokenElement(V))
16-
Qhat = VectorElement(BrokenElement(V[facet]))
17-
Vhat = VectorElement(V[facet])
16+
Qhat = VectorElement(BrokenElement(V[facet]), dim=2)
17+
Vhat = VectorElement(V[facet], dim=2)
1818
Z = FunctionSpace(mesh, MixedElement(U, Vv, Qhat, Vhat, R))
1919

2020
z = Coefficient(Z)

tsfc/kernel_interface/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def prepare_coefficient(coefficient, name, interior_facet=False):
472472

473473
if coefficient.ufl_element().family() == 'Real':
474474
# Constant
475-
value_size = coefficient.ufl_element().value_size
475+
value_size = coefficient.ufl_function_space().value_size
476476
expression = gem.reshape(gem.Variable(name, (value_size,)),
477477
coefficient.ufl_shape)
478478
return expression

tsfc/ufl_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,9 @@ def apply_mapping(expression, element, domain):
412412
mesh = domain
413413
if domain is not None and mesh != domain:
414414
raise NotImplementedError("Multiple domains not supported")
415-
if expression.ufl_shape != element.value_shape:
416-
raise ValueError(f"Mismatching shapes, got {expression.ufl_shape}, expected {element.value_shape}")
415+
pvs = element.pullback.physical_value_shape(element, mesh)
416+
if expression.ufl_shape != pvs:
417+
raise ValueError(f"Mismatching shapes, got {expression.ufl_shape}, expected {pvs}")
417418
mapping = element.mapping().lower()
418419
if mapping == "identity":
419420
rexpression = expression
@@ -451,7 +452,7 @@ def apply_mapping(expression, element, domain):
451452
sub_elem = element.sub_elements[0]
452453
shape = expression.ufl_shape
453454
flat = ufl.as_vector([expression[i] for i in numpy.ndindex(shape)])
454-
vs = sub_elem.value_shape
455+
vs = sub_elem.pullback.physical_value_shape(sub_elem, mesh)
455456
rvs = sub_elem.reference_value_shape
456457
seen = set()
457458
rpieces = []
@@ -472,7 +473,7 @@ def apply_mapping(expression, element, domain):
472473
# And reshape
473474
rexpression = as_tensor(numpy.asarray(rpieces).reshape(element.reference_value_shape))
474475
else:
475-
raise NotImplementedError(f"Don't know how to handle mapping type {mapping} for expression of rank {element.value_shape}")
476+
raise NotImplementedError(f"Don't know how to handle mapping type {mapping} for expression of rank {ufl.FunctionSpace(mesh, element).value_shape}")
476477
if rexpression.ufl_shape != element.reference_value_shape:
477478
raise ValueError(f"Mismatching reference shapes, got {rexpression.ufl_shape} expected {element.reference_value_shape}")
478479
return rexpression

0 commit comments

Comments
 (0)