Skip to content

Commit 0541ade

Browse files
committed
NLVS tape block stores adj_sol per block not per solver
1 parent f36c82b commit 0541ade

File tree

2 files changed

+22
-22
lines changed

2 files changed

+22
-22
lines changed

firedrake/adjoint_utils/blocks/solving.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(self, lhs, rhs, func, bcs, *args, **kwargs):
5656
# Solution function
5757
self.func = func
5858
self.function_space = self.func.function_space()
59+
self.adj_state_buf = self.func.copy(deepcopy=True)
5960
# Boundary conditions
6061
self.bcs = []
6162
if bcs is not None:
@@ -193,6 +194,8 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies):
193194
if self.adj_bdy_cb is not None and compute_bdy:
194195
self.adj_bdy_cb(adj_sol_bdy)
195196

197+
self.adj_state_buf.assign(adj_sol)
198+
196199
r = {}
197200
r["form"] = F_form
198201
r["adj_sol"] = adj_sol
@@ -399,6 +402,8 @@ def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs,
399402
if tlm_output is None:
400403
return
401404

405+
self.adj_state.assign(self.adj_state_buf)
406+
402407
F_form = self._create_F_form()
403408

404409
# Using the equation Form derive dF/du, d^2F/du^2 * du/dm * direction.
@@ -727,6 +732,7 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies):
727732
)
728733
adj_sol, adj_sol_bdy = self._adjoint_solve(adj_inputs[0], compute_bdy)
729734
self.adj_state = adj_sol
735+
self.adj_state_buf.assign(adj_sol)
730736
if self.adj_cb is not None:
731737
self.adj_cb(adj_sol)
732738
if self.adj_bdy_cb is not None and compute_bdy:

tests/firedrake/adjoint/test_hessian.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_simple_solve(rg):
3737
mesh = IntervalMesh(10, 0, 1)
3838
V = FunctionSpace(mesh, "Lagrange", 1)
3939

40-
f = Function(V).assign(2)
40+
f = Function(V).assign(2.)
4141

4242
u = TrialFunction(V)
4343
v = TestFunction(V)
@@ -76,10 +76,10 @@ def test_mixed_derivatives(rg):
7676
mesh = IntervalMesh(10, 0, 1)
7777
V = FunctionSpace(mesh, "Lagrange", 1)
7878

79-
f = Function(V).assign(2)
79+
f = Function(V).assign(2.)
8080
control_f = Control(f)
8181

82-
g = Function(V).assign(3)
82+
g = Function(V).assign(3.)
8383
control_g = Control(g)
8484

8585
u = TrialFunction(V)
@@ -126,7 +126,7 @@ def test_function(rg):
126126
R = FunctionSpace(mesh, "R", 0)
127127
c = Function(R, val=4)
128128
control_c = Control(c)
129-
f = Function(V).assign(3)
129+
f = Function(V).assign(3.)
130130
control_f = Control(f)
131131

132132
u = Function(V)
@@ -139,14 +139,14 @@ def test_function(rg):
139139
J = assemble(c ** 2 * u ** 2 * dx)
140140

141141
Jhat = ReducedFunctional(J, [control_c, control_f])
142-
dJdc, dJdf = compute_gradient(J, [control_c, control_f], apply_riesz=True)
142+
dJdc, dJdf = compute_derivative(J, [control_c, control_f], apply_riesz=True)
143143

144144
# Step direction for derivatives and convergence test
145145
h_c = Function(R, val=1.0)
146146
h_f = rg.uniform(V, 0, 10)
147147

148148
# Total derivative
149-
dJdc, dJdf = compute_gradient(J, [control_c, control_f], apply_riesz=True)
149+
dJdc, dJdf = compute_derivative(J, [control_c, control_f], apply_riesz=True)
150150
dJdm = assemble(dJdc * h_c * dx + dJdf * h_f * dx)
151151

152152
# Hessian
@@ -163,7 +163,7 @@ def test_nonlinear(rg):
163163
mesh = UnitSquareMesh(10, 10)
164164
V = FunctionSpace(mesh, "Lagrange", 1)
165165
R = FunctionSpace(mesh, "R", 0)
166-
f = Function(V).assign(5)
166+
f = Function(V).assign(5.)
167167

168168
u = Function(V)
169169
v = TestFunction(V)
@@ -201,11 +201,11 @@ def test_dirichlet(rg):
201201
mesh = UnitSquareMesh(10, 10)
202202
V = FunctionSpace(mesh, "Lagrange", 1)
203203

204-
f = Function(V).assign(30)
204+
f = Function(V).assign(30.)
205205

206206
u = Function(V)
207207
v = TestFunction(V)
208-
c = Function(V).assign(1)
208+
c = Function(V).assign(1.)
209209
bc = DirichletBC(V, c, "on_boundary")
210210

211211
F = inner(grad(u), grad(v)) * dx + u**4*v*dx - f**2 * v * dx
@@ -249,24 +249,25 @@ def Dt(u, u_, timestep):
249249
pr = project(sin(2*pi*x), V, annotate=False)
250250
ic = Function(V).assign(pr)
251251

252-
u_ = Function(V)
253-
u = Function(V)
252+
u_ = Function(V).assign(ic)
253+
u = Function(V).assign(ic)
254254
v = TestFunction(V)
255255

256256
nu = Constant(0.0001)
257257

258-
timestep = Constant(1.0/n)
258+
dt = 0.01
259+
nt = 20
259260

260261
params = {
261262
'snes_rtol': 1e-10,
262263
'ksp_type': 'preonly',
263264
'pc_type': 'lu',
264265
}
265266

266-
F = (Dt(u, ic, timestep)*v
267+
F = (Dt(u, u_, dt)*v
267268
+ u*u.dx(0)*v + nu*u.dx(0)*v.dx(0))*dx
269+
268270
bc = DirichletBC(V, 0.0, "on_boundary")
269-
t = 0.0
270271

271272
if solve_type == "nlvs":
272273
use_nlvs = True
@@ -285,21 +286,14 @@ def Dt(u, u_, timestep):
285286
else:
286287
solve(F == 0, u, bc, solver_parameters=params)
287288
u_.assign(u)
288-
t += float(timestep)
289289

290-
F = (Dt(u, u_, timestep)*v
291-
+ u*u.dx(0)*v + nu*u.dx(0)*v.dx(0))*dx
292-
293-
end = 0.2
294-
while (t <= end):
290+
for _ in range(nt):
295291
if use_nlvs:
296292
solver.solve()
297293
else:
298294
solve(F == 0, u, bc, solver_parameters=params)
299295
u_.assign(u)
300296

301-
t += float(timestep)
302-
303297
J = assemble(u_*u_*dx + ic*ic*dx)
304298

305299
Jhat = ReducedFunctional(J, Control(ic))

0 commit comments

Comments
 (0)