-
Notifications
You must be signed in to change notification settings - Fork 100
Initial PyTorch support #1583
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Initial PyTorch support #1583
Conversation
| if (STIR_WITH_TORCH) | ||
| target_link_libraries(${executable} ${TORCH_LIBRARIES}) | ||
| target_include_directories(${executable} PUBLIC ${TORCH_INCLUDE_DIRS}) | ||
| endif() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
irrelevant comment for now, but do we really need this for all exes? The mod to src/buildblock/CMakeLists.txt should be enough (due to transitive dependence)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not ... but let's see :)
| if (STIR_WITH_TORCH) | ||
| target_link_libraries(${executable} ${TORCH_LIBRARIES}) | ||
| target_include_directories(${executable} PUBLIC ${TORCH_INCLUDE_DIRS}) | ||
| endif() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as in src/cmake/stir_exe_targets.cmake
* On going work to make SWIG happy.
* Almost ready to ready the data in Tensors. * Implemented Tensor Iterators, strictly for CPU usage.
A lot of work in the IO and buildblock. We should be able to read and write all the data types in the IO module, in Tensors and hopefully have not broken anything for Array in the process. * My confusion comes because I don't know, now how many of these fucntions are used by ProjData.
|
This is where we are now:
|
It is a bigger headache than I thought.
We start the tensor tests. Access the tensor elements through the .at() method, that ensures that best speed and efficiency in CPU. GPU access should allowed only be with linear operators.
* test_Tensor replicates the tests from test_Array and is mostly ready for:
- testing for simple operations with and without offsets
- 1D and 2D (somewhat work needed)
|
Made quite a bit of progress, a few more boring parts left: |
KrisThielemans
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about creating an alias to minimise #ifdef STIR_TORCH more. ATM, there's quite a few non-changes because of this. e.g.
ArrayTypeFwd.h
namespace stir
{
template <typename elemT, int num_dimensions>
class Array;
#ifndef STIR_TORCH
template <typename elemT, int num_dimensions>
using ArrayType = Array<elemT, num_dimensions>
#else
template <typename elemT, int num_dimensions>
TensorWrapper;
template <typename elemT, int num_dimensions>
using ArrayType = TensorWrapper<elemT, num_dimensions>
#endif
}
and ArrayType.h
#include "stir/ArrayTypeFwd.h"
#include "stir/Array.h"
#ifdef STIR_TORCH
#include "stir/TensorWrapper.h"
#endif
then use ArrayType wherever you want to replace Array with TensorWrapper
I'm not 100% sure about using syntax though.
|
@KrisThielemans |
|
Looks ok, but I guess I'd write it with a loop in terms of the index (which is closer to how STIR will be using it) auto iter = test.begin_all_const();
for (index[1] = test.get_min_index(0); index[1] <= test.get_max_index(0); ++index[1])
for (index[2] = test.get_min_index(1); index[2] <= test.get_max_index(1); ++index[2])
{
check(*iter == test.at(index), "test on next(): element out of sequence?");
++iter;
}The choice to start from 1 was a very bad one... At some point, we'll have to introduce a new index-type that starts from 0 (could probably just use |
I know. I wanted to replicate closely the spelling of that test. it is based on a similar do .. while |
|
@KrisThielemans in NumericVector operators you have this line This automatically resizes the vector if it is smaller. I did it like this |
Changes in this pull request
Testing performed
Related issues
Checklist before requesting a review
documentation/release_XXX.mdhas been updated with any functionality change (if applicable)