Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Flux"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.16.7"
version = "0.16.8"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -52,7 +52,7 @@ Enzyme = "0.13"
EnzymeCore = "0.7.7, 0.8.4"
Functors = "0.5"
MLCore = "1.0.0"
MLDataDevices = "1.4.2"
MLDataDevices = "1.16.0"
MLUtils = "0.4"
MPI = "0.20.19"
MacroTools = "0.5"
Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ using EnzymeCore: EnzymeCore
default_device_rng,
gpu_device, cpu_device, xla_device,
CPUDevice,
CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice,
CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, OpenCLDevice,
XLADevice,
# get_device, # we define get_device here for retrocompatibility
gpu_backend!,
Expand Down
25 changes: 25 additions & 0 deletions test/ext_opencl/basic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
@testset "Basic GPU movement" begin
@test Flux.gpu(rand(Float64, 16)) isa CLArray{Float64, 1}
@test Flux.gpu(rand(Float64, 16, 16)) isa CLArray{Float64, 2}
@test Flux.gpu(rand(Float32, 16, 16)) isa CLArray{Float32, 2}

@test gradient(x -> sum(Flux.gpu(x)), rand(Float32, 4, 4)) isa Tuple
@test gradient(x -> sum(cpu(x)), OpenCL.rand(Float32, 4, 4)) isa Tuple
end

@testset "Dense no bias" begin
m = Dense(3 => 4; bias=false) |> Flux.gpu
x = zeros(Float32, 3, 4) |> Flux.gpu
@test m(x) isa CLArray{Float32, 2}
@test sum(m(x)) ≈ 0f0
gs = gradient(m -> sum(m(x)), m)
@test isnothing(gs[1].bias)
end

@testset "Chain of Dense layers" begin
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
x = rand(Float32, 10, 10)
@test (m|>gpu)(x|>gpu) isa CLArray{Float32, 2}
test_gradients(m, x, test_gpu=true, compare_finite_diff=false)
end

44 changes: 44 additions & 0 deletions test/ext_opencl/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
@testset "data movement" begin
opencl_device = Flux.gpu_device()
cdev = cpu_device()

@test opencl_device isa Flux.OpenCLDevice

x = randn(Float32, 5, 5)
cx = x |> opencl_device
@test cx isa OpenCL.CLMatrix{Float32}
x2 = cx |> cdev
@test x2 isa Matrix{Float32}
@test x ≈ x2

opencl_device = gpu_device(1)
@test opencl_device isa Flux.OpenCLDevice

@test cpu(cx) isa Matrix{Float32}
@test cpu(cx) ≈ x

@test gpu(x) isa OpenCL.CLMatrix{Float32}
@test cpu(gpu(x)) ≈ x
end

@testset "Basic" begin
include("basic.jl")
end

@testset "Recurrent" begin
global BROKEN_TESTS = [:lstm, :gru, :gruv3]
include("../ext_common/recurrent_gpu_ad.jl")
end

@testset "Huber Loss test" begin
X = Flux.gpu(Float32[0,1])
Y = Flux.gpu(Float32[1,0])

grad = Flux.gradient(X, Y) do a,b
Flux.Losses.huber_loss(a,b)
end

@test Flux.cpu(grad[1]) == [-0.5, 0.5]
@test Flux.cpu(grad[2]) == [0.5, -0.5]
end

22 changes: 21 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
using BSON: BSON
using Pkg

if get(ENV, "FLUX_TEST_OPENCL", "false") == "true"
Pkg.add(["OpenCL", "pocl_jll"])
using OpenCL, pocl_jll
end # OpenCL, pocl_jll are required to load before Flux to successfully use the jll

using FiniteDifferences: FiniteDifferences
using Flux
using Flux: OneHotArray, OneHotMatrix, OneHotVector,
Expand All @@ -12,7 +19,6 @@ using LinearAlgebra
using MLUtils: MLUtils, batch, unstack, unsqueeze,
unbatch, getobs, numobs, flatten, DataLoader
using Optimisers: Optimisers
using Pkg
using Random
using SparseArrays
using Statistics
Expand Down Expand Up @@ -164,6 +170,20 @@ end
@info "Skipping Metal tests, set FLUX_TEST_METAL=true to run them."
end

if get(ENV, "FLUX_TEST_OPENCL", "false") == "true"
if !isempty(cl.platforms()) && !isempty(vcat(cl.devices.(cl.platforms())...))
@testset "OpenCL" begin
include("ext_opencl/runtests.jl")

flux_testsuite(gpu)
end
else
@info "OpenCL.jl package is not functional. Skipping OpenCL tests."
end
else
@info "Skipping OpenCL tests, set FLUX_TEST_OPENCL=true to run them."
end

if get(ENV, "FLUX_TEST_DISTRIBUTED_MPI", "false") == "true" || get(ENV, "FLUX_TEST_DISTRIBUTED_NCCL", "false") == true
Pkg.add(["MPI"])
using MPI
Expand Down
Loading