diff --git a/CMakeLists.txt b/CMakeLists.txt index a382a8fc79..6ed9852f75 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,17 +92,19 @@ find_package( Boost 1.36.0 REQUIRED ) #### optional external libraries. # Listed here such that we know if we should compile extra utilities -option(DISABLE_LLN_MATRIX "disable use of LLN library" OFF) -option(DISABLE_ITK "disable use of ITK library" OFF) -option(DISABLE_HDF5 "disable use of HDF5 libraries" OFF) -option(DISABLE_STIR_LOCAL "disable use of LOCAL extensions to STIR" OFF) -option(DISABLE_CERN_ROOT "disable use of Cern ROOT libraries" OFF) -option(DISABLE_NLOHMANN_JSON "disable use of nlohmann JSON libraries" OFF) +option(DISABLE_LLN_MATRIX "disable use of LLN library" ON) +option(DISABLE_ITK "disable use of ITK library" ON) +option(DISABLE_HDF5 "disable use of HDF5 libraries" ON) +option(DISABLE_STIR_LOCAL "disable use of LOCAL extensions to STIR" ON) +option(DISABLE_CERN_ROOT "disable use of Cern ROOT libraries" ON) +option(DISABLE_NLOHMANN_JSON "disable use of nlohmann JSON libraries" ON) option(STIR_ENABLE_EXPERIMENTAL "disable use of STIR experimental code" OFF) # disable by default -option(DISABLE_NiftyPET_PROJECTOR "disable use of NiftyPET projector" OFF) -option(DISABLE_Parallelproj_PROJECTOR "disable use of Parallelproj projector" OFF) +option(DISABLE_NiftyPET_PROJECTOR "disable use of NiftyPET projector" ON) +option(DISABLE_Parallelproj_PROJECTOR "disable use of Parallelproj projector" ON) OPTION(DOWNLOAD_ZENODO_TEST_DATA "download zenodo data for tests" OFF) -option(DISABLE_UPENN "disable use of UPENN filetypes" OFF) +option(DISABLE_UPENN "disable use of UPENN filetypes" ON) +option(DISABLE_TORCH "disable use of TORCH" ON) +option(DISABLE_STIR_CUDA "disable use of TORCH" ON) if(NOT DISABLE_ITK) @@ -218,6 +220,31 @@ if(NOT DISABLE_STIR_CUDA) endif() endif() + + +# Enable CUDA if available +if (TORCH_CUDA_ARCH_LIST) + message(STATUS "CUDA support enabled with architectures: ${TORCH_CUDA_ARCH_LIST}") +else() + message(STATUS "CUDA support not enabled.") +endif() + +if(NOT DISABLE_TORCH) + set(Torch_DIR "" CACHE PATH "Path to the Torch CMake configuration directory") + # Try to find Torch + find_package(Torch QUIET) + + # Check if Torch was found + if (NOT Torch_FOUND) + message(WARNING "Torch not found. Setting STIR_WITH_TORCH to OFF.") + set(STIR_WITH_TORCH OFF) + else() + message(STATUS "Torch found. Setting STIR_WITH_TORCH to ON.") + set(STIR_WITH_TORCH ON) + message(STATUS "Torch include paths: ${TORCH_INCLUDE_DIRS}") + endif() +endif() + # Parallelproj if(NOT DISABLE_Parallelproj_PROJECTOR) find_package(parallelproj 1.3.4 CONFIG) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a20e1c5d39..394977be99 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -255,6 +255,7 @@ endif() add_library(stir_registries OBJECT ${STIR_REGISTRIES}) # TODO, really should use stir_libs.cmake target_include_directories(stir_registries PRIVATE ${STIR_INCLUDE_DIR}) +target_include_directories(stir_registries PRIVATE ${TORCH_INCLUDE_DIRS}) target_include_directories(stir_registries PRIVATE ${Boost_INCLUDE_DIR}) # go and look for CMakeLists.txt files in all those directories diff --git a/src/IO/interfile.cxx b/src/IO/interfile.cxx index 58e32bd0fd..be151cf720 100644 --- a/src/IO/interfile.cxx +++ b/src/IO/interfile.cxx @@ -145,9 +145,15 @@ read_interfile_image(istream& input, const string& directory_for_data) return 0; } +#ifdef STIR_WITH_TORCH + auto max_it = std::max_element(hdr.image_scaling_factors[0].begin(), hdr.image_scaling_factors[0].end()); + if(*max_it > 1.f) + image_ptr->slice_wise_mult(hdr.image_scaling_factors[0]); +#else for (int i = 0; i < hdr.matrix_size[2][0]; i++) if (hdr.image_scaling_factors[0][i] != 1) (*image_ptr)[i] *= static_cast(hdr.image_scaling_factors[0][i]); +#endif // Check number of time frames if (image_ptr->get_exam_info().get_time_frame_definitions().get_num_frames() > 1) @@ -198,11 +204,13 @@ read_interfile_dynamic_image(istream& input, const string& directory_for_data) warning("read_interfile_dynamic_image: error reading data or scale factor returned by read_data not equal to 1"); return 0; } - +#ifdef STIR_WITH_TORCH + //TODO NE: Tomorrow +#else for (int i = 0; i < hdr.matrix_size[2][0]; i++) if (fabs(hdr.image_scaling_factors[frame_num - 1][i] - double(1)) > double(1e-10)) (*image_sptr)[i] *= static_cast(hdr.image_scaling_factors[frame_num - 1][i]); - +#endif // Set the time frame of the individual frame _exam_info.time_frame_definitions = TimeFrameDefinitions(hdr.get_exam_info().time_frame_definitions, frame_num); image_sptr->set_exam_info(_exam_info); @@ -734,7 +742,11 @@ write_basic_interfile_image_header(const string& header_file_name, template Succeeded write_basic_interfile(const string& filename, +#ifdef STIR_WITH_TORCH + const TensorWrapper<3, elemT>& image, +#else const Array<3, elemT>& image, +#endif const NumericType output_type, const float scale, const ByteOrder byte_order) @@ -748,7 +760,11 @@ template Succeeded write_basic_interfile(const string& filename, const ExamInfo& exam_info, +#ifdef STIR_WITH_TORCH + const TensorWrapper<3, NUMBER>& image, +#else const Array<3, NUMBER>& image, +#endif const CartesianCoordinate3D& voxel_size, const CartesianCoordinate3D& origin, const NumericType output_type, @@ -788,7 +804,11 @@ write_basic_interfile(const string& filename, template Succeeded write_basic_interfile(const string& filename, +#ifdef STIR_WITH_TORCH + const TensorWrapper<3, NUMBER>& image, +#else const Array<3, NUMBER>& image, +#endif const CartesianCoordinate3D& voxel_size, const CartesianCoordinate3D& origin, const NumericType output_type, @@ -1450,7 +1470,22 @@ write_basic_interfile_PDFS_header(const string& data_filename, const ProjDataFro /********************************************************************** template instantiations **********************************************************************/ +#ifdef STIR_WITH_TORCH +template Succeeded write_basic_interfile<>(const string& filename, + const TensorWrapper<3, float>&, + const CartesianCoordinate3D& voxel_size, + const CartesianCoordinate3D& origin, + const NumericType output_type, + const float scale, + const ByteOrder byte_order); + +template Succeeded write_basic_interfile<>(const string& filename, + const TensorWrapper<3, float>& image, + const NumericType output_type, + const float scale, + const ByteOrder byte_order); +#else template Succeeded write_basic_interfile<>(const string& filename, const Array<3, signed short>&, const CartesianCoordinate3D& voxel_size, @@ -1491,5 +1526,5 @@ template Succeeded write_basic_interfile<>(const string& filename, const NumericType output_type, const float scale, const ByteOrder byte_order); - +#endif END_NAMESPACE_STIR diff --git a/src/buildblock/CMakeLists.txt b/src/buildblock/CMakeLists.txt index e080af4ed7..3d5f37fba4 100644 --- a/src/buildblock/CMakeLists.txt +++ b/src/buildblock/CMakeLists.txt @@ -22,6 +22,8 @@ set(${dir_LIB_SOURCES} ExamData.cxx RadionuclideDB.cxx Radionuclide.cxx + find_STIR_config.cxx + TensorWrapper.cxx ProjData.cxx ProjDataInfo.cxx ProjDataInfoCylindrical.cxx @@ -46,7 +48,7 @@ set(${dir_LIB_SOURCES} Viewgram.cxx Sinogram.cxx RelatedViewgrams.cxx - zoom.cxx + # zoom.cxx DataSymmetriesForViewSegmentNumbers.cxx recon_array_functions.cxx utilities.cxx @@ -84,8 +86,10 @@ if (NOT MINI_STIR) ArrayFilter1DUsingConvolution.cxx ArrayFilter2DUsingConvolution.cxx ArrayFilter3DUsingConvolution.cxx +if not (STIR_WITH_TORCH) #Use Torch MaximalArrayFilter3D.cxx MaximalImageFilter3D.cxx +endif() FilePath.cxx SeparableConvolutionImageFilter.cxx NonseparableConvolutionUsingRealDFTImageFilter.cxx @@ -93,7 +97,7 @@ if (NOT MINI_STIR) SSRB.cxx inverse_SSRB.cxx centre_of_gravity.cxx - DynamicProjData.cxx + # DynamicProjData.cxx MultipleProjData.cxx MultipleDataSetHeader.cxx GatedProjData.cxx @@ -118,6 +122,10 @@ endif() include(stir_lib_target) +if (STIR_WITH_TORCH) + target_include_directories(buildblock PUBLIC ${TORCH_INCLUDE_DIRS}) + target_link_libraries(buildblock PUBLIC ${TORCH_LIBRARY}) +endif() if (HAVE_JSON) # Add the header-only nlohman_json header-only library diff --git a/src/buildblock/TensorWrapper.cxx b/src/buildblock/TensorWrapper.cxx new file mode 100644 index 0000000000..134f3b4ceb --- /dev/null +++ b/src/buildblock/TensorWrapper.cxx @@ -0,0 +1,247 @@ +#include "stir/TensorWrapper.h" + +START_NAMESPACE_STIR + + +template +TensorWrapper::TensorWrapper() + : tensor(nullptr), device("cpu") {} + +// Implementation of TensorWrapper methods +template +TensorWrapper::TensorWrapper( + const at::Tensor& tensor, + const std::vector& offsets, + const std::string& device) + : tensor(tensor.to(device)), device(device), offsets(offsets) +{ + ends.resize(offsets.size()); + for(int i=0; i +TensorWrapper::TensorWrapper( + const IndexRange & range, const std::string& device) + : device(device) +{ + std::vector shape = convertIndexRangeToShape(range); + torch::Dtype dtype = getTorchDtype(); + tensor = torch::zeros(shape, torch::TensorOptions().dtype(dtype)); + // Extract offsets + offsets = extract_offsets_recursive(range); + ends.resize(offsets.size()); + for(int i=0; i +TensorWrapper::TensorWrapper( + const IndexRange & range, + shared_ptr data_sptr, + const std::string& device) + : device(device) +{ + // tensor = torch::from_blob(data_sptr.get(), torch::TensorOptions().dtype(dtype)); + elemT* data_ptr = data_sptr.get(); + init(range, data_ptr); + tensor.to(device == "cuda" ? torch::kCUDA : torch::kCPU); +} + +template +TensorWrapper::TensorWrapper( + const shared_ptr> shape, const std::string& device) + : device(device) +{ + if (shape->size() != num_dimensions) { + throw std::invalid_argument("Shape size does not match the number of dimensions"); + } + torch::Dtype dtype = getTorchDtype(); + tensor = torch::zeros((*shape), torch::TensorOptions().dtype(dtype)); + compute_strides(); + tensor.to(device == "cuda" ? torch::kCUDA : torch::kCPU); +} + +// template +// TensorWrapper::TensorWrapper( +// const shared_ptr> discretised_density, const std::string& device) +// : device(device) +// { +// std::vector shape = getShapeFromDiscretisedDensity(discretised_density); +// extract_offsets_recursive(discretised_density->get_index_range()); +// tensor = torch::zeros(shape, torch::TensorOptions().dtype(getTorchDtype())); +// fillTensorFromDiscretisedDensity(discretised_density); +// tensor.to(device == "cuda" ? torch::kCUDA : torch::kCPU); +// } + +template +void TensorWrapper::to_gpu() +{ + if (torch::cuda::is_available()) { + device = "cuda"; + tensor = tensor.to(torch::kCUDA); + } else { + throw std::runtime_error("CUDA is not available. Cannot move tensor to GPU."); + } +} + +template +void TensorWrapper::to_cpu() { + device = "cpu"; + tensor = tensor.to(torch::kCPU); +} + +template +bool TensorWrapper::is_on_gpu() const { + return tensor.device().is_cuda(); +} + +template +void TensorWrapper::print_device() const { + std::cout << "Tensor is on device: " << tensor.device() << std::endl; +} + + +template +elemT TensorWrapper::find_max() const { + return tensor.max().template item(); +} + +template +elemT TensorWrapper::sum() const { + return tensor.sum().template item(); +} + +template +elemT TensorWrapper::sum_positive() const { + return tensor.clamp_min(0).sum().template item(); +} + +template +elemT TensorWrapper::find_min() const { + return tensor.min().template item(); +} + +template +void TensorWrapper::fill(const elemT& n) { + tensor.fill_(n); +} + +template +void TensorWrapper::apply_lower_threshold(const elemT& l) { + tensor.clamp_min(l); +} + +template +void TensorWrapper::apply_upper_threshold(const elemT& u) { + tensor.clamp_max(u); +} + +// template +// TensorWrapper TensorWrapper::get_empty_copy() const { +// torch::Tensor empty_tensor = torch::empty_like(tensor); +// empty_tensor.to(device); +// return TensorWrapper(empty_tensor); +// } + + +// template +// std::unique_ptr> TensorWrapper::clone() const { +// torch::Tensor cloned_tensor = tensor.clone(); +// tensor.to(device); +// return std::make_unique>(cloned_tensor); +// } + +template +torch::Tensor& TensorWrapper::getTensor() { + return tensor; +} + +template +const torch::Tensor& TensorWrapper::getTensor() const { + return tensor; +} + +template +void TensorWrapper::printSizes() const { + std::cout << "Tensor sizes: ["; + for (size_t i = 0; i < num_dimensions; ++i) { + std::cout << tensor.size(i) << " "; + } + std::cout << "]" << std::endl; +} + +template +stir::IndexRange TensorWrapper::get_index_range() const { + return stir::IndexRange(); +} + +template +void TensorWrapper::resize(const stir::IndexRange& range) +{ + std::vector new_shape = convertIndexRangeToShape(range); + torch::Dtype dtype = getTorchDtype(); + tensor = torch::empty(new_shape, torch::TensorOptions().dtype(dtype)); + + // Extract offsets + offsets = extract_offsets_recursive(range); + ends.resize(offsets.size()); + for(int i=0; i +void TensorWrapper::grow(const stir::IndexRange& range) { + resize(range); +} + +template +void TensorWrapper::print() const { + std::cout << "TensorWrapper Tensor:" << std::endl; + std::cout << tensor << std::endl; +} + +template +void TensorWrapper:: + mask_cyl(const float radius, const bool in) +{ + float _radius = 0.0; + if (radius == -1.f){ + _radius = static_cast(get_max_index(2) - get_min_index(1)) / 2.f; + } + else + _radius = radius; + auto y_coords = torch::arange(0, tensor.size(1), tensor.options().dtype(torch::kFloat).device(tensor.device())); + auto x_coords = torch::arange(0, tensor.size(2), tensor.options().dtype(torch::kFloat).device(tensor.device())); + + const float center_y = (tensor.size(1) - 1) / 2.0f; + const float center_x = (tensor.size(2) - 1) / 2.0f; + + auto yy = y_coords.view({-1, 1}).expand({tensor.size(1), tensor.size(2)}); + auto xx = x_coords.view({1, -1}).expand({tensor.size(1), tensor.size(2)}); + + // Compute the squared Euclidean distance in the x-y plane + auto dist_squared = (yy - center_y).pow(2) + (xx - center_x).pow(2); + auto mask_2d = dist_squared <= (_radius * _radius); + auto mask_3d = mask_2d.unsqueeze(0).expand({tensor.size(0), tensor.size(1), tensor.size(2)}); + tensor.mul_(mask_3d); +} + +// Explicit template instantiations + +template class TensorWrapper<1, int32_t>; +template class TensorWrapper<2, int32_t>; +template class TensorWrapper<3, int32_t>; +template class TensorWrapper<1, float>; +template class TensorWrapper<2, float>; +template class TensorWrapper<3, float>; +template class TensorWrapper<1, double>; +template class TensorWrapper<2, double>; +template class TensorWrapper<3, double>; + +// template class TensorWrapper<4, float>; +END_NAMESPACE_STIR diff --git a/src/buildblock/VoxelsOnCartesianGrid.cxx b/src/buildblock/VoxelsOnCartesianGrid.cxx index e7a8bbfbd7..79f4a86b76 100644 --- a/src/buildblock/VoxelsOnCartesianGrid.cxx +++ b/src/buildblock/VoxelsOnCartesianGrid.cxx @@ -134,12 +134,20 @@ VoxelsOnCartesianGrid::VoxelsOnCartesianGrid() {} template -VoxelsOnCartesianGrid::VoxelsOnCartesianGrid(const Array<3, elemT>& v, +#ifdef STIR_WITH_TORCH + VoxelsOnCartesianGrid::VoxelsOnCartesianGrid(const TensorWrapper<3, elemT>& v, +#else + VoxelsOnCartesianGrid::VoxelsOnCartesianGrid(const Array<3, elemT>& v, +#endif const CartesianCoordinate3D& origin, const BasicCoordinate<3, float>& grid_spacing) : DiscretisedDensityOnCartesianGrid<3, elemT>(v.get_index_range(), origin, grid_spacing) { +#ifdef STIR_WITH_TORCH + TensorWrapper<3, elemT>::operator=(v); +#else Array<3, elemT>::operator=(v); +#endif } template @@ -151,12 +159,20 @@ VoxelsOnCartesianGrid::VoxelsOnCartesianGrid(const IndexRange<3>& range, template VoxelsOnCartesianGrid::VoxelsOnCartesianGrid(const shared_ptr& exam_info_sptr, +#ifdef STIR_WITH_TORCH + const TensorWrapper<3, elemT>& v, +#else const Array<3, elemT>& v, +#endif const CartesianCoordinate3D& origin, const BasicCoordinate<3, float>& grid_spacing) : DiscretisedDensityOnCartesianGrid<3, elemT>(exam_info_sptr, v.get_index_range(), origin, grid_spacing) { +#ifdef STIR_WITH_TORCH + TensorWrapper<3, elemT>::operator=(v); +#else Array<3, elemT>::operator=(v); +#endif } template @@ -306,6 +322,7 @@ VoxelsOnCartesianGrid::set_voxel_size(const BasicCoordinate<3, float>& c) this->set_grid_spacing(c); } +#ifndef STIR_WITH_TORCH template PixelsOnCartesianGrid VoxelsOnCartesianGrid::get_plane(const int z) const @@ -330,7 +347,7 @@ VoxelsOnCartesianGrid::set_plane(const PixelsOnCartesianGrid& plan this->operator[](z) = plane; } - +#endif template void VoxelsOnCartesianGrid::grow_z_range(const int min_z, const int max_z) @@ -344,8 +361,9 @@ VoxelsOnCartesianGrid::grow_z_range(const int min_z, const int max_z) */ CartesianCoordinate3D min_indices; CartesianCoordinate3D max_indices; - +#ifndef STIR_WITH_TORCH this->get_regular_range(min_indices, max_indices); +#endif assert(min_z <= min_indices.z()); assert(max_z >= max_indices.z()); min_indices.z() = min_z; @@ -366,7 +384,9 @@ VoxelsOnCartesianGrid::get_min_indices() const { CartesianCoordinate3D min_indices; CartesianCoordinate3D max_indices; +#ifndef STIR_WITH_TORCH this->get_regular_range(min_indices, max_indices); +#endif return min_indices; } @@ -376,7 +396,9 @@ VoxelsOnCartesianGrid::get_max_indices() const { CartesianCoordinate3D min_indices; CartesianCoordinate3D max_indices; +#ifndef STIR_WITH_TORCH this->get_regular_range(min_indices, max_indices); +#endif return max_indices; } #if 0 diff --git a/src/buildblock/recon_array_functions.cxx b/src/buildblock/recon_array_functions.cxx index b8c25a7884..98b3374244 100644 --- a/src/buildblock/recon_array_functions.cxx +++ b/src/buildblock/recon_array_functions.cxx @@ -126,6 +126,9 @@ truncate_rim(DiscretisedDensity<3, float>& input_image, const int rim_truncation DiscretisedDensityOnCartesianGrid<3, float>& input_image_cartesian = dynamic_cast&>(input_image); +#ifdef STIR_WITH_TORCH + input_image.mask_cyl(-1); +#else if (!input_image_cartesian.is_regular()) error("truncate_rim called for non-regular grid. Not implemented"); @@ -136,13 +139,12 @@ truncate_rim(DiscretisedDensity<3, float>& input_image, const int rim_truncation const int ze = input_image_cartesian.get_max_index(); const int ye = input_image_cartesian[zs].get_max_index(); const int xe = input_image_cartesian[zs][ys].get_max_index(); - - // TODO check what happens with even-sized images (i.e. where is the centre?) - - // const int zm=(zs+ze)/2; + // const int zm=(zs+ze)/2; const int ym = (ys + ye) / 2; const int xm = (xs + xe) / 2; + // TODO check what happens with even-sized images (i.e. where is the centre?) + const float truncated_radius = static_cast((xe - xs) / 2 - rim_truncation_image); if (strictly_less_than_radius) @@ -165,6 +167,7 @@ truncate_rim(DiscretisedDensity<3, float>& input_image, const int rim_truncation input_image[z][y][x] = 0; } } +#endif } // AZ&KT 04/10/99: added rim_truncation_sino @@ -311,6 +314,9 @@ divide_array(SegmentByView& numerator, const SegmentByView& denomi void divide_array(DiscretisedDensity<3, float>& numerator, const DiscretisedDensity<3, float>& denominator) { +#ifdef STIR_WITH_TORCH + numerator /= denominator; +#else assert(numerator.get_index_range() == denominator.get_index_range()); float small_value = numerator.find_max() * SMALL_NUM; small_value = (small_value > 0.0F) ? small_value : 0.0F; @@ -328,6 +334,7 @@ divide_array(DiscretisedDensity<3, float>& numerator, const DiscretisedDensity<3 else numerator[z][y][x] /= denominator[z][y][x]; } +#endif } // MJ 03/01/2000 for loglikelihood computation @@ -380,6 +387,10 @@ void multiply_and_add(DiscretisedDensity<3, float>& image_res, const DiscretisedDensity<3, float>& image_scaled, float scalar) { +#ifdef STIR_WITH_TORCH + float one = 1.f; + image_res.sapyb(one, image_scaled, scalar); +#else assert(image_res.get_index_range() == image_scaled.get_index_range()); // TODO rewrite in terms of 'full' iterator @@ -389,6 +400,7 @@ multiply_and_add(DiscretisedDensity<3, float>& image_res, const DiscretisedDensi { image_res[z][y][x] += image_scaled[z][y][x] * scalar; } +#endif } // to be used with in_place_function @@ -402,24 +414,24 @@ neg_trunc(float x) void truncate_end_planes(DiscretisedDensity<3, float>& input_image, int input_num_planes) { +//NE:WaitTomorrow +// // TODO this function does not make a lot of sense in general +// #ifndef NDEBUG +// // this will throw an exception when the cast is invalid +// dynamic_cast&>(input_image); +// #endif - // TODO this function does not make a lot of sense in general -#ifndef NDEBUG - // this will throw an exception when the cast is invalid - dynamic_cast&>(input_image); -#endif - - const int zs = input_image.get_min_index(); - const int ze = input_image.get_max_index(); +// const int zs = input_image.get_min_index(); +// const int ze = input_image.get_max_index(); - int upper_limit = (input_image.get_length() % 2 == 1) ? input_image.get_length() / 2 + 1 : input_image.get_length() / 2; +// int upper_limit = (input_image.get_length() % 2 == 1) ? input_image.get_length() / 2 + 1 : input_image.get_length() / 2; - int num_planes = input_num_planes <= upper_limit ? input_num_planes : upper_limit; +// int num_planes = input_num_planes <= upper_limit ? input_num_planes : upper_limit; - for (int j = 0; j < num_planes; j++) - { - input_image[zs + j].fill(0.0); - input_image[ze - j].fill(0.0); - } +// for (int j = 0; j < num_planes; j++) +// { +// input_image[zs + j].fill(0.0); +// input_image[ze - j].fill(0.0); +// } } END_NAMESPACE_STIR diff --git a/src/cmake/STIRConfig.cmake.in b/src/cmake/STIRConfig.cmake.in index cc20a4526b..eb01b9b6f9 100644 --- a/src/cmake/STIRConfig.cmake.in +++ b/src/cmake/STIRConfig.cmake.in @@ -153,6 +153,13 @@ if(@STIR_WITH_CUDA@) set(STIR_WITH_CUDA TRUE) endif() + +if(@STIR_WITH_TORCH@) + find_package(CUDAToolkit REQUIRED) + enable_language(CUDA) + set(STIR_WITH_TORCH TRUE) +endif() + if(@STIR_WITH_Parallelproj_PROJECTOR@) find_package(parallelproj REQUIRED CONFIG) set(STIR_WITH_Parallelproj_PROJECTOR TRUE) diff --git a/src/cmake/STIRConfig.h.in b/src/cmake/STIRConfig.h.in index efc241dae9..5dfc618e00 100644 --- a/src/cmake/STIRConfig.h.in +++ b/src/cmake/STIRConfig.h.in @@ -80,6 +80,8 @@ namespace stir { #cmakedefine STIR_OPENMP +#cmakedefine STIR_WITH_TORCH + #cmakedefine STIR_MPI #cmakedefine nlohmann_json_FOUND "@nlohmann_json_FOUND@" diff --git a/src/cmake/stir_exe_targets.cmake b/src/cmake/stir_exe_targets.cmake index 1a1769aa0d..caab452f1d 100644 --- a/src/cmake/stir_exe_targets.cmake +++ b/src/cmake/stir_exe_targets.cmake @@ -22,6 +22,10 @@ foreach(source ${${dir_EXE_SOURCES}}) if(BUILD_EXECUTABLES) get_filename_component(executable ${source} NAME_WE) add_executable(${executable} ${ALL_HEADERS} ${ALL_INLINES} ${ALL_TXXS} ${source} $) +if (STIR_WITH_TORCH) + target_link_libraries(${executable} ${TORCH_LIBRARIES}) + target_include_directories(${executable} PUBLIC ${TORCH_INCLUDE_DIRS}) +endif() target_link_libraries(${executable} ${STIR_LIBRARIES}) SET_PROPERTY(TARGET ${executable} PROPERTY FOLDER "Executables") target_include_directories(${executable} PUBLIC ${Boost_INCLUDE_DIR}) diff --git a/src/cmake/stir_lib_target.cmake b/src/cmake/stir_lib_target.cmake index 56884b214f..e86d3e1c5d 100644 --- a/src/cmake/stir_lib_target.cmake +++ b/src/cmake/stir_lib_target.cmake @@ -21,6 +21,7 @@ target_include_directories(${dir} PUBLIC # make sure that if you use STIR, the compiler will be set to what was set via UseCXX target_compile_features(${dir} PUBLIC cxx_std_${CMAKE_CXX_STANDARD}) target_include_directories(${dir} PUBLIC ${Boost_INCLUDE_DIR}) +target_include_directories(${dir} PUBLIC ${TORCH_INCLUDE_DIRS}) SET_PROPERTY(TARGET ${dir} PROPERTY FOLDER "Libs") diff --git a/src/cmake/stir_test_exe_targets.cmake b/src/cmake/stir_test_exe_targets.cmake index c70c5a3c77..5542da4dbe 100644 --- a/src/cmake/stir_test_exe_targets.cmake +++ b/src/cmake/stir_test_exe_targets.cmake @@ -59,6 +59,10 @@ macro (create_stir_involved_test source libraries dependencies) if(BUILD_TESTING) get_filename_component(executable ${source} NAME_WE) add_executable(${executable} ${source} ${dependencies}) + if (STIR_WITH_TORCH) + target_link_libraries(${executable} ${TORCH_LIBRARIES}) + target_include_directories(${executable} PUBLIC ${TORCH_INCLUDE_DIRS}) + endif() target_link_libraries(${executable} ${libraries}) SET_PROPERTY(TARGET ${executable} PROPERTY FOLDER "Tests") target_include_directories(${executable} PUBLIC ${Boost_INCLUDE_DIR}) diff --git a/src/data_buildblock/CMakeLists.txt b/src/data_buildblock/CMakeLists.txt index e230390b10..f241d39eaa 100644 --- a/src/data_buildblock/CMakeLists.txt +++ b/src/data_buildblock/CMakeLists.txt @@ -35,7 +35,6 @@ endif() include(stir_lib_target) -target_link_libraries(${dir} PUBLIC buildblock) if (HAVE_HDF5) # for GEHDF5, TODO remove once IO dependency added or GEHDF5Wrapper no longer includes H5Cpp.h diff --git a/src/include/stir/Array.h b/src/include/stir/Array.h index 69f01d47f9..fddc3787b1 100644 --- a/src/include/stir/Array.h +++ b/src/include/stir/Array.h @@ -93,6 +93,7 @@ class Array : public NumericVectorWithOffset, e typedef typename base_type::iterator iterator; typedef typename base_type::const_iterator const_iterator; //@} + #ifdef ARRAY_FULL /*! @name typedefs for full_iterator support Full iterators provide a 1-dimensional view on a multi-dimensional diff --git a/src/include/stir/ArrayFunction.inl b/src/include/stir/ArrayFunction.inl index 612b7d8c03..757668276f 100644 --- a/src/include/stir/ArrayFunction.inl +++ b/src/include/stir/ArrayFunction.inl @@ -25,6 +25,10 @@ #include "stir/array_index_functions.h" #include "stir/modulo.h" +#ifdef STIR_WITH_TORCH +#include "stir/TensorWrapper.h" +#endif + #include #include #ifdef BOOST_NO_STDC_NAMESPACE @@ -348,4 +352,76 @@ transform_array_from_periodic_indices(Array& out_array, c } while (next(index, out_array)); } +#ifdef STIR_WITH_TORCH + +template +inline TensorWrapper<1, elemT>& +in_place_log(TensorWrapper<1, elemT>& v) +{ + for (int i = v.get_min_index(); i <= v.get_max_index(); i++) + v.at(i) = std::log(v.at(i)); + return v; +} + +template +inline TensorWrapper& +in_place_log(TensorWrapper& v) +{ + for (int i = v.get_min_index(); i <= v.get_max_index(); i++) + in_place_log(v.at(i)); + return v; +} + +template +inline TensorWrapper<1, elemT>& +in_place_exp(TensorWrapper<1, elemT>& v) +{ + for (int i = v.get_min_index(); i <= v.get_max_index(); i++) + v.at(i) = std::exp(v.at(i)); + return v; +} + +template +inline TensorWrapper& +in_place_exp(TensorWrapper& v) +{ + for (int i = v.get_min_index(); i <= v.get_max_index(); i++) + in_place_exp(v.at(i)); + return v; +} + +template +inline TensorWrapper<1, elemT>& +in_place_abs(TensorWrapper<1, elemT>& v) +{ + for (int i = v.get_min_index(); i <= v.get_max_index(); i++) + if (v.at(i) < 0) + v.at(i) = -v.at(i); + return v; +} + +template +inline TensorWrapper& +in_place_abs(TensorWrapper& v) +{ + for (int i = v.get_min_index(); i <= v.get_max_index(); i++) + in_place_abs(v.at(i)); + return v; +} + +// template +// inline T& +// in_place_apply_function(T& v, FUNCTION f) +// { +// typename T::full_iterator iter = v.begin_all(); +// const typename T::full_iterator end_iter = v.end_all(); +// while (iter != end_iter) +// { +// *iter = f(*iter); +// ++iter; +// } +// return v; +// } +#endif + END_NAMESPACE_STIR diff --git a/src/include/stir/DiscretisedDensity.h b/src/include/stir/DiscretisedDensity.h index d04120aa1b..c73039b1c1 100644 --- a/src/include/stir/DiscretisedDensity.h +++ b/src/include/stir/DiscretisedDensity.h @@ -32,7 +32,12 @@ */ #include "stir/CartesianCoordinate3D.h" +#ifdef STIR_WITH_TORCH +#include "stir/TensorWrapper.h" +#include +#else #include "stir/Array.h" +#endif #include "stir/ExamData.h" #include "stir/shared_ptr.h" #include @@ -91,8 +96,13 @@ START_NAMESPACE_STIR */ +#ifdef STIR_WITH_TORCH +template +class DiscretisedDensity : public ExamData, public TensorWrapper +#else template class DiscretisedDensity : public ExamData, public Array +#endif { #ifdef SWIG // work-around swig problem. It gets confused when using a private (or protected) @@ -101,7 +111,12 @@ class DiscretisedDensity : public ExamData, public Array #else private: #endif + #ifdef STIR_WITH_TORCH + typedef TensorWrapper base_type; + #else typedef Array base_type; + #endif + typedef DiscretisedDensity self_type; public: diff --git a/src/include/stir/DiscretisedDensity.inl b/src/include/stir/DiscretisedDensity.inl index b43d76541b..b0102e4624 100644 --- a/src/include/stir/DiscretisedDensity.inl +++ b/src/include/stir/DiscretisedDensity.inl @@ -39,8 +39,9 @@ DiscretisedDensity::DiscretisedDensity() template DiscretisedDensity::DiscretisedDensity(const IndexRange& range_v, const CartesianCoordinate3D& origin_v) - : Array(range_v), - origin(origin_v) + :base_type(range_v), + origin(origin_v) + {} template @@ -48,7 +49,7 @@ DiscretisedDensity::DiscretisedDensity(const shared_ptr& range_v, const CartesianCoordinate3D& origin_v) : ExamData(exam_info_sptr), - Array(range_v), + base_type(range_v), origin(origin_v) {} diff --git a/src/include/stir/IO/interfile.h b/src/include/stir/IO/interfile.h index a27742053d..475bfd8b2f 100644 --- a/src/include/stir/IO/interfile.h +++ b/src/include/stir/IO/interfile.h @@ -40,6 +40,10 @@ template class IndexRange; template class Array; +#ifdef STIR_WITH_TORCH +template +class TensorWrapper; +#endif template class DiscretisedDensity; template @@ -166,7 +170,12 @@ const VectorWithOffset compute_file_offsets(int number_of_time_fr template Succeeded write_basic_interfile(const std::string& filename, +#ifdef STIR_WITH_TORCH + const TensorWrapper<3, elemT>& image, +#else const Array<3, elemT>& image, +#endif + // const Array<3, elemT>& image, const CartesianCoordinate3D& voxel_size, const CartesianCoordinate3D& origin, const NumericType output_type = NumericType::FLOAT, @@ -183,7 +192,12 @@ Succeeded write_basic_interfile(const std::string& filename, template Succeeded write_basic_interfile(const std::string& filename, const ExamInfo& exam_info, +#ifdef STIR_WITH_TORCH + const TensorWrapper<3, elemT>& image, +#else const Array<3, elemT>& image, +#endif + // const Array<3, elemT>& image, const CartesianCoordinate3D& voxel_size, const CartesianCoordinate3D& origin, const NumericType output_type = NumericType::FLOAT, @@ -203,7 +217,12 @@ Succeeded write_basic_interfile(const std::string& filename, template Succeeded write_basic_interfile(const std::string& filename, +#ifdef STIR_WITH_TORCH + const TensorWrapper<3, elemT>& image, +#else const Array<3, elemT>& image, +#endif + // const Array<3, elemT>& image, const NumericType output_type = NumericType::FLOAT, const float scale = 0, const ByteOrder byte_order = ByteOrder::native); diff --git a/src/include/stir/IO/read_data.h b/src/include/stir/IO/read_data.h index 1baefc8185..dee6696730 100644 --- a/src/include/stir/IO/read_data.h +++ b/src/include/stir/IO/read_data.h @@ -28,6 +28,10 @@ template class NumericInfo; template class Array; +#ifdef STIR_WITH_TORCH +template +class TensorWrapper; +#endif /*! \ingroup Array_IO \brief Read the data of an Array from file. @@ -44,6 +48,26 @@ class Array; template inline Succeeded read_data(IStreamT& s, Array& data, const ByteOrder byte_order = ByteOrder::native); +#ifdef STIR_WITH_TORCH +template +inline Succeeded read_data(IStreamT& s, + TensorWrapper& data, const ByteOrder byte_order = ByteOrder::native); + +template +inline Succeeded read_data(IStreamT& s, + TensorWrapper& data, + NumericInfo input_type, + ScaleT& scale_factor, + const ByteOrder byte_order = ByteOrder::native); + +template +inline Succeeded read_data(IStreamT& s, + TensorWrapper& data, + NumericType type, + ScaleT& scale, + const ByteOrder byte_order = ByteOrder::native); +#endif + /*! \ingroup Array_IO \brief Read the data of an Array from file as a different type. diff --git a/src/include/stir/IO/read_data.inl b/src/include/stir/IO/read_data.inl index e8c6b30818..933c5d5b0c 100644 --- a/src/include/stir/IO/read_data.inl +++ b/src/include/stir/IO/read_data.inl @@ -16,6 +16,7 @@ */ #include "stir/Array.h" +#include "stir/TensorWrapper.h" #include "stir/convert_array.h" #include "stir/NumericType.h" #include "stir/NumericInfo.h" @@ -65,6 +66,101 @@ read_data(IStreamT& s, Array& data, const ByteOrder byte_ return detail::read_data_help(detail::test_if_1d(), s, data, byte_order); } +#ifdef STIR_WITH_TORCH +template +inline Succeeded +read_data(IStreamT& s, TensorWrapper& data, const ByteOrder byte_order) +{ + if (!data.is_contiguous()) { + throw std::runtime_error("Tensor must be contiguous to read data."); + } + + // Ensure the tensor has the correct data type + if (data.getTensor().scalar_type() != torch::CppTypeToScalarType()) { + throw std::runtime_error("Tensor data type does not match the specified element type."); + } + + // Get the total number of elements in the tensor + int64_t num_elements = data.size_all(); + + // Calculate the total number of bytes to read + std::size_t num_bytes = num_elements * sizeof(elemT); + + // Read binary data into the tensor's memory + s.read(reinterpret_cast(data.begin_all()), num_bytes); + + if (!s) { + return Succeeded::no; + } + + // if (byte_order != ByteOrder::native) { + // elemT* data_ptr = data.begin_all(); + // for (int64_t i = 0; i < num_elements; ++i) { + // swap_byte_order(data_ptr[i]); + // } + // } + + return Succeeded::yes; +} + +template +inline Succeeded +read_data(IStreamT& s, + TensorWrapper& data, + NumericInfo input_type, + ScaleT& scale_factor, + const ByteOrder byte_order) +{ + if (typeid(InputType) == typeid(elemT)) + { + // TODO? you might want to use the scale even in this case, + // but at the moment we don't + scale_factor = ScaleT(1); + return read_data(s, data, byte_order); + } + else + { + //TODO NE Another time. + return Succeeded::no; + // Array in_data(data.get_index_range()); + // Succeeded success = read_data(s, in_data, byte_order); + // if (success == Succeeded::no) + // return Succeeded::no; + // convert_array(data, scale_factor, in_data); + // return Succeeded::yes; + } +} + +template +inline Succeeded +read_data(IStreamT& s, TensorWrapper& data, NumericType type, ScaleT& scale, const ByteOrder byte_order) +{ + switch (type.id) + { + // define macro what to do with a specific NumericType +#define CASE(NUMERICTYPE) \ + case NUMERICTYPE: \ + return read_data(s, data, NumericInfo::type>(), scale, byte_order) + + // now list cases that we want + CASE(NumericType::SCHAR); + CASE(NumericType::UCHAR); + CASE(NumericType::SHORT); + CASE(NumericType::USHORT); + CASE(NumericType::INT); + CASE(NumericType::UINT); + CASE(NumericType::LONG); + CASE(NumericType::ULONG); + CASE(NumericType::FLOAT); + CASE(NumericType::DOUBLE); +#undef CASE + default: + warning("read_data : type not yet supported\n, at line %d in file %s", __LINE__, __FILE__); + return Succeeded::no; + } +} +#endif + template inline Succeeded read_data(IStreamT& s, @@ -103,14 +199,7 @@ read_data(IStreamT& s, Array& data, NumericType type, Sca return read_data(s, data, NumericInfo::type>(), scale, byte_order) // now list cases that we want - CASE(NumericType::SCHAR); - CASE(NumericType::UCHAR); - CASE(NumericType::SHORT); - CASE(NumericType::USHORT); CASE(NumericType::INT); - CASE(NumericType::UINT); - CASE(NumericType::LONG); - CASE(NumericType::ULONG); CASE(NumericType::FLOAT); CASE(NumericType::DOUBLE); #undef CASE @@ -120,4 +209,7 @@ read_data(IStreamT& s, Array& data, NumericType type, Sca } } +#ifdef STIR_WITH_TORCH +#endif + END_NAMESPACE_STIR diff --git a/src/include/stir/IO/write_data.h b/src/include/stir/IO/write_data.h index 2342d2d2c4..c8bb33ae1d 100644 --- a/src/include/stir/IO/write_data.h +++ b/src/include/stir/IO/write_data.h @@ -26,7 +26,10 @@ template class NumericInfo; template class Array; - +#ifdef STIR_WITH_TORCH +template +class TensorWrapper; +#endif /*! \ingroup Array_IO \brief Write the data of an Array to file. @@ -126,6 +129,30 @@ inline Succeeded write_data(OStreamT& s, const ByteOrder byte_order = ByteOrder::native, const bool can_corrupt_data = false); +#ifdef STIR_WITH_TORCH + +template +inline Succeeded write_data(OStreamT& s, + const TensorWrapper& data, + const ByteOrder byte_order = ByteOrder::native, + const bool can_corrupt_data = false); + +template +inline Succeeded write_data(OStreamT& s, + const TensorWrapper& data, + NumericInfo output_type, + ScaleT& scale_factor, + const ByteOrder byte_order = ByteOrder::native, + const bool can_corrupt_data = false); + +template +inline Succeeded write_data(OStreamT& s, + const TensorWrapper& data, + NumericType type, + ScaleT& scale, + const ByteOrder byte_order = ByteOrder::native, + const bool can_corrupt_data = false); +#endif END_NAMESPACE_STIR #include "stir/IO/write_data.inl" diff --git a/src/include/stir/IO/write_data.inl b/src/include/stir/IO/write_data.inl index bf8f0a9a9e..c67cf92891 100644 --- a/src/include/stir/IO/write_data.inl +++ b/src/include/stir/IO/write_data.inl @@ -24,7 +24,9 @@ #include "stir/IO/write_data_1d.h" #include "stir/warning.h" #include - +#ifdef STIR_WITH_TORCH +#include "stir/TensorWrapper.h" +#endif START_NAMESPACE_STIR namespace detail @@ -151,4 +153,96 @@ write_data(OStreamT& s, } } +#ifdef STIR_WITH_TORCH + +template +Succeeded +write_data(OStreamT& s, + const TensorWrapper& data, + NumericInfo output_type, + ScaleT& scale_factor, + const ByteOrder byte_order, + const bool can_corrupt_data) +{ + + // find_scale_factor(scale_factor, data, NumericInfo()); + + if (!data.is_contiguous()) { + throw std::runtime_error("Tensor must be contiguous to read data."); + } + + // if (!byte_order.is_native_order()) + // { + // Array& data_ref = const_cast&>(data); + // for (auto iter = data_ref.begin_all(); iter != data_ref.end_all(); ++iter) + // ByteOrder::swap_order(*iter); + // } + + + const std::streamsize num_to_write = static_cast(data.size_all()) * sizeof(elemT); + bool writing_ok = true; + try + { + s.write(reinterpret_cast(data.get_const_full_data_ptr()), num_to_write); + } + catch (...) + { + writing_ok = false; + } + + // data.release_const_full_data_ptr(); + + // if (!can_corrupt_data && !byte_order.is_native_order()) + // { + // Array& data_ref = const_cast&>(data); + // for (auto iter = data_ref.begin_all(); iter != data_ref.end_all(); ++iter) + // ByteOrder::swap_order(*iter); + // } + + if (!writing_ok || !s) + { + warning("write_data: error after writing to stream.\n"); + return Succeeded::no; + } + + return Succeeded::yes; +} + + +template +Succeeded +write_data(OStreamT& s, + const TensorWrapper& data, + NumericType type, + ScaleT& scale, + const ByteOrder byte_order, + const bool can_corrupt_data) +{ + if (NumericInfo().type_id() == type) + { + // you might want to use the scale even in this case, + // but at the moment we don't + scale = ScaleT(1); + return write_data(s, data, NumericInfo(), scale, byte_order, can_corrupt_data); + } + switch (type.id) + { + // define macro what to do with a specific NumericType +#define CASE(NUMERICTYPE) \ + case NUMERICTYPE: \ + return write_data(s, data, NumericInfo::type>(), scale, byte_order, can_corrupt_data) + + // now list cases that we want + CASE(NumericType::INT); + CASE(NumericType::FLOAT); + CASE(NumericType::DOUBLE); +#undef CASE + default: + warning("write_data : type not yet supported\n, at line %d in file %s", __LINE__, __FILE__); + return Succeeded::no; + } +} + +#endif + END_NAMESPACE_STIR diff --git a/src/include/stir/KeyParser.h b/src/include/stir/KeyParser.h index b53218a90d..9554627dd7 100644 --- a/src/include/stir/KeyParser.h +++ b/src/include/stir/KeyParser.h @@ -185,6 +185,8 @@ class KeyParser void add_key(const std::string& keyword, float* variable_ptr); //! add a vectorised keyword. When parsing, parse its value as a float and put it in \c (*variable_ptr)[current_index] void add_vectorised_key(const std::string& keyword, std::vector* variable_ptr); + + void add_vectorised_key(const std::string& keyword, std::vector>* variable_ptr); //! add a keyword. When parsing, parse its value as a double and put it in *variable_ptr void add_key(const std::string& keyword, double* variable_ptr); //! add a vectorised keyword. When parsing, parse its value as a double and put it in \c (*variable_ptr)[current_index] diff --git a/src/include/stir/NestedIterator.h b/src/include/stir/NestedIterator.h index 1d86280123..37f4cdbc83 100644 --- a/src/include/stir/NestedIterator.h +++ b/src/include/stir/NestedIterator.h @@ -26,7 +26,7 @@ #include "stir/NestedIteratorHelpers.h" #include "boost/iterator/iterator_traits.hpp" - +#include "boost/iterator/iterator_adaptor.hpp" START_NAMESPACE_STIR /*! diff --git a/src/include/stir/PixelsOnCartesianGrid.h b/src/include/stir/PixelsOnCartesianGrid.h index b2b5274311..8e5cc406b0 100644 --- a/src/include/stir/PixelsOnCartesianGrid.h +++ b/src/include/stir/PixelsOnCartesianGrid.h @@ -45,7 +45,11 @@ class PixelsOnCartesianGrid : public DiscretisedDensityOnCartesianGrid<2, elemT> inline PixelsOnCartesianGrid(); //! Construct PixelsOnCartesianGrid with the given array, origin and grid spacing +#ifdef STIR_WITH_TORCH + inline PixelsOnCartesianGrid(const TensorWrapper<2, elemT>& v, +#else inline PixelsOnCartesianGrid(const Array<2, elemT>& v, +#endif const CartesianCoordinate3D& origin, const BasicCoordinate<2, float>& grid_spacing); diff --git a/src/include/stir/PixelsOnCartesianGrid.inl b/src/include/stir/PixelsOnCartesianGrid.inl index 7fdaca8459..32bb40c2a7 100644 --- a/src/include/stir/PixelsOnCartesianGrid.inl +++ b/src/include/stir/PixelsOnCartesianGrid.inl @@ -26,7 +26,11 @@ PixelsOnCartesianGrid::PixelsOnCartesianGrid() {} template -PixelsOnCartesianGrid::PixelsOnCartesianGrid(const Array<2, elemT>& v, +#ifdef STIR_WITH_TORCH + PixelsOnCartesianGrid::PixelsOnCartesianGrid(const TensorWrapper<2, elemT>& v, +#else + PixelsOnCartesianGrid::PixelsOnCartesianGrid(const Array<2, elemT>& v, +#endif const CartesianCoordinate3D& origin, const BasicCoordinate<2, float>& grid_spacing) : DiscretisedDensityOnCartesianGrid<2, elemT>(v.get_index_range(), origin, grid_spacing) diff --git a/src/include/stir/RunTests.h b/src/include/stir/RunTests.h index b406962608..15df391914 100644 --- a/src/include/stir/RunTests.h +++ b/src/include/stir/RunTests.h @@ -33,6 +33,9 @@ #include #include #include +#ifdef STIR_WITH_TORCH +#include "stir/TensorWrapper.h" +#endif START_NAMESPACE_STIR @@ -157,6 +160,47 @@ class RunTests } return true; } +#ifdef STIR_WITH_TORCH + template + bool check_if_equal(const TensorWrapper<1, T>& t1, const TensorWrapper<1, T>& t2, const std::string& str = "") + { + if (t1.get_min_index() != t2.get_min_index() || t1.get_max_index() != t2.get_max_index()) + { + std::cerr << "Error: unequal ranges. " << str << std::endl; + return everything_ok = false; + } + + for (int i = t1.get_min_index(); i <= t1.get_max_index(); i++) + { + if (!check_if_equal(t1.at(i), t2.at(i), str)) + { + std::cerr << "(at TensorWrapper<" << typeid(T).name() << "> first mismatch at index " << i << ")\n"; + return everything_ok = false; + } + } + return true; + } + template + bool check_if_equal(const TensorWrapper<2, T>& t1, const TensorWrapper<2, T>& t2, const std::string& str = "") + { + if (t1.get_min_index() != t2.get_min_index() || t1.get_max_index() != t2.get_max_index()) + { + std::cerr << "Error: unequal ranges. " << str << std::endl; + return everything_ok = false; + } + + auto j = t2.begin_all_const(); + for (auto i = t1.begin_all_const(); i <= t1.end_all_const(); ++i, ++j) + { + if (!check_if_equal(*i, *j, str)) + { + std::cerr << "(at TensorWrapper<" << typeid(T).name() << "> first mismatch at index " << i << ")\n"; + return everything_ok = false; + } + } + return true; + } +#endif // VC 6.0 needs definition of template members in the class def unfortunately. //! check equality by comparing size and calling check_if_equal on all elements template @@ -244,6 +288,24 @@ class RunTests return true; } +#ifdef STIR_WITH_TORCH + + template + bool check_if_zero(const TensorWrapper& t, const std::string& str = "") + { + size_t index = 0; + for (auto it = t.begin_all_const(); it!= t.end_all_const(); ++it, ++index) + { + if (!check_if_zero(*it, str)) + { + std::cerr << "(at TensorWrapper<" << typeid(T).name() << "> first mismatch at index " << index << ")\n"; + return false; + } + } + return true; + } + +#endif //! compare norm with tolerance template bool check_if_zero(const BasicCoordinate& a, const std::string& str = "") diff --git a/src/include/stir/TensorWrapper.h b/src/include/stir/TensorWrapper.h new file mode 100644 index 0000000000..9e3df28f65 --- /dev/null +++ b/src/include/stir/TensorWrapper.h @@ -0,0 +1,699 @@ +#ifndef __stir_TENSOR_WRAPPER_H_ +#define __stir_TENSOR_WRAPPER_H_ + +#include +#include +#include +#include +#include +#include "stir/IndexRange.h" +#include "stir/shared_ptr.h" +#include "stir/BasicCoordinate.h" + +START_NAMESPACE_STIR + + // Template class declaration for TensorWrapper + template + class TensorWrapper { +private: + typedef TensorWrapper self; +public: + //! \name typedefs for iterator support + //@{ + typedef elemT value_type; + typedef value_type& reference; + typedef const value_type& const_reference; + typedef ptrdiff_t difference_type; + typedef elemT* iterator; + typedef elemT const* const_iterator; + + typedef std::reverse_iterator reverse_iterator; + typedef std::reverse_iterator const_reverse_iterator; + + // Define full_iterator as a type alias for iterator + using full_iterator = iterator; + using const_full_iterator = const_iterator; + + //@} + typedef size_t size_type; + + //!\name basic iterator support + //@{ + //! use to initialise an iterator to the first element of the vector + inline iterator begin(){ + if (is_on_gpu()) { + throw std::runtime_error("Iterator is not supported for tensors on the GPU."); + } + return tensor.data_ptr(); // Return a pointer to the start of the tensor's data + } + //! use to initialise an iterator to the first element of the (const) vector + inline const_iterator begin() const{ + if (is_on_gpu()) { + throw std::runtime_error("Iterator is not supported for tensors on the GPU."); + } + return tensor.data_ptr(); + } + + inline iterator end(){ + if (is_on_gpu()) { + throw std::runtime_error("Iterator is not supported for tensors on the GPU."); + } + return tensor.data_ptr() + tensor.numel(); // Return a pointer to the end of the tensor's data + } + //! iterator 'past' the last element of the (const) vector + inline const_iterator end() const{ + if (is_on_gpu()) { + throw std::runtime_error("Iterator is not supported for tensors on the GPU."); + } + return tensor.data_ptr() + tensor.numel(); + } + + inline reverse_iterator rbegin(){ + if (is_on_gpu()) { + throw std::runtime_error("Iterator is not supported for tensors on the GPU."); + } + return std::make_reverse_iterator(begin()); + } + inline reverse_iterator rend(){ + if (is_on_gpu()) { + throw std::runtime_error("Iterator is not supported for tensors on the GPU."); + } + return std::make_reverse_iterator(end()); + } + inline const_reverse_iterator rbegin() const{ + if (is_on_gpu()) { + throw std::runtime_error("Iterator is not supported for tensors on the GPU."); + } + return std::make_reverse_iterator(end()); + } + inline const_reverse_iterator rend() const{ + if (is_on_gpu()) { + throw std::runtime_error("Iterator is not supported for tensors on the GPU."); + } + return std::make_reverse_iterator(begin()); + } + + // Begin iterator for all elements + full_iterator begin_all() { + if (is_on_gpu()) { + throw std::runtime_error("full_iterator is not supported for tensors on the GPU."); + } + return tensor.data_ptr(); + } + + const_full_iterator begin_all_const() const { + if (is_on_gpu()) { + throw std::runtime_error("full_iterator is not supported for tensors on the GPU."); + } + return tensor.data_ptr(); + } + + const_full_iterator begin_all() const{ + return begin_all_const(); + } + + full_iterator end_all(){ + if (is_on_gpu()) { + throw std::runtime_error("full_iterator is not supported for tensors on the GPU."); + } + return tensor.data_ptr() + tensor.numel(); + } + + const_full_iterator end_all_const() const{ + if (is_on_gpu()) { + throw std::runtime_error("full_iterator is not supported for tensors on the GPU."); + } + return tensor.data_ptr() + tensor.numel(); + } + + const_full_iterator end_all() const { + return end_all_const(); + } + //@} + + TensorWrapper(); + TensorWrapper(const torch::Tensor& tensor, const std::vector& offsets, const std::string& device = "cpu"); + explicit TensorWrapper(const IndexRange &range, const std::string& device = "cpu"); + TensorWrapper(const IndexRange& range, shared_ptr data_sptr, const std::string& device = "cpu"); + TensorWrapper(const shared_ptr> shape, const std::string& device = "cpu"); + // TensorWrapper(const shared_ptr > discretised_density, const std::string& device = "cpu"); + + ~TensorWrapper() { + tensor.reset(); + } + + //! Copy constructor + inline TensorWrapper(const TensorWrapper& other) + : tensor(other.tensor.clone()), + device(other.device), + offsets(other.offsets), + ends(other.ends), strides(other.strides){} + + // friend inline void swap(TensorWrapper& first, TensorWrapper& second) noexcept// nothrow + // { + // using std::swap; + // // Swap the member variables + // swap(first.tensor, second.tensor); + // swap(first.device, second.device); + // swap(first.offsets, second.offsets); + // swap(first.ends, second.ends); + // swap(first.strides, second.strides); + // } + + //! move constructor + /*! implementation uses the copy-and-swap idiom, see e.g. https://stackoverflow.com/a/3279550 */ + // inline TensorWrapper(TensorWrapper&& other) noexcept + // { + // swap(*this, other); + // } + + + + template + elemT& at(Indices... indices) { + static_assert(sizeof...(indices) == num_dimensions, "Number of indices must match the number of dimensions."); + if (is_on_gpu()) { + throw std::runtime_error("at() is not supported for tensors on the GPU."); + } + return *(tensor.data_ptr() + compute_linear_index(indices...)); + } + + template + const elemT& at(Indices... indices) const { + static_assert(sizeof...(indices) == num_dimensions, "Number of indices must match the number of dimensions."); + if (is_on_gpu()) { + throw std::runtime_error("at() is not supported for tensors on the GPU."); + } + return *(tensor.data_ptr() + compute_linear_index(indices...)); + } + + //! TODO: I would love to remove the loop + elemT& at(const BasicCoordinate& coord) { + if (is_on_gpu()) { + throw std::runtime_error("at() with BasicCoordinate is not supported for tensors on the GPU."); + } + + // std::array indices; + // std::iota(indices.begin(), indices.end(), 1); + std::array indices; + for (int i = 0; i < num_dimensions; ++i) { + indices[i] = coord[i + 1]; // BasicCoordinate uses 1-based indexing + } + return *(tensor.data_ptr() + compute_linear_index(indices)); + } + + //! TODO: I would love to remove the loop + const elemT& at(const BasicCoordinate& coord) const { + if (is_on_gpu()) { + throw std::runtime_error("at() with BasicCoordinate is not supported for tensors on the GPU."); + } + + std::array indices; + // std::iota(indices.begin(), indices.end(), 1); + // std::array indices; + for (int i = 0; i < num_dimensions; ++i) { + indices[i] = coord[i + 1]; // BasicCoordinate uses 1-based indexing + } + return *(tensor.data_ptr() + compute_linear_index(indices)); + } + +private: + + void compute_strides() { + strides.resize(num_dimensions); + strides[num_dimensions - 1] = 1; // Last dimension has a stride of 1 + if(num_dimensions > 1) + for (int i = num_dimensions - 2; i >= 0; --i) { + strides[i] = tensor.size(i + 1) * strides[i + 1]; + } + } + + // template + // size_t compute_linear_index(Indices... indices) const { + // std::array dims = {indices...}; + // size_t linear_index = 0; + // size_t stride = 1; + // for (int i = num_dimensions - 1; i >= 0; --i) { + // linear_index += dims[i] * stride; + // stride *= tensor.size(i); + // } + // return linear_index; + // } + + //! Maybe I should be checking if strides has been initialized. + //! But this will be called often, so skip. + template + size_t compute_linear_index(Indices... indices) const { + if (is_on_gpu()) { + // Use PyTorch's tensor indexing for GPU tensors + torch::Tensor indices_tensor = torch::tensor({indices...}, torch::kInt64).to(tensor.device()); + torch::Tensor linear_index_tensor = indices_tensor * torch::tensor(strides, torch::kInt64).to(tensor.device()); + return linear_index_tensor.sum().item(); + + // torch::Tensor offset_tensor = torch::tensor(offsets, torch::kInt64).to(tensor.device()); + // torch::Tensor indices_tensor = torch::tensor({indices...}, torch::kInt64).to(tensor.device()); + // torch::Tensor linear_index_tensor = (indices_tensor - offset_tensor) * torch::tensor(strides, torch::kInt64).to(tensor.device()); + // return linear_index_tensor.sum().item(); + + } else { + // NE: No offset version, safe. + // // Use cached strides for CPU tensors + // std::array dims = {indices...}; + // size_t linear_index = 0; + // for (int i = 0; i < num_dimensions; ++i) { + // linear_index += dims[i] * strides[i]; + // } + // return linear_index; + //NE: with offsets. + std::array dims = {indices...}; + + size_t linear_index = 0; + for (int i = 0; i < num_dimensions; ++i) { + int adjusted_index = dims[i] - offsets[i]; // Subtract the offset + if (adjusted_index < 0 || adjusted_index >= tensor.size(i)) { + throw std::out_of_range("Index out of bounds after applying offset."); + } + linear_index += adjusted_index * strides[i]; + } + return linear_index; + } + } + +public: + + inline bool operator==(const self& iv) const{ + if(offsets!=iv.offsets) + return false; + return torch::equal(tensor, iv.tensor); + } + + inline bool operator!=(const self& iv) const{ + return !(*this == iv); + } + + template ::value>::type> + inline TensorWrapper& operator+=(const Scalar v){ + tensor.add_(static_cast(v)); + return *this; + } + + template ::value>::type> + inline TensorWrapper& operator*=(const Scalar v){ + tensor.mul_(static_cast(v)); + return *this; + } + + template ::value>::type> + inline TensorWrapper& operator-=(const Scalar v){ + tensor.sub_(static_cast(v)); + return *this; + } + + template ::value>::type> + inline TensorWrapper& operator/=(const Scalar v){ + if (v == 0) { + throw std::invalid_argument("Division by zero is not allowed."); + } + tensor.div_(static_cast(v)); + return *this; + } + + template ::value>::type> + inline TensorWrapper operator+(const Scalar v){ + return TensorWrapper(tensor + static_cast(v), offsets, device); + } + + template ::value>::type> + inline TensorWrapper operator*(const Scalar v){ + return TensorWrapper(tensor * static_cast(v), offsets, device); + } + + template ::value>::type> + inline TensorWrapper operator-(const Scalar v){ + return TensorWrapper(tensor - static_cast(v), offsets, device); + } + + template ::value>::type> + inline TensorWrapper operator/(const Scalar v){ + if (v == 0) { + throw std::invalid_argument("Division by zero is not allowed."); + } + return TensorWrapper(tensor / static_cast(v), offsets, device); + } + + inline TensorWrapper& operator+=(const TensorWrapper& v) + { + if (!tensor.sizes().equals(v.tensor.sizes())) { + throw std::invalid_argument("Tensors must have the same shape for element-wise addition."); + } + if (is_on_gpu() != v.is_on_gpu()) { + throw std::invalid_argument("Both tensors must be on the same device for element-wise division."); + } + tensor.add_(v.tensor); // In-place addition of another tensor + return *this; + } + + inline TensorWrapper& operator*=(const TensorWrapper& v) + { + if (!tensor.sizes().equals(v.tensor.sizes())) { + throw std::invalid_argument("Tensors must have the same shape for element-wise addition."); + } + if (is_on_gpu() != v.is_on_gpu()) { + throw std::invalid_argument("Both tensors must be on the same device for element-wise division."); + } + tensor.mul_(v.tensor); // In-place addition of another tensor + return *this; + } + + // In-place element-wise subtraction with another TensorWrapper + inline TensorWrapper& operator-=(const TensorWrapper& other) + { + if (tensor.sizes() != other.tensor.sizes()) { + throw std::invalid_argument("Tensors must have the same shape for element-wise subtraction."); + } + if (is_on_gpu() != other.is_on_gpu()) { + throw std::invalid_argument("Both tensors must be on the same device for element-wise subtraction."); + } + tensor.sub_(other.tensor); // In-place subtraction + return *this; + } + + inline TensorWrapper& operator/=(const TensorWrapper& v) + { + if (!tensor.sizes().equals(v.tensor.sizes())) { + throw std::invalid_argument("Tensors must have the same shape for element-wise division."); // TODO + } + if (is_on_gpu() != v.is_on_gpu()) { + throw std::invalid_argument("Both tensors must be on the same device for element-wise division."); + } + tensor.div_(v.tensor); + return *this; + } + + // Element-wise subtraction between two TensorWrapper objects + inline TensorWrapper operator+(const TensorWrapper& v) const { + if (tensor.sizes() != v.tensor.sizes()) { + throw std::invalid_argument("Tensors must have the same shape for element-wise subtraction."); + } + if (is_on_gpu() != v.is_on_gpu()) { + throw std::invalid_argument("Both tensors must be on the same device for element-wise subtraction."); + } + return TensorWrapper(tensor + v.tensor, offsets, device); + } + + // Element-wise subtraction between two TensorWrapper objects + inline TensorWrapper operator-(const TensorWrapper& v) const { + if (tensor.sizes() != v.tensor.sizes()) { + throw std::invalid_argument("Tensors must have the same shape for element-wise subtraction."); + } + if (is_on_gpu() != v.is_on_gpu()) { + throw std::invalid_argument("Both tensors must be on the same device for element-wise subtraction."); + } + return TensorWrapper(tensor - v.tensor, offsets, device); + } + + //! Efficent slice-wise tensor multiplication the the values in a vector that holds a scalar for each slice + // inline TensorWrapper& operator*=(const std::vector& scalars) + inline void slice_wise_mult(const std::vector& scalars) + { + // The tensor must be 3D + if (tensor.dim() != 3) { + throw std::invalid_argument("Input tensor must be 3D."); + } + + if (tensor.size(0) != static_cast(scalars.size())) { + throw std::invalid_argument("Number of scalars must match the number of slices in the tensor."); + } + + // Convert the scalars vector to a 1D tensor + torch::Tensor scalar_tensor = torch::tensor(scalars, tensor.options()); + scalar_tensor = scalar_tensor.view({-1, 1, 1}); // Shape: [num_slices, 1, 1] + tensor *= scalar_tensor; + } + + + //! assignment operator - Deep copy + TensorWrapper& operator=(const TensorWrapper& other) + { + if (this == &other) { + return *this; + } + // swap(*this, other); + // Perform a deep copy of the tensor + tensor = other.tensor.clone(); // Create a new tensor with the same data + // Copy other members + device = other.device; + offsets = other.offsets; + ends = other.ends; + strides = other.strides; + + return *this; + } + + inline bool is_regular() const {return true;} + +private: + inline void init(const IndexRange& range, elemT* const data_ptr, bool copy_data = true) + { + // Determine the shape of the tensor + std::vector shape = convertIndexRangeToShape(range); + + if (data_ptr == nullptr) + { + // Allocate a new tensor with the required shape + tensor = torch::zeros(shape, torch::TensorOptions().dtype(getTorchDtype())); + } + else + { + tensor = torch::tensor( + std::vector(data_ptr, data_ptr + size_all()), // Copy data into a vector + torch::TensorOptions().dtype(getTorchDtype())).reshape(shape); // Set dtype and shape + } + + // Extract offsets + offsets = extract_offsets_recursive(range); + std::cerr << "OFFSETS SIZE" << offsets.size() << std::endl; + ends.resize(offsets.size()); + for(int i=0; i(tensor.size(dim)); + } + + inline size_t size(int dim = 0) const + { + return static_cast(tensor.size(dim)); + } + + inline size_t size_all() const{return tensor.numel();} + + elemT find_max() const; + + elemT sum() const; + + elemT sum_positive() const; + + elemT find_min() const; + + void fill(const elemT& n); + + void mask_cyl(const float radius, const bool in = false); + + void apply_lower_threshold(const elemT& l); + + void apply_upper_threshold(const elemT& u); + + // TensorWrapper get_empty_copy() const; + + inline void _xapyb(const torch::Tensor& x, const elemT a, const torch::Tensor& y, const elemT b) + { + if (!x.sizes().equals(y.sizes())) { + throw std::invalid_argument("Tensors x and y must have the same shape"); + } + tensor = a * x + b * y; + } + + inline void xapyb(const TensorWrapper& x, const elemT a, const TensorWrapper& y, const elemT b) + { + _xapyb(x.tensor, a, y.tensor, b); + } + + inline void _xapyb(const torch::Tensor& x, const torch::Tensor& a, const torch::Tensor& y, const torch::Tensor& b) + { + if (!x.sizes().equals(a.sizes()) || + !x.sizes().equals(y.sizes()) || + !x.sizes().equals(b.sizes())) { + throw std::invalid_argument("Tensors x and y must have the same shape"); + } + tensor = a * x + b * y; + } + + //! set values of the array to x*a+y*b, where a and b are arrays + inline void xapyb(const TensorWrapper& x, const TensorWrapper& a, const TensorWrapper& y, const TensorWrapper& b) + { + _xapyb(x.tensor, a.tensor, y.tensor, b.tensor); + } + + //! set values of the array to self*a+y*b where a and b are scalar or arrays + template + inline void sapyb(const T& a, const TensorWrapper& y, const T& b) + { + this->xapyb(*this, a, y, b); + } + + // std::unique_ptr> clone() const; + torch::Tensor& getTensor(); + const torch::Tensor& getTensor() const; + void printSizes() const; + + stir::IndexRange get_index_range() const; + + void resize(const stir::IndexRange& range); + + void grow(const stir::IndexRange& range); + + inline bool is_contiguous() const{ + return tensor.is_contiguous(); + } + + //! member function for access to the data via a elemT* + inline elemT* get_full_data_ptr(){ + if (!this->is_contiguous()) + error("Array::get_full_data_ptr() called for non-contiguous array."); + return &(*this->begin_all()); + } + + inline const elemT* get_const_full_data_ptr() const{ + if (!this->is_contiguous()) + error("Array::get_full_data_ptr() called for non-contiguous array."); + return &(*this->begin_all_const()); + } + + + void print() const; + + inline int get_min_index(int dim = 0) const{return offsets[dim];} + + inline int get_max_index(int dim = 0) const{return ends[dim];} + + inline void set_offset(const int min_index, int dim = 0){ + if (dim > tensor.dim()){ + throw std::invalid_argument("dim must be less than the number of dimensions."); + } + offsets[dim] = min_index; + ends[dim] -= min_index; + } + + inline void set_offset(const std::vector& new_offsets) { + if (new_offsets.size() != tensor.dim()) { + throw std::invalid_argument("Number of offsets must match the number of dimensions."); + } + offsets = new_offsets; + } + +protected: + torch::Tensor tensor; + std::string device; // Change from torch::Device to std::string + std::vector offsets; + std::vector ends; + std::vector strides; + + inline std::vector convertIndexRangeToShape(const stir::IndexRange& range) const + { + std::vector shape; + computeShapeRecursive(range, shape); + return shape; + } + + template + inline void computeShapeRecursive(const stir::IndexRange& range, std::vector& shape) const + { + // Get the size of the current dimension + shape.push_back(range.get_max_index() - range.get_min_index() + 1); + if constexpr (current_dim > 1) { + computeShapeRecursive(range[range.get_min_index()], shape); + } + } + + // std::vector getShapeFromDiscretisedDensity(const stir::shared_ptr> discretised_density) const + // { + // const stir::IndexRange range = discretised_density->get_index_range(); + // std::vector shape = convertIndexRangeToShape(range); + // return shape; + // } + + // void fillTensorFromDiscretisedDensity(const stir::shared_ptr> discretised_density) + // { + // auto accessor = tensor.accessor(); + // fillRecursive((*discretised_density), + // accessor, + // discretised_density->get_index_range()); + // } + + // template + // inline void fillRecursive(const stir::Array & array, + // Accessor accessor, const stir::IndexRange& range) + // { + // // std::cout << "IN" << std::endl; + // if constexpr (current_dim == 1) { + // // std::cout << "DIM1" << std::endl; + // // Base case: Fill the last dimension + // int min_index = range.get_min_index(); + // for (int i = range.get_min_index(); i < range.get_max_index(); ++i) { + // // std::cout << i << " " << i - range.get_min_index() << std::endl; + // accessor[i - min_index] = array[i]; + + // } + // } else { + // // Recursive case: Traverse the current dimension + // int min_index = range.get_min_index(); + // // std::cout << "CUrrent dim" << current_dim << std::endl; + // for (int i = range.get_min_index(); i < range.get_max_index(); ++i) { + // // std::cout << "dc " << i << " " << i-range.get_min_index() << std::endl; + // fillRecursive(array[i], accessor[i-min_index], range[i]); + // } + // } + // } + +private: + torch::Dtype getTorchDtype() const { + if constexpr (std::is_same::value) { + return torch::kFloat; + } else if constexpr (std::is_same::value) { + return torch::kDouble; + } else if constexpr (std::is_same::value) { + return torch::kInt; + } + else { + throw std::invalid_argument("Unsupported data type"); + } + } + + template + inline std::vector extract_offsets_recursive(const stir::IndexRange & range) { + std::vector result; + result.push_back(range.get_min_index()); // Get the minimum index for the current dimension + + if constexpr (current_dim > 1) { + // Recurse into the next dimension + auto sub_offsets = extract_offsets_recursive(range[range.get_min_index()]); + result.insert(result.end(), sub_offsets.begin(), sub_offsets.end()); + } + + return result; + } +}; + +END_NAMESPACE_STIR + +#endif // TENSOR_WRAPPER_H diff --git a/src/include/stir/VoxelsOnCartesianGrid.h b/src/include/stir/VoxelsOnCartesianGrid.h index dd06b6291f..112c0874ae 100644 --- a/src/include/stir/VoxelsOnCartesianGrid.h +++ b/src/include/stir/VoxelsOnCartesianGrid.h @@ -25,6 +25,7 @@ */ +#include "stir/Array.h" #include "stir/DiscretisedDensityOnCartesianGrid.h" #include "stir/CartesianCoordinate3D.h" @@ -55,7 +56,11 @@ static VoxelsOnCartesianGrid ask_parameters(); VoxelsOnCartesianGrid(); //! Construct a VoxelsOnCartesianGrid, initialising data from the Array<3,elemT> object. +#ifdef STIR_WITH_TORCH + VoxelsOnCartesianGrid(const TensorWrapper<3, elemT>& v, +#else VoxelsOnCartesianGrid(const Array<3, elemT>& v, +#endif const CartesianCoordinate3D& origin, const BasicCoordinate<3, float>& grid_spacing); @@ -67,7 +72,11 @@ static VoxelsOnCartesianGrid ask_parameters(); //! Construct a VoxelsOnCartesianGrid, initialising data from the Array<3,elemT> object. VoxelsOnCartesianGrid(const shared_ptr& exam_info_sptr, - const Array<3, elemT>& v, +#ifdef STIR_WITH_TORCH + const TensorWrapper<3, elemT>& v, +#else + const Array<3, elemT>& v, +#endif const CartesianCoordinate3D& origin, const BasicCoordinate<3, float>& grid_spacing); @@ -149,13 +158,13 @@ static VoxelsOnCartesianGrid ask_parameters(); VoxelsOnCartesianGrid* #endif clone() const override; - +#ifndef STIR_WITH_TORCH //! Extract a single plane PixelsOnCartesianGrid get_plane(const int z) const; //! Set a single plane void set_plane(const PixelsOnCartesianGrid& plane, const int z); - +#endif //! is the same as get_grid_spacing(), but now returns CartesianCoordinate3D for convenience inline CartesianCoordinate3D get_voxel_size() const; diff --git a/src/include/stir/VoxelsOnCartesianGrid.inl b/src/include/stir/VoxelsOnCartesianGrid.inl index dda78bb60e..7833d1e4ce 100644 --- a/src/include/stir/VoxelsOnCartesianGrid.inl +++ b/src/include/stir/VoxelsOnCartesianGrid.inl @@ -33,35 +33,52 @@ template int VoxelsOnCartesianGrid::get_min_z() const { - return this->get_min_index(); + return this->get_min_index(0); } template int VoxelsOnCartesianGrid::get_min_y() const { + #ifdef STIR_WITH_TORCH + return this->get_min_index(1); +#else return this->get_length() == 0 ? 0 : (*this)[get_min_z()].get_min_index(); +#endif } template int VoxelsOnCartesianGrid::get_min_x() const { +#ifdef STIR_WITH_TORCH + return this->get_min_index(2); +#else return this->get_length() == 0 ? 0 : (*this)[get_min_z()][get_min_y()].get_min_index(); +#endif + } template int VoxelsOnCartesianGrid::get_x_size() const { - return this->get_length() == 0 ? 0 : (*this)[get_min_z()][get_min_y()].get_length(); +#ifdef STIR_WITH_TORCH + return this->get_length(2); +#else + return this->get_length() == 0 ? 0 : (*this)[get_min_z()][get_min_y()].get_length(); +#endif } template int VoxelsOnCartesianGrid::get_y_size() const { +#ifdef STIR_WITH_TORCH + return this->get_length(1); +#else return this->get_length() == 0 ? 0 : (*this)[get_min_z()].get_length(); +#endif } template @@ -75,14 +92,23 @@ template int VoxelsOnCartesianGrid::get_max_x() const { - return this->get_length() == 0 ? 0 : (*this)[get_min_z()][get_min_y()].get_max_index(); +#ifdef STIR_WITH_TORCH + return this->get_max_index(1); +#else + return this->get_length() == 0 ? 0 : (*this)[get_min_z()][get_min_y()].get_max_index(); +#endif } template int VoxelsOnCartesianGrid::get_max_y() const { +#ifdef STIR_WITH_TORCH + return this->get_max_index(2); +#else return this->get_length() == 0 ? 0 : (*this)[get_min_z()].get_max_index(); +#endif + } template diff --git a/src/include/stir/make_array.h b/src/include/stir/make_array.h index f5359a4811..fb6f62d0a7 100644 --- a/src/include/stir/make_array.h +++ b/src/include/stir/make_array.h @@ -178,6 +178,18 @@ inline Array make_array(const Array& a8, const Array& a9); +#ifdef STIR_WITH_TORCH + +// #include "stir/TensorWrapper.h" + +// template +// inline TensorWrapper<1, T> make_1d_tensor(const T& a0); + +// template +// inline TensorWrapper<1, T> make_1d_tensor(const T& a0, const T& a1); + +#endif + END_NAMESPACE_STIR #include "stir/make_array.inl" diff --git a/src/include/stir/make_array.inl b/src/include/stir/make_array.inl index dbb4568093..8c5dcf5f21 100644 --- a/src/include/stir/make_array.inl +++ b/src/include/stir/make_array.inl @@ -380,4 +380,23 @@ make_array(const Array& a0, return a; } +#ifdef STIR_WITH_TORCH +// template +// Array +// make_array(const Array& a0) +// { +// const Array<1, T> a = NumericVectorWithOffset(make_vector(a0)); +// return a; +// } + +// template +// Array +// make_array(const Array& a0, const Array& a1) +// { +// const Array a = NumericVectorWithOffset, T>(make_vector(a0, a1)); +// return a; +// } + +#endif + END_NAMESPACE_STIR diff --git a/src/include/stir/modelling/ParametricDiscretisedDensity.h b/src/include/stir/modelling/ParametricDiscretisedDensity.h index 6e5b192b2e..b201e7ef22 100644 --- a/src/include/stir/modelling/ParametricDiscretisedDensity.h +++ b/src/include/stir/modelling/ParametricDiscretisedDensity.h @@ -20,7 +20,6 @@ */ -#include "stir/DiscretisedDensity.h" #include "stir/NestedIterator.h" // for ParametricVoxelsOnCartesianGrid typedef #include "stir/VoxelsOnCartesianGrid.h" diff --git a/src/swig/CMakeLists.txt b/src/swig/CMakeLists.txt index bc52432edc..d8177ed321 100644 --- a/src/swig/CMakeLists.txt +++ b/src/swig/CMakeLists.txt @@ -92,6 +92,7 @@ set(swig_stir_dependencies stir_coordinates.i stir_dataprocessors.i stir_exam.i + stir_tensorwrapper.i stir_LOR.i stir_normalisation.i stir_objectivefunctions.i @@ -119,6 +120,13 @@ if(BUILD_SWIG_PYTHON) ".... Python_NumPy_INCLUDE_DIRS: ${Python_NumPy_INCLUDE_DIRS}") set(STIR_Python_dependency Python::Module Python::NumPy) + # Ensure PyTorch Python bindings are linked + find_library(TORCH_PYTHON_LIBRARY torch_python PATHS ${TORCH_INSTALL_PREFIX}/lib) + if (NOT TORCH_PYTHON_LIBRARY) + message(FATAL_ERROR "torch_python library not found. Ensure PyTorch is installed with Python bindings.") + endif() + message(STATUS "Found torch_python library: ${TORCH_PYTHON_LIBRARY}") + # TODO probably better to call the module stirpy or something # TODO -builtin option only appropriate for python # while the next statement sets it for all modules called stir @@ -129,7 +137,7 @@ if(BUILD_SWIG_PYTHON) set_property(TARGET ${SWIG_MODULE_stir_REAL_NAME} PROPERTY SWIG_GENERATED_COMPILE_OPTIONS /bigobj) endif() SWIG_WORKAROUND(${SWIG_MODULE_stir_REAL_NAME}) - SWIG_LINK_LIBRARIES(stir PUBLIC ${STIR_LIBRARIES} ${STIR_Python_dependency}) + SWIG_LINK_LIBRARIES(stir PUBLIC ${STIR_LIBRARIES} ${STIR_Python_dependency} ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY}) target_link_libraries(${SWIG_MODULE_stir_REAL_NAME} PUBLIC ${OpenMP_EXE_LINKER_FLAGS}) CONFIGURE_FILE(./pyfragments.swg ./ COPYONLY) diff --git a/src/swig/stir.i b/src/swig/stir.i index 284c89a041..aca17e762a 100644 --- a/src/swig/stir.i +++ b/src/swig/stir.i @@ -1037,6 +1037,7 @@ ADD_REPR(stir::Succeeded, %arg($self->succeeded() ? "yes" : "no")); %include "stir_array.i" %include "stir_exam.i" +%include "stir_tensorwrapper.i" %shared_ptr(stir::DataSymmetriesForViewSegmentNumbers); %include "stir_projdata_coords.i" diff --git a/src/swig/stir_tensorwrapper.i b/src/swig/stir_tensorwrapper.i new file mode 100644 index 0000000000..d4ba70d19a --- /dev/null +++ b/src/swig/stir_tensorwrapper.i @@ -0,0 +1,59 @@ +/* + SWIG interface file for TensorWrapper<3, float> + Provides Python bindings for TensorWrapper<3, float> with PyTorch tensor support. +*/ + +// %module stir_tensorwrapper + +%{ + +#include +#include "stir/TensorWrapper.h" +#include +%} + +%include "stir/TensorWrapper.h" + +%shared_ptr(stir::TensorWrapper<3, float>); + +namespace stir +{ +// Explicitly instantiate TensorWrapper<3, float> +%template(TensorWrapper3DFloat) TensorWrapper<3, float>; + +// %extend stir::TensorWrapper<3, float> { +// // Convert TensorWrapper to a PyTorch tensor +// PyObject* to_python_tensor() { +// return THPVariable_Wrap(self->getTensor()); +// } + +// // Set a PyTorch tensor into TensorWrapper +// void from_python_tensor(PyObject* py_tensor) { +// if (!THPVariable_Check(py_tensor)) { +// // SWIG_exception(SWIG_TypeError, "Expected a PyTorch tensor"); +// } +// auto tensor = THPVariable_Unpack(py_tensor); +// self->getTensor() = tensor; +// } + +// // Constructor to create TensorWrapper from a PyTorch tensor +// TensorWrapper<3, float>(PyObject* py_tensor, const std::string& device = "cpu") { +// if (!THPVariable_Check(py_tensor)) { +// // SWIG_exception(SWIG_TypeError, "Expected a PyTorch tensor"); +// } +// auto tensor = THPVariable_Unpack(py_tensor); +// return new stir::TensorWrapper<3, float>(tensor, device); +// } + +// // Print the device of the tensor +// void print_device() { +// self->print_device(); +// } + +// // Print the sizes of the tensor +// void print_sizes() { +// self->printSizes(); +// } +// } + +} \ No newline at end of file diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index a7c1621b36..5be2d5deba 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -17,7 +17,13 @@ set(buildblock_simple_tests test_Array.cxx test_VectorWithOffset.cxx ) - + +if (STIR_WITH_TORCH) + list(APPEND buildblock_simple_tests + test_Tensor.cxx + ) +endif() + if (NOT MINI_STIR) set(${dir_SIMPLE_TEST_EXE_SOURCES} diff --git a/src/test/test_Tensor.cxx b/src/test/test_Tensor.cxx new file mode 100644 index 0000000000..b57cc899e6 --- /dev/null +++ b/src/test/test_Tensor.cxx @@ -0,0 +1,402 @@ +/*! + \file + \ingroup test + \ingroup TensorWrapper + + \brief tests for the stir::TensorWrapper class + + \author Nikos Efthimiou +*/ + +#ifndef NDEBUG +// set to high level of debugging +# ifdef _DEBUG +# undef _DEBUG +# endif +# define _DEBUG 2 +#endif + +#include "stir/Array.h" +#include "stir/TensorWrapper.h" +#include "stir/make_array.h" +#include "stir/Coordinate2D.h" +#include "stir/Coordinate3D.h" +#include "stir/Coordinate4D.h" +#include "stir/convert_array.h" +#include "stir/Succeeded.h" +#include "stir/IO/write_data.h" +#include "stir/IO/read_data.h" + +#include "stir/RunTests.h" + +#include "stir/ArrayFunction.h" +#include "stir/array_index_functions.h" +#include "stir/copy_fill.h" +#include +#include + +// for open_read/write_binary +#include "stir/utilities.h" +#include "stir/info.h" +#include "stir/error.h" + +#include "stir/HighResWallClockTimer.h" + +#include +#include +#include +#include +using std::ofstream; +using std::ifstream; +using std::plus; +using std::cerr; +using std::endl; + +START_NAMESPACE_STIR + +namespace detail +{ + +// static TensorWrapper<2, float> +// test_make_array() +// { +// return make_array(make_1d_array(1.F, 0.F, 0.F), make_1d_array(0.F, 1.F, 1.F), make_1d_array(0.F, -2.F, 2.F)); +// } +} // namespace detail + +/*! + \brief Tests Array functionality + \ingroup test + \warning Running this will create and delete 2 files with names + output.flt and output.other. Existing files with these names will be overwritten. + +*/ +class TensorTests : public RunTests +{ +private: + // // this function tests the next() function and compare it to using full_iterators + // // sadly needs to be declared in the class for VC 6.0 + // template + // void run_tests_on_next(const Array& test) + // { + // // exit if empty array (as do..while() loop would fail) + // if (test.size() == 0) + // return; + + // BasicCoordinate index = get_min_indices(test); + // typename Array::const_full_iterator iter = test.begin_all(); + // do + // { + // check(*iter == test[index], "test on next(): element out of sequence?"); + // ++iter; + // } while (next(index, test) && (iter != test.end_all())); + // check(iter == test.end_all(), "test on next() : did we cover all elements?"); + // } + + // // functions that runs IO tests for an array of arbitrary dimension + // // sadly needs to be declared in the class for VC 6.0 + // template + // void run_IO_tests(const Array& t1) + // { + // std::fstream os; + // std::fstream is; + // run_IO_tests_with_file_args(os, is, t1); + // FILE* ofptr; + // FILE* ifptr; + // run_IO_tests_with_file_args(ofptr, is, t1); + // run_IO_tests_with_file_args(ofptr, ifptr, t1); + // } + // template + // void run_IO_tests_with_file_args(OFSTREAM& os, IFSTREAM& is, const Array& t1) + // { + // { + // open_write_binary(os, "output.flt"); + // check(write_data(os, t1) == Succeeded::yes, "write_data could not write array"); + // close_file(os); + // } + // Array t2(t1.get_index_range()); + // { + // open_read_binary(is, "output.flt"); + // check(read_data(is, t2) == Succeeded::yes, "read_data could not read from output.flt"); + // close_file(is); + // } + // check_if_equal(t1, t2, "test out/in"); + // remove("output.flt"); + + // { + // open_write_binary(os, "output.flt"); + // const Array copy = t1; + // check(write_data(os, t1, ByteOrder::swapped) == Succeeded::yes, "write_data could not write array with swapped byte order"); + // check_if_equal(t1, copy, "test out with byte-swapping didn't change the array"); + // close_file(os); + // } + // { + // open_read_binary(is, "output.flt"); + // check(read_data(is, t2, ByteOrder::swapped) == Succeeded::yes, "read_data could not read from output.flt"); + // close_file(is); + // } + // check_if_equal(t1, t2, "test out/in (swapped byte order)"); + // remove("output.flt"); + + // cerr << "\tTests writing as shorts\n"; + // run_IO_tests_mixed(os, is, t1, NumericInfo()); + // cerr << "\tTests writing as floats\n"; + // run_IO_tests_mixed(os, is, t1, NumericInfo()); + // cerr << "\tTests writing as signed chars\n"; + // run_IO_tests_mixed(os, is, t1, NumericInfo()); + + // /* check on failed IO. + // Note: needs to be after the others, as we would have to call os.clear() + // for ostream to be able to write again, but that's not defined for FILE*. + // */ + // { + // const Array copy = t1; + // cerr << "\n\tYou should now see a warning that writing failed. That's by intention.\n"; + // check(write_data(os, t1, ByteOrder::swapped) != Succeeded::yes, "write_data with swapped byte order should have failed"); + // check_if_equal(t1, copy, "test out with byte-swapping didn't change the array even with failed IO"); + // } + // } + + // //! function that runs IO tests with mixed types for array of arbitrary dimension + // // sadly needs to be implemented in the class for VC 6.0 + // template + // void run_IO_tests_mixed(OFSTREAM& os, + // IFSTREAM& is, + // const Array& orig, + // NumericInfo output_type_info) + // { + // { + // open_write_binary(os, "output.orig"); + // elemT scale(1); + // check(write_data(os, orig, NumericInfo(), scale) == Succeeded::yes, + // "write_data could not write array in original data type"); + // close_file(os); + // check_if_equal(scale, static_cast(1), "test out/in: data written in original data type: scale factor should be 1"); + // } + // elemT scale(1); + // bool write_data_ok; + // { + // ofstream os; + // open_write_binary(os, "output.other"); + // write_data_ok = check(write_data(os, orig, output_type_info, scale) == Succeeded::yes, + // "write_data could not write array as other_type"); + // close_file(os); + // } + + // if (write_data_ok) + // { + // // only do reading test if data was written + // Array data_read_back(orig.get_index_range()); + // { + // open_read_binary(is, "output.other"); + // check(read_data(is, data_read_back) == Succeeded::yes, "read_data could not read from output.other"); + // close_file(is); + // remove("output.other"); + // } + + // // compare with convert() + // { + // float newscale = static_cast(scale); + // Array origconverted = convert_array(newscale, orig, NumericInfo()); + // check_if_equal(newscale, scale, "test read_data <-> convert : scale factor "); + // check_if_equal(origconverted, data_read_back, "test read_data <-> convert : data"); + // } + + // // compare orig/scale with data_read_back + // { + // const Array orig_scaled(orig / scale); + // this->check_array_equality_with_rounding( + // orig_scaled, data_read_back, "test out/in: data written as other_type, read as other_type"); + // } + + // // compare data written as original, but read as other_type + // { + // Array data_read_back2(orig.get_index_range()); + + // ifstream is; + // open_read_binary(is, "output.orig"); + + // elemT in_scale = 0; + // check(read_data(is, data_read_back2, NumericInfo(), in_scale) == Succeeded::yes, + // "read_data could not read from output.orig"); + // // compare orig/in_scale with data_read_back2 + // const Array orig_scaled(orig / in_scale); + // this->check_array_equality_with_rounding( + // orig_scaled, data_read_back2, "test out/in: data written as original_type, read as other_type"); + // } + // } // end of if(write_data_ok) + // remove("output.orig"); + // } + + // //! a special version of check_if_equal just for this class + // /*! we check up to .5 if output_type is integer, and up to tolerance otherwise + // */ + // template + // bool check_array_equality_with_rounding(const Array& orig, + // const Array& data_read_back, + // const char* const message) + // { + // NumericInfo output_type_info; + // bool test_failed = false; + // typename Array::const_full_iterator diff_iter = orig.begin_all(); + // typename Array::const_full_iterator data_read_back_iter = data_read_back.begin_all_const(); + // while (diff_iter != orig.end_all()) + // { + // if (output_type_info.integer_type()) + // { + // std::stringstream full_message; + // // construct useful error message even though we use a boolean check + // full_message << boost::format("unequal values are %2% and %3%. %1%: difference larger than .5") % message + // % static_cast(*data_read_back_iter) % *diff_iter; + // // difference should be maximum .5 (but we test with slightly larger tolerance to accomodate numerical precision) + // test_failed = check(fabs(*diff_iter - *data_read_back_iter) <= .502, full_message.str().c_str()); + // } + // else + // { + // std::string full_message = message; + // full_message += ": difference larger than tolerance"; + // test_failed = check_if_equal(static_cast(*data_read_back_iter), *diff_iter, full_message.c_str()); + // } + // if (test_failed) + // break; + // diff_iter++; + // data_read_back_iter++; + // } + // return test_failed; + // } + +public: + void run_tests() override; +}; + +// // helper function to create a shared_ptr that doesn't delete the data (as it's still owned by the vector) +// template +// shared_ptr +// vec_to_shared(std::vector& v) +// { +// shared_ptr sptr(v.data(), [](auto) {}); +// return sptr; +// } + +void +TensorTests::run_tests() +{ + cerr << "Testing Tensor classes\n"; + { + cerr << "Testing 1D stuff" << endl; + { + TensorWrapper<1, int> testint(IndexRange<1>(5)); + testint.at(1) = 2; + check_if_equal(testint.size(), size_t(5), "test size()"); + check_if_equal(testint.size_all(), size_t(5), "test size_all()"); + TensorWrapper<1, float> test(IndexRange<1>(10)); + check_if_zero(test, "Array1D not initialised to 0"); + test.at(1) = 10.5f; + test.set_offset(-1); + check_if_equal(test.size(), size_t(10), "test size() with non-zero offset"); + check_if_equal(test.size_all(), size_t(10), "test size_all() with non-zero offset"); + check_if_equal(test.at(0), 10.5F, "test indexing of Array1D"); + test += 1; + check_if_equal(test.at(0), 11.5F, "test operator+=(float)"); + check_if_equal(test.sum(), 20.5F, "test operator+=(float) and sum()"); + check_if_zero(test - test, "test operator-(Tensor1D)"); + BasicCoordinate<1, int> c; + c[1] = 0; + check_if_equal(test.at(c), 11.5F, "test at(BasicCoordinate)"); + test.at(c) = 12.5; + check_if_equal(test.at(c), 12.5F, "test at(BasicCoordinate)"); + { + //! NE: Here, test_Array calls for partial specialisations, or support for + //! factories. I will skip that in favour of IndexRange<> + // TensorWrapper<1, float> ref(-1, 2); + const IndexRange<1> range(-1,1); + TensorWrapper<1, float> ref(range); + ref.at(-1) = 1.F; + ref.at(0) = 3.F; + ref.at(1) = 3.14F; + TensorWrapper<1, float> test = ref; + + test += 1.f; + for (int i = ref.get_min_index(); i <= ref.get_max_index(); ++i){ + check_if_equal(test.at(i), ref.at(i) + 1, "test operator+=(float)"); + } + test = ref; + test -= 4; + for (int i = ref.get_min_index(); i <= ref.get_max_index(); ++i) + check_if_equal(test.at(i), ref.at(i) - 4, "test operator-=(float)"); + test = ref; + test *= 3; + for (int i = ref.get_min_index(); i <= ref.get_max_index(); ++i) + check_if_equal(test.at(i), ref.at(i) * 3, "test operator*=(float)"); + test = ref; + test /= 3; + for (int i = ref.get_min_index(); i <= ref.get_max_index(); ++i) + check_if_equal(test.at(i), ref.at(i) / 3, "test operator/=(float)"); + } + TensorWrapper<1, float> test2; + test2 = test * 2; + check_if_equal(2 * test.at(0), test2.at(0), "test operator*(float)"); + { + + } +#if 1 + { + // tests on log/exp + const IndexRange<1> range(-3,9); + TensorWrapper<1, float> test(range); + test.fill(1.F); + in_place_log(test); + { + TensorWrapper<1, float> testeq(range); + check_if_equal(test, testeq, "test in_place_log of TensorWrapper<1,float>"); + } + { + for (int i = test.get_min_index(); i <= test.get_max_index(); i++) + test.at(i) = 3.5F * i + 100; + } + TensorWrapper<1, float> test_copy = test; + in_place_log(test); + in_place_exp(test); + check_if_equal(test, test_copy, "test log/exp of Array1D"); + } +#endif + } + + { + cerr << "Testing 2D stuff" << endl; + { + const IndexRange<2> range(Coordinate2D(0, 0), Coordinate2D(9, 9)); + TensorWrapper<2, float> test2(range); + check_if_equal(test2.size(), size_t(10), "test size()"); + check_if_equal(test2.size_all(), size_t(100), "test size_all()"); + check_if_zero(test2, "test TensorWrapper<2, float> not initialised to 0"); + test2.at(3,4) = 23.3f; + } + { + IndexRange<2> range(Coordinate2D(0, 0), Coordinate2D(3, 3)); + TensorWrapper<2, float> testfp(range); + TensorWrapper<2, float> t2fp(range); + testfp.at(3,2) = 3.3F; + t2fp.at(3,2) = 2.2F; + TensorWrapper<2, float> t2 = t2fp + testfp; + check_if_equal(t2.at(3,2), 5.5F, "test operator +(Array2D)"); + t2fp += testfp; + check_if_equal(t2fp.at(3,2), 5.5F, "test operator +=(Array2D)"); + check_if_equal(t2, t2fp, "test comparing Array2D+= and +"); + } + } + + } +} + +END_NAMESPACE_STIR + +USING_NAMESPACE_STIR + +int +main() +{ + TensorTests tests; + tests.run_tests(); + return tests.main_return_value(); +} diff --git a/src/utilities/CMakeLists.txt b/src/utilities/CMakeLists.txt index d726cfe460..b6946926f3 100644 --- a/src/utilities/CMakeLists.txt +++ b/src/utilities/CMakeLists.txt @@ -75,6 +75,7 @@ if (NOT MINI_STIR) find_sum_projection_of_viewgram_and_sinogram.cxx separate_true_from_random_scatter_for_necr.cxx stir_timings.cxx + pytorch_playground.cxx ) if (HAVE_ITK) diff --git a/src/utilities/pytorch_playground.cxx b/src/utilities/pytorch_playground.cxx new file mode 100644 index 0000000000..b2b7431b5b --- /dev/null +++ b/src/utilities/pytorch_playground.cxx @@ -0,0 +1,309 @@ +// +// + +/*! +\file +\ingroup utilities +\brief this executable is not meant to do something specific, other than facilitate the developlement of the Pytorch interface. +Heavily inspired by compare_images.cxxs + +\author Nikos Efthimiou +*/ + + +#include "stir/DiscretisedDensity.h" +#include "stir/ArrayFunction.h" +#include "stir/recon_array_functions.h" +#include "stir/IO/read_from_file.h" +#include "stir/is_null_ptr.h" +#include "stir/warning.h" +#include +#include +#include // For timing +#include +#include + +using std::cerr; +using std::cout; +using std::endl; + +USING_NAMESPACE_STIR + +//********************** main + +int +main(int argc, char* argv[]) +{ + if (argc < 3 || argc > 7) + { + cerr << "Usage: \n" + << argv[0] << "\n\t" + << "[-r rimsize] \n\t" + << "[-t tolerance] \n\t" + << "old_image new_image \n\t" + << "'rimsize' has to be a nonnegative integer.\n\t" + << "'tolerance' is by default .0005 \n\t" + << "When the -r option is used, the (radial) rim of the\n\t" + << "images will be set to 0, for 'rimsize' pixels.\n"; + return (EXIT_FAILURE); + } + // skip program name + --argc; + ++argv; + int rim_truncation_image = -1; + float tolerance = .0005F; + + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + + if (err != cudaSuccess) { + std::cerr << "CUDA error: " << cudaGetErrorString(err) << std::endl; + return -1; + } + + if (device_count == 0) { + std::cout << "No CUDA devices available." << std::endl; + return 0; + } + + std::cout << "Number of CUDA devices: " << device_count << std::endl; + + for (int device = 0; device < device_count; ++device) { + cudaDeviceProp device_prop; + cudaGetDeviceProperties(&device_prop, device); + + std::cout << "Device " << device << ": " << device_prop.name << std::endl; + std::cout << " CUDA Capability: " << device_prop.major << "." << device_prop.minor << std::endl; + std::cout << " Total Memory: " << device_prop.totalGlobalMem / (1024 * 1024) << " MB" << std::endl; + std::cout << " Multiprocessors: " << device_prop.multiProcessorCount << std::endl; + std::cout << " Clock Rate: " << device_prop.clockRate / 1000 << " MHz" << std::endl; + } + + // first process command line options + while (argc > 0 && argv[0][0] == '-') + { + if (strcmp(argv[0], "-r") == 0) + { + if (argc < 2) + { + cerr << "Option '-r' expects a nonnegative (integer) argument\n"; + exit(EXIT_FAILURE); + } + rim_truncation_image = atoi(argv[1]); + argc -= 2; + argv += 2; + } + if (strcmp(argv[0], "-t") == 0) + { + if (argc < 2) + { + cerr << "Option '-t' expects a (float) argument\n"; + exit(EXIT_FAILURE); + } + tolerance = static_cast(atof(argv[1])); + argc -= 2; + argv += 2; + } + } + + shared_ptr> first_operand(read_from_file>(argv[0])); + + if (stir::is_null_ptr(first_operand)) + { + cerr << "Could not read first file\n"; + exit(EXIT_FAILURE); + } + + stir::IndexRange d = first_operand->get_index_range(); + + std::cout << "D:" << d.get_min_index() << " " << d.get_max_index() << " " << d.get_length() << std::endl; + std::cout << "D:" << d[0].get_min_index() << " " << d[0].get_max_index() << " " << d[0].get_length() << std::endl; + std::cout << "G:" << d[0][0].get_min_index() << " " << d[0][0].get_max_index() << " " << d[0][0].get_length() << std::endl; + + std::vector shape({d.get_length(), d[0].get_length(), d[0][0].get_length()}); + // // Allocate Tensor + // // TensorWrapper<3, float> tw(shape); + // TensorWrapper tw(first_operand); + + // // Print the tensor's device + // tw.print_device(); + // // Move the tensor to the GPU (if available) + // try { + // tw.to_gpu(); + // std::cout << "Moved tw to GPU." << std::endl; + // } catch (const std::runtime_error& e) { + // std::cerr << e.what() << std::endl; + // } + // // Print the tensor's device again + // tw.print_device(); + + // std::cout << " H " << std::endl; + // tw.printSizes(); + // std::cout << " H2 " << std::endl; + + // const float a = 1000; + // const float b = 5000; + // // Start timing + // { + // auto start = std::chrono::high_resolution_clock::now(); + // std::cout << "stir::Array MAX value: " << (*first_operand).find_max() << std::endl; + // auto end = std::chrono::high_resolution_clock::now(); + // auto duration = std::chrono::duration_cast(end - start).count(); + // std::cout << "stir::Array: Time to compute max value: " << duration << " ms" << std::endl; + // std::cout << "\n" << std::endl; + // } + // { + // auto start = std::chrono::high_resolution_clock::now(); + // std::cout << "Tensored: MAX value: " << tw.find_max() << std::endl; + // auto end = std::chrono::high_resolution_clock::now(); + // auto duration = std::chrono::duration_cast(end - start).count(); + // std::cout << "Tensored: Time to compute max value: " << duration << " ms" << std::endl; + // std::cout << "\n" << std::endl; + // } + + // { + // auto start = std::chrono::high_resolution_clock::now(); + // std::cout << "stir::Array SUM value: " << (*first_operand).sum() << std::endl; + // auto end = std::chrono::high_resolution_clock::now(); + // auto duration = std::chrono::duration_cast(end - start).count(); + // std::cout << "stir::Array: Time to compute SUM value: " << duration << " ms" << std::endl; + // std::cout << "\n" << std::endl; + // } + // { + // auto start = std::chrono::high_resolution_clock::now(); + // std::cout << "Tensored: SUM value: " << tw.sum() << std::endl; + // auto end = std::chrono::high_resolution_clock::now(); + // auto duration = std::chrono::duration_cast(end - start).count(); + // std::cout << "Tensored: Time to compute SUM value: " << duration << " ms" << std::endl; + // std::cout << "\n" << std::endl; + // } + + // { + // auto start = std::chrono::high_resolution_clock::now(); + // std::cout << "stir::Array SUM_pos value: " << (*first_operand).sum_positive() << std::endl; + // auto end = std::chrono::high_resolution_clock::now(); + // auto duration = std::chrono::duration_cast(end - start).count(); + // std::cout << "stir::Array: Time to compute SUM_pos value: " << duration << " ms" << std::endl; + // std::cout << "\n" << std::endl; + // } + // { + // auto start = std::chrono::high_resolution_clock::now(); + // std::cout << "Tensored: SUM_pos value: " << tw.sum_positive() << std::endl; + // auto end = std::chrono::high_resolution_clock::now(); + // auto duration = std::chrono::duration_cast(end - start).count(); + // std::cout << "Tensored: Time to compute SUM_pos value: " << duration << " ms" << std::endl; + // std::cout << "\n" << std::endl; + // } + + // std::cout << "____________XAPYB__________" << std::endl; + // { + // auto cloned_empty_first_operand = first_operand->get_empty_copy(); + // auto cloned_first_operand = first_operand->clone(); + // auto start = std::chrono::high_resolution_clock::now(); + // cloned_empty_first_operand->xapyb((*first_operand), a, *cloned_first_operand, b); + // std::cout << "stir::Array: MAX value after xapyb: " << cloned_empty_first_operand->find_max() << std::endl; + // auto end = std::chrono::high_resolution_clock::now(); + // auto duration = std::chrono::duration_cast(end - start).count(); + // std::cout << "stir::Array: Time to compute max value after xapyb: : " << duration << " ms" << std::endl; + // std::cout << "\n" << std::endl; + // } + + // { + // auto cloned_empty_tw = tw;//.get_empty_copy(); + // cloned_empty_tw.fill(0); + // cloned_empty_tw.print_device(); + // try { + // cloned_empty_tw.to_gpu(); + // std::cout << "Moved cloned_empty_tw to GPU." << std::endl; + // } catch (const std::runtime_error& e) { + // std::cerr << e.what() << std::endl; + // } + // // Print the tensor's device again + // cloned_empty_tw.print_device(); + + // auto cloned_tw = tw;//.clone(); + // cloned_tw.print_device(); + // try { + // cloned_tw.to_gpu(); + // std::cout << "Moved cloned_tw to GPU." << std::endl; + // } catch (const std::runtime_error& e) { + // std::cerr << e.what() << std::endl; + // } + // // Print the tensor's device again + // cloned_tw.print_device(); + + // auto start = std::chrono::high_resolution_clock::now(); + // cloned_empty_tw.xapyb(tw, a, cloned_tw, b); + // std::cout << "Tensored: MAX value after xapyb: " << cloned_empty_tw.find_max() << std::endl; + // auto end = std::chrono::high_resolution_clock::now(); + // auto duration = std::chrono::duration_cast(end - start).count(); + // std::cout << "Tensored: Time to compute max value: " << duration << " ms" << std::endl; + // std::cout << "\n" << std::endl; + // } + + // { + // auto start = std::chrono::high_resolution_clock::now(); + // std::cout << "Tensored first_operand MAX value: " << tw.find_max() << std::endl; + // auto end = std::chrono::high_resolution_clock::now(); + // auto duration = std::chrono::duration_cast(end - start).count(); + // std::cout << "Time to compute max value: " << duration << " ms" << std::endl; + // } + + + std::cout << "STOP HERE" << std::endl; + return 1; + + // shared_ptr> second_operand(read_from_file>(argv[1])); + // if (is_null_ptr(second_operand)) + // { + // cerr << "Could not read 2nd file\n"; + // exit(EXIT_FAILURE); + // } + + // // check if images are compatible + // { + // std::string explanation; + // if (!first_operand->has_same_characteristics(*second_operand, explanation)) + // { + // warning("input images do not have the same characteristics.\n%s", explanation.c_str()); + // return EXIT_FAILURE; + // } + // } + + // if (rim_truncation_image >= 0) + // { + // truncate_rim(*first_operand, rim_truncation_image); + // truncate_rim(*second_operand, rim_truncation_image); + // } + + // float reference_max = first_operand->find_max(); + // float reference_min = first_operand->find_min(); + + // float amplitude = fabs(reference_max) > fabs(reference_min) ? fabs(reference_max) : fabs(reference_min); + + // *first_operand -= *second_operand; + // const float max_error = first_operand->find_max(); + // const float min_error = first_operand->find_min(); + // in_place_abs(*first_operand); + // const float max_abs_error = first_operand->find_max(); + + // const bool same = (max_abs_error / amplitude <= tolerance); + + // cout << "\nMaximum absolute error = " << max_abs_error << "\nMaximum in (1st - 2nd) = " << max_error + // << "\nMinimum in (1st - 2nd) = " << min_error << endl; + // cout << "Error relative to sup-norm of first image = " << (max_abs_error / amplitude) * 100 << " %" << endl; + + // cout << "\nImage arrays "; + + // if (same) + // { + // cout << (max_abs_error == 0 ? "are " : "deemed ") << "identical\n"; + // } + // else + // { + // cout << "deemed different\n"; + // } + // cout << "(tolerance used: " << tolerance * 100 << " %)\n\n"; + // return same ? EXIT_SUCCESS : EXIT_FAILURE; + +} // end main diff --git a/src/utilities/stir_math.cxx b/src/utilities/stir_math.cxx index 0e54d8594d..c08261671f 100644 --- a/src/utilities/stir_math.cxx +++ b/src/utilities/stir_math.cxx @@ -107,6 +107,8 @@ \author Kris Thielemans */ +//NE; will be removed later +#include #include "stir/ArrayFunction.h" #include "stir/DiscretisedDensity.h" #include "stir/SegmentByView.h" @@ -156,6 +158,22 @@ process_data(const string& output_file_name, const OutputFileFormat& output_format) { unique_ptr image_ptr = read_from_file(*argv); +//NE; will be removed later + //Create a 3D tensor of size (3, 4, 5) filled with random values + auto tensor = torch::rand({3, 4, 5}); + + // Print the tensor + std::cout << "3D Tensor:" << std::endl; + std::cout << tensor << std::endl; + + // Access a specific element (e.g., at position [1, 2, 3]) + std::cout << "Element at [1, 2, 3]: " << tensor[1][2][3].item() << std::endl; + + // Perform operations on the tensor (e.g., multiply by 2) + auto modified_tensor = tensor * 2; + std::cout << "Modified Tensor (multiplied by 2):" << std::endl; + std::cout << modified_tensor << std::endl; + if (!no_math_on_data && !except_first) in_place_apply_function(*image_ptr, pow_times_add_object); diff --git a/src/utilities/stir_timings.cxx b/src/utilities/stir_timings.cxx index 8377225aa3..0e2e08167a 100644 --- a/src/utilities/stir_timings.cxx +++ b/src/utilities/stir_timings.cxx @@ -26,9 +26,11 @@ #include "stir/ProjDataInterfile.h" #include "stir/ProjDataInMemory.h" #include "stir/DiscretisedDensity.h" +#include "stir/TimedObject.h" #include "stir/VoxelsOnCartesianGrid.h" #include "stir/IO/read_from_file.h" #include "stir/IO/write_to_file.h" +#include "stir/recon_buildblock/ProjectorByBinPair.h" #ifndef MINI_STIR # include "stir/recon_buildblock/ProjectorByBinPairUsingProjMatrixByBin.h" #endif