Skip to content

Commit 9ea8a44

Browse files
committed
changes
1 parent b902efa commit 9ea8a44

File tree

8 files changed

+52
-66
lines changed

8 files changed

+52
-66
lines changed

firedrake/assemble.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,8 @@ def allocate(self):
365365
else:
366366
test, trial = self._form.arguments()
367367
sparsity = ExplicitMatrixAssembler._make_sparsity(test, trial, self._mat_type, self._sub_mat_type, self.maps_and_regions)
368-
return matrix.Matrix(self._form, self._bcs, self._mat_type, sparsity, ScalarType,
369-
sub_mat_type=self._sub_mat_type,
370-
options_prefix=self._options_prefix)
368+
op2mat = op2.Mat(sparsity, mat_type=self._mat_type, sub_mat_type=self._sub_mat_type, dtype=ScalarType)
369+
return matrix.Matrix(self._form, op2mat, bcs=self._bcs, options_prefix=self._options_prefix, fc_params=self._form_compiler_params)
371370
else:
372371
raise NotImplementedError("Only implemented for rank = 2 and diagonal = False")
373372

@@ -626,8 +625,7 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args):
626625
raise TypeError(f"Unrecognised BaseForm instance: {expr}")
627626

628627
def assembled_matrix(self, expr, bcs, petscmat):
629-
return matrix.AssembledMatrix(expr.arguments(), petscmat, self._mat_type,
630-
bcs=bcs, options_prefix=self._options_prefix)
628+
return matrix.AssembledMatrix(expr.arguments(), petscmat, bcs=bcs, options_prefix=self._options_prefix)
631629

