diff --git a/react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp b/react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp index c4a288fc0..3497a3c52 100644 --- a/react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp +++ b/react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp @@ -38,6 +38,17 @@ using namespace facebook; namespace { +/** + * Check if the input string is zero or a positive integer. + * + * @param s String to check for zero or positive integer. + * @return `true` if input string is zero or a positive integer, or `false` + * otherwise. + */ +bool isZeroOrPositiveInteger(const std::string& s) { + return !s.empty() && std::all_of(s.begin(), s.end(), ::isdigit); +} + jsi::Value absImpl( jsi::Runtime& runtime, const jsi::Value& thisValue, @@ -702,22 +713,20 @@ jsi::Value TensorHostObject::get( return jsi::Value(runtime, toString_); } - int idx = -1; - try { - idx = std::stoi(name.c_str()); - } catch (const std::exception& e) { - // Cannot parse name value to int. This can happen when the name in bracket - // or dot notion is not an int (e.g., tensor['foo']). - // Let's ignore this exception here since this function will return - // undefined if it reaches the function end. - } - // Check if index is within bounds of dimension 0 - if (idx >= 0 && idx < this->tensor.size(0)) { - auto outputTensor = this->tensor.index({idx}); - auto tensorHostObject = - std::make_shared( - runtime, std::move(outputTensor)); - return jsi::Object::createFromHostObject(runtime, tensorHostObject); + // Check if prop name is zero or a positive integer, and if so it will access + // the tensor via the tensor indexing API: + // + // https://pytorch.org/cppdocs/notes/tensor_indexing.html + if (isZeroOrPositiveInteger(name)) { + int idx = std::stoi(name.c_str()); + // Check if index is within bounds of dimension 0 + if (idx >= 0 && idx < this->tensor.size(0)) { + auto outputTensor = this->tensor.index({idx}); + auto tensorHostObject = + std::make_shared( + runtime, std::move(outputTensor)); + return jsi::Object::createFromHostObject(runtime, tensorHostObject); + } } return BaseHostObject::get(runtime, propNameId);