Skip to content

Problem with gradient in nonlinear function f #8

@vboussange

Description

@vboussange

For now gradients ∇v_y,∇v_z cannot be used in the nonlinear function

f(y,z,v_y,v_z,∇v_y,∇v_z, p, t)

Uncommenting those two line https://github.com/vboussange/HighDimPDE.jl/blob/21a25b7f13f15332a1b5872961315b7bcedc5a5b/src/DeepSplitting.jl#L92-L93

should make this work. Nonetheless, possibly because of issue FluxML/Flux.jl#1464 , this fails : the gradient throws a one dimensional array ∇vi(y1)[1] = [0f0]. Here is an example to play with (given that one has uncommented the above lines).

@testset "DeepSplitting algorithm - gradient squared" begin
    batch_size = 2000
    train_steps = 1000
    K = 1
    tspan = (0f0, 5f-1)
    dt = 5f-2  # time step

    μ(x, p, t) = 0f0 # advection coefficients
    σ(x, p, t) = 1f-1 #1f-1 # diffusion coefficients
    
    for d in [1,2,5]
        u1s = []
        for _ in 1:2
            u_domain = (fill(-5f-1, d), fill(5f-1, d))

            hls = d + 50 #hidden layer size

            nn = Flux.Chain(Dense(d,hls,tanh),
                            Dense(hls,hls,tanh),
                            Dense(hls,1)) # Neural network used by the scheme

            opt = ADAM(1e-2) #optimiser
            alg = DeepSplitting(nn, K=K, opt = opt, mc_sample = UniformSampling(u_domain[1],u_domain[2]) )

            x = fill(0f0,d)  # initial point
            g(X) = exp.(-0.25f0 * sum(X.^2,dims=1))   # initial condition
            a(u) = u - u^3
            f(y, z, v_y, v_z, ∇v_y, ∇v_z, p, t) = begin @show ∇v_y; sum(∇v_y.^2,dims=1) end

            # defining the problem
            prob = PIDEProblem(g, f, μ, σ, tspan, x = x, neumann = u_domain)
            # solving
            @time xs,ts,sol = solve(prob, 
                            alg, 
                            dt, 
                            # verbose = true, 
                            # abstol=1e-5,
                            use_cuda = false,
                            maxiters = train_steps,
                            batch_size=batch_size)
            push!(u1s, sol[end])
            println("d = $d, u1 = $(sol[end])")

        end
        e_l2 = mean(rel_error_l2.(u1s[1], u1s[2]))
        println("rel_error_l2 = ", e_l2, "\n")
        @test e_l2 < 0.1
    end
end

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions