Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,19 @@ find_package( Boost 1.36.0 REQUIRED )

#### optional external libraries.
# Listed here such that we know if we should compile extra utilities
option(DISABLE_LLN_MATRIX "disable use of LLN library" OFF)
option(DISABLE_ITK "disable use of ITK library" OFF)
option(DISABLE_HDF5 "disable use of HDF5 libraries" OFF)
option(DISABLE_STIR_LOCAL "disable use of LOCAL extensions to STIR" OFF)
option(DISABLE_CERN_ROOT "disable use of Cern ROOT libraries" OFF)
option(DISABLE_NLOHMANN_JSON "disable use of nlohmann JSON libraries" OFF)
option(DISABLE_LLN_MATRIX "disable use of LLN library" ON)
option(DISABLE_ITK "disable use of ITK library" ON)
option(DISABLE_HDF5 "disable use of HDF5 libraries" ON)
option(DISABLE_STIR_LOCAL "disable use of LOCAL extensions to STIR" ON)
option(DISABLE_CERN_ROOT "disable use of Cern ROOT libraries" ON)
option(DISABLE_NLOHMANN_JSON "disable use of nlohmann JSON libraries" ON)
option(STIR_ENABLE_EXPERIMENTAL "disable use of STIR experimental code" OFF) # disable by default
option(DISABLE_NiftyPET_PROJECTOR "disable use of NiftyPET projector" OFF)
option(DISABLE_Parallelproj_PROJECTOR "disable use of Parallelproj projector" OFF)
option(DISABLE_NiftyPET_PROJECTOR "disable use of NiftyPET projector" ON)
option(DISABLE_Parallelproj_PROJECTOR "disable use of Parallelproj projector" ON)
OPTION(DOWNLOAD_ZENODO_TEST_DATA "download zenodo data for tests" OFF)
option(DISABLE_UPENN "disable use of UPENN filetypes" OFF)
option(DISABLE_UPENN "disable use of UPENN filetypes" ON)
option(DISABLE_TORCH "disable use of TORCH" ON)
option(DISABLE_STIR_CUDA "disable use of TORCH" ON)


if(NOT DISABLE_ITK)
Expand Down Expand Up @@ -218,6 +220,31 @@ if(NOT DISABLE_STIR_CUDA)
endif()
endif()



# Enable CUDA if available
if (TORCH_CUDA_ARCH_LIST)
message(STATUS "CUDA support enabled with architectures: ${TORCH_CUDA_ARCH_LIST}")
else()
message(STATUS "CUDA support not enabled.")
endif()

if(NOT DISABLE_TORCH)
set(Torch_DIR "" CACHE PATH "Path to the Torch CMake configuration directory")
# Try to find Torch
find_package(Torch QUIET)

# Check if Torch was found
if (NOT Torch_FOUND)
message(WARNING "Torch not found. Setting STIR_WITH_TORCH to OFF.")
set(STIR_WITH_TORCH OFF)
else()
message(STATUS "Torch found. Setting STIR_WITH_TORCH to ON.")
set(STIR_WITH_TORCH ON)
message(STATUS "Torch include paths: ${TORCH_INCLUDE_DIRS}")
endif()
endif()

# Parallelproj
if(NOT DISABLE_Parallelproj_PROJECTOR)
find_package(parallelproj 1.3.4 CONFIG)
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ endif()
add_library(stir_registries OBJECT ${STIR_REGISTRIES})
# TODO, really should use stir_libs.cmake
target_include_directories(stir_registries PRIVATE ${STIR_INCLUDE_DIR})
target_include_directories(stir_registries PRIVATE ${TORCH_INCLUDE_DIRS})
target_include_directories(stir_registries PRIVATE ${Boost_INCLUDE_DIR})

# go and look for CMakeLists.txt files in all those directories
Expand Down
41 changes: 38 additions & 3 deletions src/IO/interfile.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,15 @@ read_interfile_image(istream& input, const string& directory_for_data)
return 0;
}

#ifdef STIR_WITH_TORCH
auto max_it = std::max_element(hdr.image_scaling_factors[0].begin(), hdr.image_scaling_factors[0].end());
if(*max_it > 1.f)
image_ptr->slice_wise_mult(hdr.image_scaling_factors[0]);
#else
for (int i = 0; i < hdr.matrix_size[2][0]; i++)
if (hdr.image_scaling_factors[0][i] != 1)
(*image_ptr)[i] *= static_cast<float>(hdr.image_scaling_factors[0][i]);
#endif

// Check number of time frames
if (image_ptr->get_exam_info().get_time_frame_definitions().get_num_frames() > 1)
Expand Down Expand Up @@ -198,11 +204,13 @@ read_interfile_dynamic_image(istream& input, const string& directory_for_data)
warning("read_interfile_dynamic_image: error reading data or scale factor returned by read_data not equal to 1");
return 0;
}

#ifdef STIR_WITH_TORCH
//TODO NE: Tomorrow
#else
for (int i = 0; i < hdr.matrix_size[2][0]; i++)
if (fabs(hdr.image_scaling_factors[frame_num - 1][i] - double(1)) > double(1e-10))
(*image_sptr)[i] *= static_cast<float>(hdr.image_scaling_factors[frame_num - 1][i]);

#endif
// Set the time frame of the individual frame
_exam_info.time_frame_definitions = TimeFrameDefinitions(hdr.get_exam_info().time_frame_definitions, frame_num);
image_sptr->set_exam_info(_exam_info);
Expand Down Expand Up @@ -734,7 +742,11 @@ write_basic_interfile_image_header(const string& header_file_name,
template <class elemT>
Succeeded
write_basic_interfile(const string& filename,
#ifdef STIR_WITH_TORCH
const TensorWrapper<3, elemT>& image,
#else
const Array<3, elemT>& image,
#endif
const NumericType output_type,
const float scale,
const ByteOrder byte_order)
Expand All @@ -748,7 +760,11 @@ template <class NUMBER>
Succeeded
write_basic_interfile(const string& filename,
const ExamInfo& exam_info,
#ifdef STIR_WITH_TORCH
const TensorWrapper<3, NUMBER>& image,
#else
const Array<3, NUMBER>& image,
#endif
const CartesianCoordinate3D<float>& voxel_size,
const CartesianCoordinate3D<float>& origin,
const NumericType output_type,
Expand Down Expand Up @@ -788,7 +804,11 @@ write_basic_interfile(const string& filename,
template <class NUMBER>
Succeeded
write_basic_interfile(const string& filename,
#ifdef STIR_WITH_TORCH
const TensorWrapper<3, NUMBER>& image,
#else
const Array<3, NUMBER>& image,
#endif
const CartesianCoordinate3D<float>& voxel_size,
const CartesianCoordinate3D<float>& origin,
const NumericType output_type,
Expand Down Expand Up @@ -1450,7 +1470,22 @@ write_basic_interfile_PDFS_header(const string& data_filename, const ProjDataFro
/**********************************************************************
template instantiations
**********************************************************************/
#ifdef STIR_WITH_TORCH
template Succeeded write_basic_interfile<>(const string& filename,
const TensorWrapper<3, float>&,
const CartesianCoordinate3D<float>& voxel_size,
const CartesianCoordinate3D<float>& origin,
const NumericType output_type,
const float scale,
const ByteOrder byte_order);


template Succeeded write_basic_interfile<>(const string& filename,
const TensorWrapper<3, float>& image,
const NumericType output_type,
const float scale,
const ByteOrder byte_order);
#else
template Succeeded write_basic_interfile<>(const string& filename,
const Array<3, signed short>&,
const CartesianCoordinate3D<float>& voxel_size,
Expand Down Expand Up @@ -1491,5 +1526,5 @@ template Succeeded write_basic_interfile<>(const string& filename,
const NumericType output_type,
const float scale,
const ByteOrder byte_order);

#endif
END_NAMESPACE_STIR
12 changes: 10 additions & 2 deletions src/buildblock/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ set(${dir_LIB_SOURCES}
ExamData.cxx
RadionuclideDB.cxx
Radionuclide.cxx
find_STIR_config.cxx
TensorWrapper.cxx
ProjData.cxx
ProjDataInfo.cxx
ProjDataInfoCylindrical.cxx
Expand All @@ -46,7 +48,7 @@ set(${dir_LIB_SOURCES}
Viewgram.cxx
Sinogram.cxx
RelatedViewgrams.cxx
zoom.cxx
# zoom.cxx
DataSymmetriesForViewSegmentNumbers.cxx
recon_array_functions.cxx
utilities.cxx
Expand Down Expand Up @@ -84,16 +86,18 @@ if (NOT MINI_STIR)
ArrayFilter1DUsingConvolution.cxx
ArrayFilter2DUsingConvolution.cxx
ArrayFilter3DUsingConvolution.cxx
if not (STIR_WITH_TORCH) #Use Torch
MaximalArrayFilter3D.cxx
MaximalImageFilter3D.cxx
endif()
FilePath.cxx
SeparableConvolutionImageFilter.cxx
NonseparableConvolutionUsingRealDFTImageFilter.cxx
ArcCorrection.cxx
SSRB.cxx
inverse_SSRB.cxx
centre_of_gravity.cxx
DynamicProjData.cxx
# DynamicProjData.cxx
MultipleProjData.cxx
MultipleDataSetHeader.cxx
GatedProjData.cxx
Expand All @@ -118,6 +122,10 @@ endif()

include(stir_lib_target)

if (STIR_WITH_TORCH)
target_include_directories(buildblock PUBLIC ${TORCH_INCLUDE_DIRS})
target_link_libraries(buildblock PUBLIC ${TORCH_LIBRARY})
endif()

if (HAVE_JSON)
# Add the header-only nlohman_json header-only library
Expand Down
Loading
Loading