diff --git a/react-native-pytorch-core/android/src/main/cpp/torchlive/media/NativeJSRefBridge.cpp b/react-native-pytorch-core/android/src/main/cpp/torchlive/media/NativeJSRefBridge.cpp index fb2318d04..d338fa13d 100644 --- a/react-native-pytorch-core/android/src/main/cpp/torchlive/media/NativeJSRefBridge.cpp +++ b/react-native-pytorch-core/android/src/main/cpp/torchlive/media/NativeJSRefBridge.cpp @@ -18,6 +18,11 @@ #include "./image/Image.h" #include "./image/JIImage.h" +#if __has_include() +#define HAS_VISION_CAMERA +#include +#endif + namespace torchlive { using namespace facebook::jni; @@ -72,6 +77,23 @@ std::shared_ptr imageFromFile(std::string filepath) { return std::make_shared(make_global(image)); } +std::shared_ptr imageFromFrame(jsi::Runtime& runtime, jsi::Object frameHostObject) { +#ifdef HAS_VISION_CAMERA + auto hostObject = frameHostObject.asHostObject(runtime); + + auto mediaUtilsClass = getMediaUtilsClass(); + auto imageFromImageProxyMethod = + mediaUtilsClass->getStaticMethod(local_ref)>( + "imageFromImageProxy"); + // TODO: Figure out how to get Context here + local_ref image = + imageFromImageProxyMethod(mediaUtilsClass, hostObject->frame, nullptr); + return std::make_shared(make_global(image)); +#else + throw jsi::JSError(runtime, "Error converting Frame to Image - VisionCamera is not properly installed!"); +#endif +} + std::shared_ptr imageFromBlob(const Blob& blob, double width, double height) { auto mediaUtilsClass = getMediaUtilsClass(); diff --git a/react-native-pytorch-core/android/src/main/java/org/pytorch/rn/core/media/MediaUtils.java b/react-native-pytorch-core/android/src/main/java/org/pytorch/rn/core/media/MediaUtils.java index 412f6cecb..1bcdcc1e0 100644 --- a/react-native-pytorch-core/android/src/main/java/org/pytorch/rn/core/media/MediaUtils.java +++ b/react-native-pytorch-core/android/src/main/java/org/pytorch/rn/core/media/MediaUtils.java @@ -59,6 +59,12 @@ public static IImage imageFromFile(final String filepath) { return new Image(bitmap); } + @DoNotStrip + @Keep + public static IImage imageFromImageProxy(final ImageProxy imageProxy, Context context) { + return new Image(imageProxy, context); + } + @DoNotStrip @Keep public static IImage imageFromBlob( diff --git a/react-native-pytorch-core/cxx/src/torchlive/media/BlobHostObject.cpp b/react-native-pytorch-core/cxx/src/torchlive/media/BlobHostObject.cpp index 9dd405efa..2148e8e1d 100644 --- a/react-native-pytorch-core/cxx/src/torchlive/media/BlobHostObject.cpp +++ b/react-native-pytorch-core/cxx/src/torchlive/media/BlobHostObject.cpp @@ -28,80 +28,95 @@ jsi::Value BlobObjectWithNoData(jsi::Runtime& runtime) { } // namespace -static jsi::Value arrayBufferImpl( - jsi::Runtime& runtime, - const jsi::Value& thisValue, - const jsi::Value* arguments, - size_t count) { - utils::ArgumentParser args(runtime, thisValue, arguments, count); - const auto& blob = args.thisAsHostObject()->blob; - auto promiseValue = torchlive::createPromiseAsJSIValue( - runtime, - [&blob](jsi::Runtime& rt, std::shared_ptr promise) { - auto buffer = blob->getDirectBytes(); - auto size = blob->getDirectSize(); - jsi::ArrayBuffer arrayBuffer = - rt.global() - .getPropertyAsFunction(rt, "ArrayBuffer") - .callAsConstructor(rt, static_cast(size)) - .asObject(rt) - .getArrayBuffer(rt); - std::memcpy(arrayBuffer.data(rt), buffer, size); - auto typedArray = rt.global() - .getPropertyAsFunction(rt, "Uint8Array") - .callAsConstructor(rt, std::move(arrayBuffer)) - .asObject(rt); - promise->resolve(std::move(typedArray)); - }); - return promiseValue; -} - -static jsi::Value sliceImpl( - jsi::Runtime& runtime, - const jsi::Value& thisValue, - const jsi::Value* arguments, - size_t count) { - utils::ArgumentParser args(runtime, thisValue, arguments, count); - const auto& blob = args.thisAsHostObject()->blob; - auto blobSize = static_cast(blob->getDirectSize()); - - // Default values - int start = 0; - int end = blobSize; - - // Optinal inputs - if (args.count() > 0) { - start = args.asInteger(0); - } - if (args.count() > 1) { - end = args.asInteger(1); - } - - // Invalid cases - if (std::abs(start) > blobSize || std::abs(end) > blobSize) { - return BlobObjectWithNoData(runtime); - } +namespace { - if (start < 0) { - start = blobSize + start; - } - if (end < 0) { - end = blobSize + end; - } +jsi::Value arrayBufferImpl( + jsi::Runtime& runtime, + const jsi::Value& thisValue, + const jsi::Value* arguments, + size_t count) { + utils::ArgumentParser args(runtime, thisValue, arguments, count); + const auto& blob = args.thisAsHostObject()->blob; + auto promiseValue = torchlive::createPromiseAsJSIValue( + runtime, + [&blob](jsi::Runtime& rt, std::shared_ptr promise) { + auto buffer = blob->getDirectBytes(); + auto size = blob->getDirectSize(); + jsi::ArrayBuffer arrayBuffer = + rt.global() + .getPropertyAsFunction(rt, "ArrayBuffer") + .callAsConstructor(rt, static_cast(size)) + .asObject(rt) + .getArrayBuffer(rt); + std::memcpy(arrayBuffer.data(rt), buffer, size); + auto typedArray = rt.global() + .getPropertyAsFunction(rt, "Uint8Array") + .callAsConstructor(rt, std::move(arrayBuffer)) + .asObject(rt); + promise->resolve(std::move(typedArray)); + }); + return promiseValue; +} - // More invalid cases - if (start >= end) { - return BlobObjectWithNoData(runtime); - } +jsi::Value sliceImpl( + jsi::Runtime& runtime, + const jsi::Value& thisValue, + const jsi::Value* arguments, + size_t count) { + utils::ArgumentParser args(runtime, thisValue, arguments, count); + const auto& blob = args.thisAsHostObject()->blob; + auto blobSize = static_cast(blob->getDirectSize()); + + // Default values + int start = 0; + int end = blobSize; + + // Optinal inputs + if (args.count() > 0) { + start = args.asInteger(0); + } + if (args.count() > 1) { + end = args.asInteger(1); + } + + // Invalid cases + if (std::abs(start) > blobSize || std::abs(end) > blobSize) { + return BlobObjectWithNoData(runtime); + } + + if (start < 0) { + start = blobSize + start; + } + if (end < 0) { + end = blobSize + end; + } + + // More invalid cases + if (start >= end) { + return BlobObjectWithNoData(runtime); + } + + // Implement slice(start, end) + auto size = end - start; + auto buffer = std::unique_ptr(new uint8_t[size]); + std::memcpy(buffer.get(), blob->getDirectBytes() + start, size); + + auto blobHostObject = std::make_shared( + runtime, std::make_unique(std::move(buffer), size)); + return jsi::Object::createFromHostObject(runtime, std::move(blobHostObject)); +} - // Implement slice(start, end) - auto size = end - start; - auto buffer = std::unique_ptr(new uint8_t[size]); - std::memcpy(buffer.get(), blob->getDirectBytes() + start, size); +jsi::Value releaseImpl( + jsi::Runtime& runtime, + const jsi::Value& thisValue, + const jsi::Value* arguments, + size_t count) { +utils::ArgumentParser args(runtime, thisValue, arguments, count); + args.requireNumArguments(0); + args.thisAsHostObject()->blob = nullptr; + return jsi::Value::undefined(); +} - auto blobHostObject = std::make_shared( - runtime, std::make_unique(std::move(buffer), size)); - return jsi::Object::createFromHostObject(runtime, std::move(blobHostObject)); } BlobHostObject::BlobHostObject( @@ -116,6 +131,7 @@ BlobHostObject::BlobHostObject( // Functions setPropertyHostFunction(runtime, "arrayBuffer", 0, arrayBufferImpl); setPropertyHostFunction(runtime, "slice", 0, sliceImpl); + setPropertyHostFunction(runtime, "release", 0, releaseImpl); } } // namespace media diff --git a/react-native-pytorch-core/cxx/src/torchlive/media/MediaNamespace.cpp b/react-native-pytorch-core/cxx/src/torchlive/media/MediaNamespace.cpp index 9e6eb2373..462955197 100644 --- a/react-native-pytorch-core/cxx/src/torchlive/media/MediaNamespace.cpp +++ b/react-native-pytorch-core/cxx/src/torchlive/media/MediaNamespace.cpp @@ -170,6 +170,26 @@ jsi::Value imageFromTensorImpl( runtime, std::move(image)); } +jsi::Value imageFromFrameImpl( + jsi::Runtime& runtime, + const jsi::Value& thisValue, + const jsi::Value* arguments, + size_t count) { + auto args = utils::ArgumentParser(runtime, thisValue, arguments, count); + args.requireNumArguments(1); + + std::shared_ptr image; + try { + image = torchlive::media::imageFromFrame(runtime, args[0].asObject(runtime)); + } catch (const std::exception& e) { + throw jsi::JSError( + runtime, + std::string("error on converting frame to image!\n") + e.what()); + } + return utils::helpers::createFromHostObject( + runtime, std::move(image)); +} + jsi::Value toBlobImpl( jsi::Runtime& runtime, const jsi::Value& thisValue, @@ -225,6 +245,7 @@ jsi::Object buildNamespace(jsi::Runtime& rt, RuntimeExecutor rte) { jsi::Object ns(rt); setPropertyHostFunction(rt, ns, "imageFromBlob", 3, imageFromBlobImpl); setPropertyHostFunction(rt, ns, "imageFromTensor", 1, imageFromTensorImpl); + setPropertyHostFunction(rt, ns, "imageFromFrame", 1, imageFromFrameImpl); setPropertyHostFunction(rt, ns, "imageFromFile", 1, imageFromFileImpl); setPropertyHostFunction(rt, ns, "toBlob", 1, toBlobImpl); setPropertyHostFunction(rt, ns, "imageToFile", 1, imageToFileImpl); diff --git a/react-native-pytorch-core/cxx/src/torchlive/media/NativeJSRefBridge.h b/react-native-pytorch-core/cxx/src/torchlive/media/NativeJSRefBridge.h index eaa982765..e16220c0c 100644 --- a/react-native-pytorch-core/cxx/src/torchlive/media/NativeJSRefBridge.h +++ b/react-native-pytorch-core/cxx/src/torchlive/media/NativeJSRefBridge.h @@ -18,6 +18,8 @@ namespace torchlive { namespace media { +using namespace facebook; + /** * The resolveNativeJSRefToImage_DO_NOT_USE function is needed to resolve * NativeJSRef objects to IImage. This function will be removed without @@ -35,6 +37,9 @@ std::shared_ptr imageFromFile(std::string filepath); std::shared_ptr imageFromBlob(const Blob& blob, double width, double height); +std::shared_ptr +imageFromFrame(jsi::Runtime& runtime, jsi::Object frameHostObject); + std::unique_ptr toBlob(const std::string& refId); std::unique_ptr toBlob(std::shared_ptr image); diff --git a/react-native-pytorch-core/cxx/src/torchlive/media/NativeJSRefBridgeCxx.cpp b/react-native-pytorch-core/cxx/src/torchlive/media/NativeJSRefBridgeCxx.cpp index 82dc7d35c..2bab7fad1 100644 --- a/react-native-pytorch-core/cxx/src/torchlive/media/NativeJSRefBridgeCxx.cpp +++ b/react-native-pytorch-core/cxx/src/torchlive/media/NativeJSRefBridgeCxx.cpp @@ -15,6 +15,8 @@ namespace torchlive { namespace media { +using namespace facebook; + std::shared_ptr resolveNativeJSRefToImage_DO_NOT_USE( const std::string& refId) { return nullptr; @@ -35,6 +37,10 @@ std::shared_ptr imageFromFile(std::string filepath) { return nullptr; } +std::shared_ptr imageFromFrame(jsi::Runtime& runtime, jsi::Object frameHostObject) { + return nullptr; +} + std::unique_ptr toBlob(const std::string& refId) { size_t const size = 0; auto data = std::unique_ptr(0); diff --git a/react-native-pytorch-core/cxx/src/torchlive/media/image/ImageHostObject.cpp b/react-native-pytorch-core/cxx/src/torchlive/media/image/ImageHostObject.cpp index b21164a0e..aa6207f2b 100644 --- a/react-native-pytorch-core/cxx/src/torchlive/media/image/ImageHostObject.cpp +++ b/react-native-pytorch-core/cxx/src/torchlive/media/image/ImageHostObject.cpp @@ -81,17 +81,10 @@ jsi::Value scaleImpl( utils::ArgumentParser args(runtime, thisValue, arguments, count); args.requireNumArguments(2); const auto& image = args.thisAsHostObject()->getImage(); - auto promiseValue = torchlive::createPromiseAsJSIValue( - runtime, - [&image, sx = args[0].asNumber(), sy = args[1].asNumber()]( - jsi::Runtime& rt, std::shared_ptr promise) { - auto scaledImage = image->scale(sx, sy); - auto imageObject = - utils::helpers::createFromHostObject( - rt, std::move(scaledImage)); - promise->resolve(std::move(imageObject)); - }); - return promiseValue; + double sx = args[0].asNumber(); + double sy = args[1].asNumber(); + auto scaledImage = image->scale(sx, sy); + return utils::helpers::createFromHostObject(runtime, std::move(scaledImage)); }; jsi::Value releaseImpl( @@ -99,24 +92,9 @@ jsi::Value releaseImpl( const jsi::Value& thisValue, const jsi::Value* arguments, size_t count) { - auto image = thisValue.asObject(runtime) - .asHostObject(runtime) - ->getImage(); - auto promiseValue = torchlive::createPromiseAsJSIValue( - runtime, - [image](jsi::Runtime& rt, std::shared_ptr promise) { - try { - image->close(); - promise->resolve(jsi::Value::undefined()); - } catch (std::exception& e) { - promise->reject("error on release: " + std::string(e.what())); - } catch (const char* error) { - promise->reject("error on release: " + std::string(error)); - } catch (...) { - promise->reject("error on release"); - } - }); - return promiseValue; + auto image = thisValue.asObject(runtime).asHostObject(runtime); + image->release(); + return jsi::Value::undefined(); }; } // namespace @@ -143,5 +121,9 @@ std::shared_ptr ImageHostObject::getImage() const noexcept { return image_; } +void ImageHostObject::release() noexcept { + image_ = nullptr; +} + } // namespace media } // namespace torchlive diff --git a/react-native-pytorch-core/cxx/src/torchlive/media/image/ImageHostObject.h b/react-native-pytorch-core/cxx/src/torchlive/media/image/ImageHostObject.h index 6845ab0c2..c3566b753 100644 --- a/react-native-pytorch-core/cxx/src/torchlive/media/image/ImageHostObject.h +++ b/react-native-pytorch-core/cxx/src/torchlive/media/image/ImageHostObject.h @@ -21,9 +21,10 @@ class JSI_EXPORT ImageHostObject : public common::BaseHostObject { std::shared_ptr image); std::shared_ptr getImage() const noexcept; - + void release() noexcept; + private: - std::shared_ptr image_; + std::shared_ptr image_; }; } // namespace media diff --git a/react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp b/react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp index d83fbbab7..c69aa2018 100644 --- a/react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp +++ b/react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp @@ -433,6 +433,17 @@ jsi::Value permuteImpl( runtime, std::move(tensor)); } +jsi::Value releaseImpl( + jsi::Runtime& runtime, + const jsi::Value& thisValue, + const jsi::Value* arguments, + size_t count) { + utils::ArgumentParser args(runtime, thisValue, arguments, count); + args.requireNumArguments(0); + args.thisAsHostObject()->tensor.reset(); + return jsi::Value::undefined(); +} + jsi::Value reshapeImpl( jsi::Runtime& runtime, const jsi::Value& thisValue, @@ -672,6 +683,7 @@ TensorHostObject::TensorHostObject(jsi::Runtime& runtime, torch_::Tensor t) setPropertyHostFunction(runtime, "matmul", 1, matmulImpl); setPropertyHostFunction(runtime, "mul", 1, mulImpl); setPropertyHostFunction(runtime, "permute", 1, permuteImpl); + setPropertyHostFunction(runtime, "release", 0, releaseImpl); setPropertyHostFunction(runtime, "reshape", 1, reshapeImpl); setPropertyHostFunction(runtime, "softmax", 1, softmaxImpl); setPropertyHostFunction(runtime, "squeeze", 1, squeezeImpl); diff --git a/react-native-pytorch-core/cxx/src/torchlive/torch/jit/JITNamespace.cpp b/react-native-pytorch-core/cxx/src/torchlive/torch/jit/JITNamespace.cpp index da75b99f0..83a812b7d 100644 --- a/react-native-pytorch-core/cxx/src/torchlive/torch/jit/JITNamespace.cpp +++ b/react-native-pytorch-core/cxx/src/torchlive/torch/jit/JITNamespace.cpp @@ -128,6 +128,66 @@ _LoadForMobileAsyncTask _loadForMobileImpl( runtime, std::move(moduleHostObject)); }); +jsi::Value syncImpl(jsi::Runtime& runtime, + const jsi::Value& thisValue, + const jsi::Value* arguments, + size_t count) { + + utils::ArgumentParser args(runtime, thisValue, arguments, count); + args.requireNumArguments(1); + + std::string filename = args[0].asString(runtime).utf8(runtime); + + c10::optional device = c10::nullopt; + if (count > 1) { + auto deviceType = args[1].asString(runtime).utf8(runtime); + if (deviceType == "cpu") { + device = torch_::kCPU; + } else { + throw facebook::jsi::JSError( + runtime, "only 'cpu' device is currently supported"); + } + } + + std::unordered_map extraFiles; + std::shared_ptr extraFilesObject = nullptr; + if (count > 2) { + jsi::Object obj = args[2].asObject(runtime); + auto arr = obj.getPropertyNames(runtime); + for (size_t i = 0; i < arr.length(runtime); i++) { + auto propName = + arr.getValueAtIndex(runtime, i).asString(runtime).utf8(runtime); + extraFiles[propName] = ""; + } + // Move jsi::Object to pass it through to the result worker function + // to update its values with the extra files values after loading the + // model. + extraFilesObject = std::make_shared(std::move(obj)); + } + + auto model = torch_::jit::_load_for_mobile(filename, device, extraFiles); + + // Update the extra files object passed in as third argument with the + // extra files values retrieved on _load_for_mobile in the worker thread. + // Note, this will only run if a JavaScript object was used as third + // argument and if the model included any extra files for the given keys. + if (extraFilesObject != nullptr && extraFilesObject->isObject() && + extraFiles.size() > 0) { + auto obj = extraFilesObject->asObject(runtime); + for (auto it : extraFiles) { + auto key = jsi::PropNameID::forUtf8(runtime, it.first); + auto value = jsi::String::createFromUtf8(runtime, it.second); + obj.setProperty(runtime, key, value); + } + } + + auto moduleHostObject = + std::make_shared( + runtime, nullptr, std::move(model)); + + return jsi::Object::createFromHostObject(runtime, std::move(moduleHostObject)); +} + } // namespace jsi::Object buildNamespace(jsi::Runtime& rt, torchlive::RuntimeExecutor rte) { @@ -137,7 +197,7 @@ jsi::Object buildNamespace(jsi::Runtime& rt, torchlive::RuntimeExecutor rte) { setPropertyHostFunction( rt, ns, "_loadForMobile", 1, _loadForMobileImpl.asyncPromiseFunc(rte)); setPropertyHostFunction( - rt, ns, "_loadForMobileSync", 1, _loadForMobileImpl.syncFunc(rte)); + rt, ns, "_loadForMobileSync", 1, syncImpl); return ns; } diff --git a/react-native-pytorch-core/cxx/src/torchlive/torch/jit/mobile/ModuleHostObject.cpp b/react-native-pytorch-core/cxx/src/torchlive/torch/jit/mobile/ModuleHostObject.cpp index 8dfe8be6d..bde21be19 100644 --- a/react-native-pytorch-core/cxx/src/torchlive/torch/jit/mobile/ModuleHostObject.cpp +++ b/react-native-pytorch-core/cxx/src/torchlive/torch/jit/mobile/ModuleHostObject.cpp @@ -96,6 +96,42 @@ MethodAsyncTask createMethodAsyncTask( return utils::converter::ivalueToJSIValue(runtime, value); }); } + +jsi::Value syncImpl(jsi::Runtime& runtime, + const jsi::Value& thisValue, + const jsi::Value* arguments, + size_t count) { + auto thiz = + thisValue.asObject(runtime).asHostObject(runtime); + + auto args = thiz->mobileModule.get_method("forward") + .function() + .getSchema() + .arguments(); + + // Two Cases in terms of number of argument required and argument + // provided + // Case 1 (n_required < n_provided) we ignore the extra provided args, + // respecting Js convention + // Case 2 (n_required >= n_provided) we process the provided argument + // and let libtorch check if they are enough, this would handle module + // with default parameters + int argCount = std::min(count, args.size() - 1); + + std::vector inputs = {}; + for (int i = 0; i < argCount; i++) { + c10::DynamicType& dynType = + args[i + 1].type()->expectRef(); + inputs.push_back(utils::converter::jsiValuetoIValue( + runtime, arguments[i], dynType)); + } + + c10::InferenceMode guard; + auto ivalue = thiz->mobileModule.get_method("forward")(inputs); + + return utils::converter::ivalueToJSIValue(runtime, ivalue); +} + } // namespace ModuleHostObject::ModuleHostObject( @@ -109,8 +145,7 @@ ModuleHostObject::ModuleHostObject( "forward", createMethodAsyncTask(mobileModule, "forward")); setPropertyHostFunction( rt, "forward", 1, methodAsyncTasks.at("forward").asyncPromiseFunc(rte)); - setPropertyHostFunction( - rt, "forwardSync", 1, methodAsyncTasks.at("forward").syncFunc(rte)); + setPropertyHostFunction(rt, "forwardSync", 1, syncImpl); } jsi::Value ModuleHostObject::get( jsi::Runtime& runtime, diff --git a/react-native-pytorch-core/ios/Image/ImageModule.swift b/react-native-pytorch-core/ios/Image/ImageModule.swift index d4601df36..b95ab5201 100644 --- a/react-native-pytorch-core/ios/Image/ImageModule.swift +++ b/react-native-pytorch-core/ios/Image/ImageModule.swift @@ -183,6 +183,13 @@ public class ImageModule: NSObject { return nil } return refID + } else if let ciImage = image.ciImage { + let bitmapImage = Image(image: ciImage) + let ref = JSContext.wrapObject(object: bitmapImage).getJSRef() + guard let refID = ref["ID"] as? NSString else { + return nil + } + return refID } else { return nil } diff --git a/react-native-pytorch-core/ios/Media/Image/Image.mm b/react-native-pytorch-core/ios/Media/Image/Image.mm index 26d795107..8aed888d5 100644 --- a/react-native-pytorch-core/ios/Media/Image/Image.mm +++ b/react-native-pytorch-core/ios/Media/Image/Image.mm @@ -12,16 +12,10 @@ namespace torchlive { namespace media { -Image::Image(UIImage *image) : image_(image) { - NSString *refId = [ImageModule wrapImage:image]; - if (refId == nil) { - throw "error on wrapImage"; - } - id_ = std::string([refId UTF8String]); -} +Image::Image(UIImage *image) : image_(image) {} std::string Image::getId() const { - return id_; + return "LEGACY_VALUE_DO_NOT_USE"; } double Image::getWidth() const noexcept { @@ -65,14 +59,7 @@ return std::make_shared(scaledImage); } -void Image::close() const { - // This is not needed once we fully migrate to JSI. - NSError *error = nil; - [PTLJSContext releaseWithJsRef:@{@"ID": [NSString stringWithUTF8String:id_.c_str()]} error:&error]; - if (error != nil) { - throw [error.localizedDescription UTF8String]; - } -} +void Image::close() const {} } // namespace media } // namespace torchlive diff --git a/react-native-pytorch-core/ios/Media/MediaUtils.h b/react-native-pytorch-core/ios/Media/MediaUtils.h index c35324188..0f199512c 100644 --- a/react-native-pytorch-core/ios/Media/MediaUtils.h +++ b/react-native-pytorch-core/ios/Media/MediaUtils.h @@ -6,6 +6,7 @@ */ #import +#import #import NS_ASSUME_NONNULL_BEGIN @@ -24,6 +25,8 @@ UIImage *MediaUtilsImageFromBlob(const torchlive::media::Blob& blob, double width, double height); +UIImage *MediaUtilsImageFromCMSampleBuffer(CMSampleBufferRef sampleBuffer); + NSData *MediaUtilsPrependWAVHeader(const std::vector& bytes, int sampleRate); diff --git a/react-native-pytorch-core/ios/Media/MediaUtils.mm b/react-native-pytorch-core/ios/Media/MediaUtils.mm index 329633818..d61ee732c 100644 --- a/react-native-pytorch-core/ios/Media/MediaUtils.mm +++ b/react-native-pytorch-core/ios/Media/MediaUtils.mm @@ -94,6 +94,14 @@ } } +UIImage *MediaUtilsImageFromCMSampleBuffer(CMSampleBufferRef sampleBuffer) { + @autoreleasepool { + CVPixelBufferRef pixelBuffer = CMSampleBufferGetImageBuffer(sampleBuffer); + CIImage *ciImage = [CIImage imageWithCVPixelBuffer:pixelBuffer]; + return [UIImage imageWithCIImage:ciImage]; + } +} + #pragma mark - Audio static void write(std::stringstream &stream, int value, int size) diff --git a/react-native-pytorch-core/ios/Media/NativeJSRefBridge.mm b/react-native-pytorch-core/ios/Media/NativeJSRefBridge.mm index 58fe06391..060ee8633 100644 --- a/react-native-pytorch-core/ios/Media/NativeJSRefBridge.mm +++ b/react-native-pytorch-core/ios/Media/NativeJSRefBridge.mm @@ -16,6 +16,16 @@ #import "MediaUtils.h" #import "PyTorchCore-Swift-Header.h" +#if __has_include() + #define HAS_VISION_CAMERA + #import + // forward declaration for the Frame Host Object since we only care about `Frame*` + class FrameHostObject: public facebook::jsi::HostObject { + public: + Frame* frame; + }; +#endif + namespace torchlive { namespace media { @@ -59,6 +69,16 @@ return std::make_shared(image); } +std::shared_ptr imageFromFrame(jsi::Runtime& runtime, jsi::Object frameHostObject) { +#ifdef HAS_VISION_CAMERA + const auto& frame = frameHostObject.asHostObject(runtime); + auto image = MediaUtilsImageFromCMSampleBuffer(frame->frame.buffer); + return std::make_shared(image); +#else + throw jsi::JSError(runtime, "Error converting Frame to Image - VisionCamera is not properly installed!"); +#endif +} + std::unique_ptr toBlob(const std::string& refId) { auto idRef = [NSString stringWithUTF8String:refId.c_str()]; NSError *error = nil; @@ -107,9 +127,13 @@ buffer[i * 3 + 1] = imageData[i * 4 + 1]; // G buffer[i * 3 + 2] = imageData[i * 4 + 2]; // B } - + auto data = std::unique_ptr(new uint8_t[finalDataSize]); std::memcpy(data.get(), buffer, finalDataSize); + + free(imageData); + free(buffer); + std::string blobType = Blob::kBlobTypeImageRGB; return std::make_unique( std::move(data), dataSize, blobType); diff --git a/react-native-pytorch-core/package.json b/react-native-pytorch-core/package.json index 7344294b0..3eec88b72 100644 --- a/react-native-pytorch-core/package.json +++ b/react-native-pytorch-core/package.json @@ -1,6 +1,6 @@ { "name": "react-native-pytorch-core", - "version": "0.0.0", + "version": "0.3.0-alpha.1", "description": "PyTorch core library for React Native", "main": "lib/commonjs/index", "module": "lib/module/index", @@ -73,6 +73,7 @@ "react": "17.0.1", "react-native": "0.64.3", "react-native-builder-bob": "^0.18.1", + "react-native-vision-camera": "^2.15.4", "release-it": "^14.10.0", "typescript": "^4.3.4" }, diff --git a/react-native-pytorch-core/src/ImageModule.ts b/react-native-pytorch-core/src/ImageModule.ts index 31d1db2e2..b582e9fdf 100644 --- a/react-native-pytorch-core/src/ImageModule.ts +++ b/react-native-pytorch-core/src/ImageModule.ts @@ -88,7 +88,7 @@ export interface Image extends NativeJSRef { * @param sx Scaling factor in the horizontal direction. A negative value flips pixels across the vertical axis. A value of `1` results in no horizontal scaling. * @param sy Scaling factor in the vertical direction. A negative value flips pixels across the horizontal axis. A value of `1` results in no vertical scaling. */ - scale(sx: number, sy: number): Promise; + scale(sx: number, sy: number): Image; } export const wrapRef = (ref: NativeJSRef): Image => ({ diff --git a/react-native-pytorch-core/src/torchlive/frame.d.ts b/react-native-pytorch-core/src/torchlive/frame.d.ts new file mode 100644 index 000000000..f8c2f86e0 --- /dev/null +++ b/react-native-pytorch-core/src/torchlive/frame.d.ts @@ -0,0 +1,4 @@ +import type {Frame} from 'react-native-vision-camera'; + +// If VisionCamera is not installed, this type is `never`. +export type VisionCameraFrame = Frame extends object ? Frame : never; diff --git a/react-native-pytorch-core/src/torchlive/media.ts b/react-native-pytorch-core/src/torchlive/media.ts index d5fae2b07..c4fb33438 100644 --- a/react-native-pytorch-core/src/torchlive/media.ts +++ b/react-native-pytorch-core/src/torchlive/media.ts @@ -10,8 +10,9 @@ import type {Tensor} from 'react-native-pytorch-core'; import type {NativeJSRef} from '../NativeJSRef'; import type {Image} from '../ImageModule'; +import type {VisionCameraFrame} from './frame'; -export interface Blob { +export interface PlayTorchBlob { /** * The Blob interface's size property returns the size of the Blob in bytes. */ @@ -54,6 +55,13 @@ export interface Blob { * within the blob on which this method was called. The original blob is not altered. */ slice(start?: number, end?: number): Blob; + /** + * @experimental + * + * Release blob memory immediately rather than waiting for JavaScript GC + * collecting the host object. + */ + release(): void; } export interface Media { @@ -126,13 +134,23 @@ export interface Media { imageFromFile(filepath: string): Image; /** - * Converts a [[Tensor]] or [[NativeJSRef]] into a [[Blob]]. The blob can be + * Converts a VisionCamera [`Frame`](https://mrousavy.com/react-native-vision-camera/docs/api/interfaces/Frame) into an [[Image]]. This function has to be called inside a Frame Processor, as the Frame only exists inside a Frame Processor. + * + * Requires [react-native-vision-camera](https://github.com/mrousavy/react-native-vision-camera) to be installed. + * + * @param frame [[VisionCameraFrame]] to turn into an [[Image]]. + * @returns An [[Image]] object created from the [[VisionCameraFrame]]. + */ + imageFromFrame(frame: VisionCameraFrame): Image; + + /** + * Converts a [[Tensor]] or [[NativeJSRef]] into a [[PlayTorchBlob]]. The blob can be * used to create a [[Tensor]] object or convert into a [[NativeJSRef]] like * an image or audio. * - * @param obj Object to turn into a [[Blob]]. + * @param obj Object to turn into a [[PlayTorchBlob]]. */ - toBlob(obj: Tensor | NativeJSRef): Blob; + toBlob(obj: Tensor | NativeJSRef): PlayTorchBlob; } type Torchlive = { diff --git a/react-native-pytorch-core/src/torchlive/torch.ts b/react-native-pytorch-core/src/torchlive/torch.ts index f24f63ef7..c154b536f 100644 --- a/react-native-pytorch-core/src/torchlive/torch.ts +++ b/react-native-pytorch-core/src/torchlive/torch.ts @@ -296,6 +296,13 @@ export interface Tensor { * {@link https://pytorch.org/docs/1.12/generated/torch.Tensor.item.html} */ item(): number; + /** + * @experimental + * + * Release tensor memory immediately rather than waiting for JavaScript GC + * collecting the host object. + */ + release(): void; /** * Returns a tensor with the same data and number of elements as input, but * with the specified shape. diff --git a/react-native-pytorch-core/yarn.lock b/react-native-pytorch-core/yarn.lock index 62e95123a..d20e6d147 100644 --- a/react-native-pytorch-core/yarn.lock +++ b/react-native-pytorch-core/yarn.lock @@ -7626,6 +7626,11 @@ react-native-codegen@^0.0.6: jscodeshift "^0.11.0" nullthrows "^1.1.1" +react-native-vision-camera@^2.15.4: + version "2.15.4" + resolved "https://registry.yarnpkg.com/react-native-vision-camera/-/react-native-vision-camera-2.15.4.tgz#821f0505fc8c63b87c1ae4697d2bb4f670333576" + integrity sha512-SJXSWH1pu4V3Kj4UuX/vSgOxc9d5wb5+nHqBHd+5iUtVyVLEp0F6Jbbaha7tDoU+kUBwonhlwr2o8oV6NZ7Ibg== + react-native@0.64.3: version "0.64.3" resolved "https://registry.yarnpkg.com/react-native/-/react-native-0.64.3.tgz#40db6385963b4b17325f9cc86dd19132394b03fc"