632630
@staticmethod
633631
def base_form_postorder_traversal(expr, visitor, visited={}):
@@ -1379,8 +1377,7 @@ def allocate(self):
13791377
sparsity, mat_type=self._mat_type, sub_mat_type=self._sub_mat_type,
13801378
dtype=ScalarType
13811379
)
1382-
return matrix.Matrix(self._form, op2mat, self._mat_type,
1383-
bcs=self._bcs,
1380+
return matrix.Matrix(self._form, op2mat, bcs=self._bcs,
13841381
fc_params=self._form_compiler_params,
13851382
options_prefix=self._options_prefix)
13861383

@@ -1585,7 +1582,7 @@ def allocate(self):
15851582
ctx = ImplicitMatrixContext(
15861583
self._form, row_bcs=self._bcs, col_bcs=self._bcs,
15871584
fc_params=self._form_compiler_params,
1588-
appctx=self._appctx or {}
1585+
appctx=self._appctx
15891586
)
15901587
return matrix.ImplicitMatrix(
15911588
self._form, ctx, self._bcs,

firedrake/external_operators/ml_operator.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from firedrake.external_operators import AbstractExternalOperator, assemble_method
2-
from firedrake.matrix import AssembledMatrix
2+
from firedrake.matrix import Matrix
33

44

55
class MLOperator(AbstractExternalOperator):
@@ -58,20 +58,16 @@ def assemble_jacobian(self, *args, **kwargs):
5858
"""Assemble the Jacobian using the AD engine of the ML framework."""
5959
# Delegate computation to the ML framework.
6060
J = self._jac()
61-
# Set bcs
62-
bcs = ()
63-
return AssembledMatrix(self, bcs, J)
61+
return Matrix(self, J)
6462

6563
@assemble_method(1, (1, 0))
6664
def assemble_jacobian_adjoint(self, *args, **kwargs):
6765
"""Assemble the Jacobian Hermitian transpose using the AD engine of the ML framework."""
6866
# Delegate computation to the ML framework.
6967
J = self._jac()
70-
# Set bcs
71-
bcs = ()
7268
# Take the adjoint (Hermitian transpose)
7369
J.hermitianTranspose()
74-
return AssembledMatrix(self, bcs, J)
70+
return Matrix(self, J)
7571

7672
@assemble_method(1, (0, None))
7773
def assemble_jacobian_action(self, *args, **kwargs):

firedrake/formmanipulation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,7 @@ def matrix(self, o):
180180
args.append(asplit)
181181

182182
submat = o.petscmat.createSubMatrix(*ises)
183-
bcs = ()
184-
return AssembledMatrix(tuple(args), bcs, submat)
183+
return AssembledMatrix(tuple(args), submat)
185184

186185
def zero_base_form(self, o):
187186
return ZeroBaseForm(tuple(map(self, o.arguments())))

firedrake/function.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
import ufl
55
import warnings
6+
from ufl.algorithms.analysis import extract_arguments
67
from ufl.duals import is_dual
78
from ufl.formatting.ufl2unicode import ufl2unicode
89
from ufl.domain import extract_unique_domain
@@ -383,6 +384,8 @@ def interpolate(self,
383384
Returns `self`
384385
"""
385386
from firedrake import interpolate, assemble
387+
if len(extract_arguments(expression)) > 0:
388+
raise ValueError("Can't interpolate an expression with arguments into a Function.")
386389
V = self.function_space()
387390
interp = interpolate(expression, V, **kwargs)
388391
return assemble(interp, tensor=self, ad_block_tag=ad_block_tag)

firedrake/interpolation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ def assemble(
314314
return ImplicitMatrix(self.ufl_interpolate, ctx, bcs=bcs)
315315

316316
result = self._get_callable(tensor=tensor, bcs=bcs, mat_type=mat_type, sub_mat_type=sub_mat_type)()
317+
317318
if self.rank == 2:
318319
# Assembling the operator
319320
assert isinstance(tensor, MatrixBase | None)
@@ -322,7 +323,7 @@ def assemble(
322323
result.copy(tensor.petscmat)
323324
return tensor
324325
else:
325-
return Matrix(self.ufl_interpolate, result, mat_type, bcs=bcs)
326+
return Matrix(self.ufl_interpolate, result, bcs=bcs)
326327
else:
327328
assert isinstance(tensor, Function | Cofunction | None)
328329
return tensor.assign(result) if tensor else result

firedrake/matrix.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class MatrixBase(ufl.Matrix):
2727
def __init__(
2828
self,
2929
a: ufl.BaseForm | TensorBase | tuple[BaseArgument, BaseArgument],
30-
mat_type: Literal["aij", "baij", "dense", "nest", "matfree"],
3130
bcs: Iterable[BCBase] | None = None,
3231
fc_params: dict[str, Any] | None = None,
3332
):
@@ -38,10 +37,7 @@ def __init__(
3837
----------
3938
a
4039
A UFL BaseForm (with two arguments) that this MatrixBase represents,
41-
or a tuple of the arguments it represents.
42-
mat_type
43-
Matrix type used in the assembly of the PETSc matrix: 'aij', 'baij', 'dense' or 'nest',
44-
or 'matfree' for matrix-free.
40+
or a tuple of the arguments it represents, or a slate TensorBase.
4541
fc_params
4642
A dictionary of form compiler parameters for this matrix.
4743
bcs
@@ -70,15 +66,12 @@ def __init__(
7066
self._analyze_form_arguments()
7167
self._arguments = arguments
7268

73-
if bcs is None:
74-
bcs = ()
75-
self.bcs = bcs
69+
self.bcs = bcs or ()
7670
self.comm = test.function_space().comm
7771
self._comm = internal_comm(self.comm, self)
7872
self.block_shape = (len(test.function_space()),
7973
len(trial.function_space()))
80-
self.mat_type = mat_type
81-
self.form_compiler_parameters = fc_params
74+
self.form_compiler_parameters = fc_params or {}
8275

8376
def arguments(self):
8477
if self.a:
@@ -155,7 +148,6 @@ def __init__(
155148
self,
156149
a: ufl.BaseForm,
157150
mat: op2.Mat | PETSc.Mat,
158-
mat_type: Literal["aij", "baij", "dense", "nest"],
159151
bcs: Iterable[BCBase] | None = None,
160152
fc_params: dict[str, Any] | None = None,
161153
options_prefix: str | None = None,
@@ -168,8 +160,6 @@ def __init__(
168160
The bilinear form this :class:`Matrix` represents.
169161
mat : op2.Mat | PETSc.Mat
170162
The underlying matrix object. Either a PyOP2 Mat or a PETSc Mat.
171-
mat_type
172-
The type of the PETSc matrix.
173163
bcs : Iterable[DirichletBC] | None, optional
174164
An iterable of boundary conditions to apply to this :class:`Matrix`.
175165
May be `None` if there are no boundary conditions to apply.
@@ -179,7 +169,7 @@ def __init__(
179169
options_prefix : str | None, optional
180170
PETSc options prefix to apply, by default None.
181171
"""
182-
super().__init__(a, mat_type, bcs=bcs, fc_params=fc_params)
172+
super().__init__(a, bcs=bcs, fc_params=fc_params)
183173
if isinstance(mat, op2.Mat):
184174
self.M = mat
185175
else:
@@ -188,7 +178,7 @@ def __init__(
188178
self.petscmat = self.M.handle
189179
if options_prefix:
190180
self.petscmat.setOptionsPrefix(options_prefix)
191-
self.mat_type = mat_type
181+
self.mat_type = self.petscmat.getType()
192182

193183
def assemble(self):
194184
raise NotImplementedError("API compatibility to apply bcs after 'assemble(a)'\
@@ -226,7 +216,7 @@ def __init__(
226216
options_prefix
227217
PETSc options prefix to apply, by default None.
228218
"""
229-
super().__init__(a, "matfree", bcs=bcs, fc_params=fc_params)
219+
super().__init__(a, bcs=bcs, fc_params=fc_params)
230220

231221
self.petscmat = PETSc.Mat().create(comm=self.comm)
232222
self.petscmat.setType("python")
@@ -236,6 +226,7 @@ def __init__(
236226
self.petscmat.setOptionsPrefix(options_prefix)
237227
self.petscmat.setUp()
238228
self.petscmat.assemble()
229+
self.mat_type = "matfree"
239230

240231
def assemble(self):
241232
# Bump petsc matrix state by assembling it.
@@ -250,7 +241,6 @@ def __init__(
250241
self,
251242
args: tuple[BaseArgument, BaseArgument],
252243
petscmat: PETSc.Mat,
253-
mat_type: Literal["aij", "baij", "dense", "nest", "matfree"],
254244
bcs: Iterable[BCBase] | None = None,
255245
options_prefix: str | None = None,
256246
):
@@ -262,19 +252,18 @@ def __init__(
262252
A tuple of the arguments the matrix represents.
263253
petscmat
264254
The PETSc matrix this object wraps.
265-
mat_type
266-
The type of the PETSc matrix.
267255
bcs
268256
an iterable of boundary conditions to apply to this :class:`Matrix`.
269257
May be `None` if there are no boundary conditions to apply. By default None.
270258
options_prefix
271259
PETSc options prefix to apply, by default None.
272260
"""
273-
super().__init__(args, mat_type, bcs=bcs)
261+
super().__init__(args, bcs=bcs)
274262

275263
self.petscmat = petscmat
276264
if options_prefix:
277265
self.petscmat.setOptionsPrefix(options_prefix)
266+
self.mat_type = self.petscmat.getType()
278267

279268
# this mimics op2.Mat.handle
280269
self.M = DummyOP2Mat(self.petscmat)

firedrake/matrix_free/operators.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import OrderedDict
2-
from typing import Iterable
2+
from typing import Any, Iterable
33
import itertools
44

55
from mpi4py import MPI
@@ -64,26 +64,6 @@ def find_sub_block(iset, ises, comm):
6464
return found
6565

6666

67-
"""This class gives the Python context for a PETSc Python matrix.
68-
69-
:arg a: The bilinear form defining the matrix
70-
71-
:arg row_bcs: An iterable of the :class.`.DirichletBC`s that are
72-
imposed on the test space. We distinguish between row and
73-
column boundary conditions in the case of submatrices off of the
74-
diagonal.
75-
76-
:arg col_bcs: An iterable of the :class.`.DirichletBC`s that are
77-
imposed on the trial space.
78-
79-
:arg fcparams: A dictionary of parameters to pass on to the form
80-
compiler.
81-
82-
:arg appctx: Any extra user-supplied context, available to
83-
preconditioners and the like.
84-
85-
"""
86-
8767
class ImplicitMatrixContext(object):
8868
# By default, these matrices will represent diagonal blocks (the
8969
# (0,0) block of a 1x1 block matrix is on the diagonal).
@@ -96,17 +76,38 @@ def __init__(
9676
a: ufl.BaseForm,
9777
row_bcs: Iterable[DirichletBC] | None = None,
9878
col_bcs: Iterable[DirichletBC] | None = None,
99-
fc_params=None,
100-
appctx=None
79+
fc_params : dict[str, Any] | None = None,
80+
appctx: dict[str, Any] | None = None
10181
):
82+
"""This class gives the Python context for a PETSc Python matrix.
83+
84+
Parameters
85+
----------
86+
a
87+
The bilinear form defining the matrix.
88+
row_bcs
89+
An iterable of the :class.`.DirichletBC`s that are
90+
imposed on the test space. We distinguish between row and
91+
column boundary conditions in the case of submatrices off
92+
of the diagonal. By default None.
93+
col_bcs
94+
An iterable of the :class.`.DirichletBC`s that are imposed
95+
on the trial space. By default None.
96+
fc_params
97+
A dictionary of parameters to pass on to the form compiler.
98+
By default None.
99+
appctx
100+
Any extra user-supplied context, available to preconditioners
101+
and the like. By default None.
102+
"""
102103
from firedrake.assemble import get_assembler
103104

104105
self.a = a
105106
self.aT = adjoint(a)
106107
self.comm = a.arguments()[0].function_space().comm
107108
self._comm = internal_comm(self.comm, self)
108-
self.fc_params = fc_params
109-
self.appctx = appctx
109+
self.fc_params = fc_params or {}
110+
self.appctx = appctx or {}
110111

111112
# Collect all DirichletBC instances including
112113
# DirichletBCs applied to an EquationBC.

tests/firedrake/regression/test_matrix.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from firedrake import *
2-
from firedrake import matrix
2+
from firedrake.matrix import Matrix, AssembledMatrix
33
import pytest
44

55

@@ -33,7 +33,7 @@ def mat_type(request):
3333
def test_assemble_returns_matrix(a):
3434
A = assemble(a)
3535

36-
assert isinstance(A, matrix.Matrix)
36+
assert isinstance(A, Matrix)
3737

3838

3939
def test_solve_with_assembled_matrix(a):
@@ -42,7 +42,7 @@ def test_solve_with_assembled_matrix(a):
4242
x, = SpatialCoordinate(V.mesh())
4343
f = Function(V).interpolate(x)
4444

45-
A = AssembledMatrix((v, u), bcs=(), petscmat=assemble(a).petscmat)
45+
A = AssembledMatrix((v, u), assemble(a).petscmat)
4646
L = inner(f, v) * dx
4747

4848
solution = Function(V)

0 commit comments

Comments
 (0)