diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fe9da39..00d0e36 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: cachix/install-nix-action@v30 + - uses: ners/simply-nix@main - uses: ryanccn/attic-action@v0 with: endpoint: https://cache.ners.ch diff --git a/examples/ADT.hs b/examples/ADT.hs index 8bfb7f5..01a1039 100644 --- a/examples/ADT.hs +++ b/examples/ADT.hs @@ -15,7 +15,7 @@ data Point a = Point a a deriving (Show) -- data Either ... {- already defined in 'Data.Either' -- Make some 'Storable' instances -mkStorable [t| forall a. Storable a => Storable (Point a) |] +mkMarshalable [t| forall a. Storable a => Storable (Point a) |] -- Generate corresponding Rust types extendContext (rustTyCtx [t| forall a. Point a |]) diff --git a/examples/PreludeStuff.hs b/examples/PreludeStuff.hs index 58f10ee..189c810 100644 --- a/examples/PreludeStuff.hs +++ b/examples/PreludeStuff.hs @@ -15,7 +15,7 @@ setCrateRoot [] -- Some ADTs data Point a = Point a a deriving (Show) -mkStorable [t| forall a. Storable a => Storable (Point a) |] +mkMarshalable [t| forall a. Storable a => Storable (Point a) |] extendContext (rustTyCtx [t| forall a. Point a |]) main = do diff --git a/flake.lock b/flake.lock index 9e23b27..7c1eea8 100644 --- a/flake.lock +++ b/flake.lock @@ -3,11 +3,11 @@ "language-rust": { "flake": false, "locked": { - "lastModified": 1729283594, - "narHash": "sha256-nVexe9Jrj0DtpjyPX+S2t/Op2iJyU1IRncW/BO8Ties=", + "lastModified": 1736519943, + "narHash": "sha256-5b4PRIpNCaZiTL3M1DZsKtSJWz5PPmyMpHatB7zBe+M=", "owner": "GaloisInc", "repo": "language-rust", - "rev": "ae59c7355b2fbbe818c1fdd23894df6bfa0432e5", + "rev": "f23c2b055217479bd3d1525dc9455d0f8bfcc394", "type": "github" }, "original": { @@ -18,11 +18,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1731319897, - "narHash": "sha256-PbABj4tnbWFMfBp6OcUK5iGy1QY+/Z96ZcLpooIbuEI=", + "lastModified": 1761114652, + "narHash": "sha256-f/QCJM/YhrV/lavyCVz8iU3rlZun6d+dAiC3H+CDle4=", "owner": "nixos", "repo": "nixpkgs", - "rev": "dc460ec76cbff0e66e269457d7b728432263166c", + "rev": "01f116e4df6a15f4ccdffb1bcd41096869fb385c", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 8c273ca..f24a53f 100644 --- a/flake.nix +++ b/flake.nix @@ -23,38 +23,47 @@ ); sourceFilter = root: with lib.fileset; toSource { inherit root; - fileset = fileFilter (file: any file.hasExt [ "cabal" "hs" "hsc" "md" ]) root; + fileset = fileFilter + (file: any file.hasExt [ "cabal" "hs" "md" ]) + root; }; ghcsFor = pkgs: with lib; foldlAttrs - (acc: name: hp: + (acc: name: hp': let - version = getVersion hp.ghc; + hp = tryEval hp'; + version = getVersion hp.value.ghc; majorMinor = versions.majorMinor version; ghcName = "ghc${replaceStrings ["."] [""] majorMinor}"; in - if hp ? ghc && ! acc ? ${ghcName} && versionAtLeast version "9.2" && versionOlder version "9.11" - then acc // { ${ghcName} = hp; } + if hp.value ? ghc && ! acc ? ${ghcName} && versionAtLeast version "9.4" && versionOlder version "9.13" + then acc // { ${ghcName} = hp.value; } else acc ) { } pkgs.haskell.packages; hpsFor = pkgs: { default = pkgs.haskellPackages; } // ghcsFor pkgs; pname = "inline-rust"; - src = sourceFilter ./.; - overlay = lib.composeManyExtensions [ - (final: prev: { - haskell = prev.haskell // { - packageOverrides = lib.composeManyExtensions [ - prev.haskell.packageOverrides - (hfinal: hprev: with prev.haskell.lib.compose; { - language-rust = hfinal.callCabal2nix "language-rust" inputs.language-rust { }; - ${pname} = addBuildDepend prev.cargo (hfinal.callCabal2nix pname src { }); - }) - ]; - }; - }) - ]; + pnames = [ pname ]; + haskell-overlay = final: prev: hfinal: hprev: with prev.haskell.lib.compose; { + language-rust = hfinal.callCabal2nix "language-rust" inputs.language-rust { }; + ${pname} = addBuildDepend prev.cargo (hfinal.callCabal2nix pname (sourceFilter ./.) { }); + }; + overlay = final: prev: { + haskell = prev.haskell // { + packageOverrides = lib.composeManyExtensions [ + prev.haskell.packageOverrides + (haskell-overlay final prev) + ]; + }; + }; in + { + overlays = { + default = overlay; + haskell = haskell-overlay; + }; + } + // foreach inputs.nixpkgs.legacyPackages (system: pkgs': let @@ -62,11 +71,20 @@ hps = hpsFor pkgs; libs = pkgs.buildEnv { name = "${pname}-libs"; - paths = map (hp: hp.${pname}) (attrValues hps); + paths = + lib.mapCartesianProduct + ({ hp, pname }: hp.${pname}) + { hp = attrValues hps; pname = pnames; }; pathsToLink = [ "/lib" ]; }; - docs = pkgs.haskell.lib.documentationTarball hps.default.${pname}; - sdist = pkgs.haskell.lib.sdistTarball hps.default.${pname}; + docs = pkgs.buildEnv { + name = "${pname}-docs"; + paths = map (pname: pkgs.haskell.lib.documentationTarball hps.default.${pname}) pnames; + }; + sdist = pkgs.buildEnv { + name = "${pname}-sdist"; + paths = map (pname: pkgs.haskell.lib.sdistTarball hps.default.${pname}) pnames; + }; docsAndSdist = pkgs.linkFarm "${pname}-docsAndSdist" { inherit docs sdist; }; in { @@ -80,18 +98,18 @@ devShells.${system} = foreach hps (ghcName: hp: { ${ghcName} = hp.shellFor { - packages = ps: [ hp.${pname} ]; + packages = ps: map (pname: ps.${pname}) pnames; nativeBuildInputs = with pkgs'; with haskellPackages; [ - pkgs'.haskellPackages.cabal-install - fourmolu - haskell-language-server + cabal-gild + cabal-install cargo + fourmolu rustc + ] ++ lib.optionals (lib.versionAtLeast (lib.getVersion hp.ghc) "9.4") [ + hp.haskell-language-server ]; }; }); } - ) // { - overlays.default = overlay; - }; + ); } diff --git a/inline-rust.cabal b/inline-rust.cabal index 907a777..086816f 100644 --- a/inline-rust.cabal +++ b/inline-rust.cabal @@ -22,21 +22,25 @@ source-repository head library hs-source-dirs: src - ghc-options: -Wall + ghc-options: -Wall -ddump-splices -ddump-to-file -g default-language: Haskell2010 exposed-modules: Language.Rust.Inline Language.Rust.Inline.TH other-modules: Language.Rust.Inline.Context - Language.Rust.Inline.Context.Prelude Language.Rust.Inline.Context.ByteString + Language.Rust.Inline.Context.Marshalable + Language.Rust.Inline.Context.Prelude + Language.Rust.Inline.Context.Vector + Language.Rust.Inline.Internal Language.Rust.Inline.Marshal Language.Rust.Inline.Parser Language.Rust.Inline.Pretty - Language.Rust.Inline.Internal + Language.Rust.Inline.Storable.Tuple + Language.Rust.Inline.TH.Marshalable + Language.Rust.Inline.TH.ReprC Language.Rust.Inline.TH.Storable Language.Rust.Inline.TH.Utilities - Language.Rust.Inline.TH.ReprC other-extensions: DeriveDataTypeable , CPP @@ -62,7 +66,7 @@ library test-suite spec hs-source-dirs: tests - ghc-options: -threaded + ghc-options: -threaded -ddump-splices -ddump-to-file -g if os(windows) extra-libraries: @@ -75,7 +79,11 @@ test-suite spec main-is: Main.hs type: exitcode-stdio-1.0 - default-language: Haskell2010 + default-language: Haskell2010 + default-extensions: ExplicitForAll + , QuasiQuotes + , ScopedTypeVariables + , TemplateHaskell other-modules: SimpleTypes , GhcUnboxedTypes , PointerTypes @@ -85,10 +93,13 @@ test-suite spec , Submodule , Submodule.Submodule , ByteStrings + , Vectors , ForeignPtr - build-depends: base + build-depends: async + , base , inline-rust , language-rust , hspec , ghc-prim , bytestring + , unliftio diff --git a/src/Language/Rust/Inline.hs b/src/Language/Rust/Inline.hs index d700f4d..04d7411 100644 --- a/src/Language/Rust/Inline.hs +++ b/src/Language/Rust/Inline.hs @@ -1,6 +1,7 @@ {-# LANGUAGE ForeignFunctionInterface #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} +{-# LANGUAGE ScopedTypeVariables #-} {- | Module : Language.Rust.Inline @@ -60,6 +61,7 @@ module Language.Rust.Inline ( pointers, prelude, bytestrings, + vectors, foreignPointers, -- ** Marshalling @@ -76,7 +78,7 @@ module Language.Rust.Inline ( newArray, withByteString, unsafeLocalState, - mkStorable, + mkMarshalable, mkReprC, -- * Top-level Rust items @@ -86,15 +88,15 @@ module Language.Rust.Inline ( import Language.Rust.Inline.Context import Language.Rust.Inline.Context.ByteString (bytestrings) +import Language.Rust.Inline.Context.Vector (vectors) import Language.Rust.Inline.Context.Prelude (prelude) import Language.Rust.Inline.Internal import Language.Rust.Inline.Marshal import Language.Rust.Inline.Parser import Language.Rust.Inline.Pretty import Language.Rust.Inline.TH.ReprC (mkReprC) -import Language.Rust.Inline.TH.Storable (mkStorable) +import Language.Rust.Inline.TH.Marshalable (mkMarshalable) -import Language.Haskell.TH (pprParendType) import Language.Haskell.TH.Lib import Language.Haskell.TH.Quote (QuasiQuoter (..)) import Language.Haskell.TH.Syntax @@ -103,16 +105,15 @@ import Foreign.Marshal.Alloc (alloca, free) import Foreign.Marshal.Array (newArray, withArrayLen) import Foreign.Marshal.Unsafe (unsafeLocalState) import Foreign.Marshal.Utils (new, with) -import Foreign.Ptr (FunPtr, Ptr, freeHaskellFunPtr, nullPtr) +import Foreign.Ptr (freeHaskellFunPtr) +import qualified Foreign import Control.Monad (void) import Data.List (intercalate) import Data.Traversable (for) -import Data.Word (Word8) import System.Random (randomIO) -import qualified Data.ByteString.Unsafe as ByteString -import Foreign.Storable (Storable (..)) +import qualified Language.Rust.Inline.Context.Marshalable as Marshalable {- $overview @@ -279,9 +280,6 @@ rustQuasiQuoter safety isPure supportDecs = | supportDecs = emitCodeBlock | otherwise = err -showTy :: Type -> String -showTy = show . pprParendType - {- | This function sums up the packages. What it does: 1. Map the Rust type annotations in the quasiquote to their Haskell types. @@ -314,72 +312,23 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do -- Convert the Haskell return type to a marshallable FFI type (returnFfi, haskRet') <- do marshalForm <- ghcMarshallable haskRet - let fptrRet haskRet' = [t|Ptr (Ptr $(pure haskRet'), FunPtr (Ptr $(pure haskRet') -> IO ())) -> IO ()|] - let bsRet = [t|Ptr (Ptr Word8, Word, FunPtr (Ptr Word8 -> Word -> IO ())) -> IO ()|] - ret <- case marshalForm of - BoxedDirect -> [t|IO $(pure haskRet)|] - BoxedIndirect -> [t|Ptr $(pure haskRet) -> IO ()|] - UnboxedDirect - | isPure -> pure haskRet - | otherwise -> - let retTy = showTy haskRet - in fail ("Cannot put unlifted type ‘" ++ retTy ++ "’ in IO") - ByteString -> bsRet - OptionalByteString -> bsRet - ForeignPtr - | AppT _ haskRet' <- haskRet -> fptrRet haskRet' - | otherwise -> fail ("Cannot marshal " ++ showTy haskRet ++ " using the ForeignPtr calling convention") - OptionalForeignPtr - | AppT _ (AppT _ haskRet') <- haskRet -> fptrRet haskRet' - | otherwise -> fail ("Cannot marshal " ++ showTy haskRet ++ " as an optional ForeignPtr") - pure (marshalForm, pure ret) + pure (marshalForm, returnType marshalForm haskRet) -- Convert the Haskell arguments to marshallable FFI types (marshalForms, haskArgs') <- fmap unzip $ for haskArgs $ \haskArg -> do marshalForm <- ghcMarshallable haskArg - case marshalForm of - BoxedIndirect - | returnFfi == UnboxedDirect -> - let argTy = showTy haskArg - retTy = showTy haskRet - in fail - ( "Cannot pass an argument ‘" - ++ argTy - ++ "’" - ++ " indirectly when returning an unlifted type " - ++ "‘" - ++ retTy - ++ "’" - ) - | otherwise -> do - ptr <- [t|Ptr $(pure haskArg)|] - pure (BoxedIndirect, ptr) - ByteString -> do - rbsT <- [t|Ptr (Ptr Word8, Word)|] - pure (ByteString, rbsT) - OptionalByteString -> do - rbsT <- [t|Ptr (Ptr Word8, Word)|] - pure (OptionalByteString, rbsT) - ForeignPtr - | AppT _ haskArg' <- haskArg -> do - ptr <- [t|Ptr $(pure haskArg')|] - pure (ForeignPtr, ptr) - | otherwise -> fail ("Cannot marshal " ++ showTy haskRet ++ " using the ForeignPtr calling convention") - OptionalForeignPtr - | AppT _ (AppT _ haskArg') <- haskArg -> do - ptr <- [t|Ptr $(pure haskArg')|] - pure (OptionalForeignPtr, ptr) - | otherwise -> fail ("Cannot marshal " ++ showTy haskRet ++ " as an optional ForeignPtr") - _ -> pure (marshalForm, haskArg) + ret <- argumentType marshalForm haskArg + pure (marshalForm, ret) -- Generate the Haskell FFI import declaration and emit it - bsFree <- newName $ "bsFree" ++ show (abs q) - bsFreeSig <- [t|FunPtr (Ptr Word8 -> Word -> IO ()) -> Ptr Word8 -> Word -> IO ()|] - haskSig <- foldr (\l r -> [t|$(pure l) -> $r|]) haskRet' haskArgs' + -- bsFree <- newName $ "bsFree" ++ show (abs q) + -- bsFreeSig <- [t|FunPtr (Ptr Word8 -> Word -> IO ()) -> Ptr Word8 -> Word -> IO ()|] + haskRet'' <- if addIOUnit returnFfi then [t|$(haskRet') -> IO ()|] else haskRet' + haskSig <- foldr (\l r -> [t|$(pure l) -> $r|]) (pure haskRet'') haskArgs' let ffiImport = ForeignD (ImportF CCall safety qqStrName qqName haskSig) - let ffiBsFree = ForeignD (ImportF CCall Safe "dynamic" bsFree bsFreeSig) - addTopDecls [ffiImport, ffiBsFree] + -- let ffiBsFree = ForeignD (ImportF CCall Safe "dynamic" bsFree bsFreeSig) + addTopDecls [ffiImport] -- Generate the Haskell FFI call let goArgs :: @@ -394,77 +343,18 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do -- accumulated arguments. If the return value is not marshallable, we have to -- 'alloca' some space to put the return value. goArgs acc [] - | returnFfi == ByteString = do - ret <- newName "ret" - ptr <- newName "ptr" - len <- newName "len" - finalizer <- newName "finalizer" - [e| - alloca - ( \($(varP ret)) -> do - $(appsE (varE qqName : reverse (varE ret : acc))) - ($(varP ptr), $(varP len), $(varP finalizer)) <- peek $(varE ret) - ByteString.unsafePackCStringFinalizer - $(varE ptr) - (fromIntegral $(varE len)) - ($(varE bsFree) $(varE finalizer) $(varE ptr) $(varE len)) - ) - |] - | returnFfi == ForeignPtr = do - finalizer <- newName "finalizer" - ptr <- newName "ptr" - ret <- newName "ret" - [e| - alloca - ( \($(varP ret)) -> do - $(appsE (varE qqName : reverse (varE ret : acc))) - ($(varP ptr), $(varP finalizer)) <- peek $(varE ret) - newForeignPtr $(varE finalizer) $(varE ptr) - ) - |] - | returnFfi == OptionalForeignPtr = do - finalizer <- newName "finalizer" - ptr <- newName "ptr" - ret <- newName "ret" - [e| - alloca - ( \($(varP ret)) -> do - $(appsE (varE qqName : reverse (varE ret : acc))) - ($(varP ptr), $(varP finalizer)) <- peek $(varE ret) - if $(varE ptr) == nullPtr - then pure Nothing - else Just <$> newForeignPtr $(varE finalizer) $(varE ptr) - ) - |] - | returnFfi == OptionalByteString = do - ret <- newName "ret" - ptr <- newName "ptr" - len <- newName "len" - finalizer <- newName "finalizer" - [e| - alloca - ( \($(varP ret)) -> do - $(appsE (varE qqName : reverse (varE ret : acc))) - ($(varP ptr), $(varP len), $(varP finalizer)) <- peek $(varE ret) - if $(varE ptr) == nullPtr - then pure Nothing - else - Just - <$> ByteString.unsafePackCStringFinalizer - $(varE ptr) - (fromIntegral $(varE len)) - ($(varE bsFree) $(varE finalizer) $(varE ptr) $(varE len)) - ) - |] | returnByValue returnFfi = appsE (varE qqName : reverse acc) | otherwise = do ret <- newName "ret" [e| - alloca + Foreign.allocaBytesAligned + (Marshalable.sizeOfPeek (undefined :: $(pure haskRet))) + (Marshalable.alignmentPeek (undefined :: $(pure haskRet))) ( \($(varP ret)) -> do $(appsE (varE qqName : reverse (varE ret : acc))) - peek $(varE ret) + r :: $(pure haskRet) <- Marshalable.peek $(varE ret) + pure r ) |] @@ -475,37 +365,11 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do case arg of Nothing -> fail ("Could not find Haskell variable ‘" ++ argStr ++ "’") Just argName - | marshalForm == ByteString -> do - ptr <- newName "ptr" - len <- newName "len" - bsp <- newName "bsp" - [e| - withByteString - $(varE argName) - ( \($(varP ptr)) ($(varP len)) -> - with ($(varE ptr), $(varE len)) (\($(varP bsp)) -> $(goArgs (varE bsp : acc) args)) - ) - |] - | marshalForm == ForeignPtr -> do - ptr <- newName "ptr" - [e| - withForeignPtr $(varE argName) (\($(varP ptr)) -> $(goArgs (varE ptr : acc) args)) - |] - | marshalForm == OptionalForeignPtr -> do - ptr <- newName "ptr" - fptr <- newName "fptr" - [e| - case $(varE argName) of - Nothing -> let $(varP ptr) = nullPtr in $(goArgs (varE ptr : acc) args) - Just $(varP fptr) -> - withForeignPtr $(varE fptr) (\($(varP ptr)) -> $(goArgs (varE ptr : acc) args)) - |] - | marshalForm == OptionalByteString -> fail "Don't" | passByValue marshalForm -> goArgs (varE argName : acc) args | otherwise -> do x <- newName "x" [e| - with + Marshalable.with $(varE argName) ( \($(varP x)) -> $(goArgs (varE x : acc) args) @@ -514,7 +378,7 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do let haskCall' = goArgs [] (rustArgNames `zip` marshalForms) haskCall = - if isPure && returnFfi /= UnboxedDirect + if isPure && runsInIO returnFfi then [e|unsafeLocalState $haskCall'|] else haskCall' diff --git a/src/Language/Rust/Inline/Context.hs b/src/Language/Rust/Inline/Context.hs index 624d8df..01fcca0 100644 --- a/src/Language/Rust/Inline/Context.hs +++ b/src/Language/Rust/Inline/Context.hs @@ -46,6 +46,7 @@ import GHC.Exts ( Int#, Word#, ) +import Data.Void (Void) -- Easier on the eyes type RType = Ty () @@ -61,17 +62,17 @@ type. newtype Context = Context ( [RType -> Context -> First (Q HType, Maybe (Q RType))] - , -- Given a Rust type in a quasiquote, we need to look up the + -- Given a Rust type in a quasiquote, we need to look up the -- corresponding Haskell type (for the FFI import) as well as the -- C-compatible Rust type (if the initial Rust type isn't already -- @#[repr(C)]@. - [HType -> Context -> First (Q RType)] - , -- Given a field in a Haskell ADT, we need to figure out which + , [HType -> Context -> First (Q RType)] + -- Given a field in a Haskell ADT, we need to figure out which -- (not-necessarily @#[repr(C)]@) Rust type normally maps into this -- Haskell type. - [String] + , [String] -- Source for the trait impls of @MarshalTo@ ) deriving (Semigroup, Monoid, Typeable) @@ -319,13 +320,13 @@ foreignPointers = do foreignPtrT <- [t|ForeignPtr|] pure $ Context ([rule], [rev foreignPtrT], [foreignPtr, constPtr, mutPtr]) where + htype _ (Just _) = pure ([t| ForeignPtr Void|], Nothing) -- if the pointee needs marshalling, forbid peeking from Haskell + htype t Nothing = pure ([t|ForeignPtr $t|], Nothing) + rule (Rptr _ _ t _) context - | First (Just (t', Nothing)) <- lookupRTypeInContext t context = pure ([t|ForeignPtr $t'|], Nothing) + | First (Just (t', inter)) <- lookupRTypeInContext t context = htype t' inter rule (PathTy Nothing (Path False [PathSegment "ForeignPtr" (Just (AngleBracketed [] [t] [] _)) _] _) _) context - | First (Just (t', Nothing)) <- lookupRTypeInContext t context = pure ([t|ForeignPtr $t'|], Nothing) - rule (PathTy Nothing (Path False [PathSegment "Option" (Just (AngleBracketed [] [PathTy Nothing (Path False [PathSegment "ForeignPtr" (Just (AngleBracketed [] [t] [] _)) _] _) _] [] _)) _] _) _) context - | First (Just (t', Nothing)) <- lookupRTypeInContext t context = - pure ([t|Maybe (ForeignPtr $t')|], pure . pure $ PathTy Nothing (Path False [PathSegment "ForeignPtr" (Just (AngleBracketed [] [t] [] ())) ()] ()) ()) + | First (Just (t', inter)) <- lookupRTypeInContext t context = htype t' inter rule _ _ = mempty rev _ _ _ = mempty @@ -334,6 +335,8 @@ foreignPointers = do unlines [ "#[repr(C)]" , "pub struct ForeignPtr(pub *mut T, pub extern \"C\" fn (*mut T));" + , "impl Copy for ForeignPtr {}" + , "impl Clone for ForeignPtr { fn clone(&self) -> Self { ForeignPtr(self.0, self.1) } }" ] constPtr = @@ -371,15 +374,6 @@ foreignPointers = do , "impl<'a, T> MarshalInto<&'a mut T> for &'a mut T {" , " fn marshal(self) -> &'a mut T { self }" , "}" - , "" - , "impl MarshalInto> for Option> {" - , " fn marshal(self) -> ForeignPtr {" - , " extern fn panic(_ptr: *mut T) {" - , " panic!(\"Attempted to free a null ForeignPtr\")" - , " }" - , " self.unwrap_or(ForeignPtr(std::ptr::null_mut(), panic))" - , " }" - , "}" ] {- | This maps a Rust function type into the corresponding 'FunPtr' wrapped diff --git a/src/Language/Rust/Inline/Context/ByteString.hs b/src/Language/Rust/Inline/Context/ByteString.hs index dcc8647..25db20b 100644 --- a/src/Language/Rust/Inline/Context/ByteString.hs +++ b/src/Language/Rust/Inline/Context/ByteString.hs @@ -32,15 +32,18 @@ import Foreign.Ptr (Ptr) import Debug.Trace (traceM) bytestrings :: Q Context -bytestrings = - pure $ Context ([rule], [], [rustByteString, impl]) +bytestrings = do + bs <- [t|ByteString|] + pure $ Context ([rule], [rev bs], [rustByteString, impl]) where rule rty _ | rty == void [ty| &[u8] |] = pure ([t|ByteString|], pure . pure $ void [ty| RustByteString |]) | rty == void [ty| Vec |] = pure ([t|ByteString|], pure . pure $ void [ty| RustOwnedByteString |]) - | rty == void [ty| Option> |] = pure ([t|Maybe ByteString|], pure . pure $ void [ty| RustOwnedByteString |]) rule _ _ = mempty + rev bs hty _ | bs == hty = pure . pure $ void [ty|Vec|] + rev _ _ _ = mempty + rustByteString = unlines [ "#[repr(C)]" @@ -73,13 +76,4 @@ bytestrings = , " RustOwnedByteString(bytes.as_mut_ptr(), len, free)" , " }" , "}" - , "" - , "impl MarshalInto for Option> {" - , " fn marshal(self) -> RustOwnedByteString {" - , " extern fn panic(ptr: *mut u8, len: usize) {" - , " panic!(\"Attempted to free a null ByteString\");" - , " }" - , " self.map(|bs| bs.marshal()).unwrap_or(RustOwnedByteString(std::ptr::null_mut(), 0, panic))" - , " }" - , "}" ] diff --git a/src/Language/Rust/Inline/Context/Marshalable.hs b/src/Language/Rust/Inline/Context/Marshalable.hs new file mode 100644 index 0000000..7eb3f0f --- /dev/null +++ b/src/Language/Rust/Inline/Context/Marshalable.hs @@ -0,0 +1,152 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE DefaultSignatures #-} + +module Language.Rust.Inline.Context.Marshalable where + +import Foreign + ( Word8, + Ptr, + FunPtr, + ForeignPtr, + plusPtr, + newForeignPtr, + withForeignPtr, + castPtr, + sizeOf, + alignment, Storable, poke ) +import qualified Foreign +import Data.ByteString (ByteString) +import Data.ByteString.Internal (ByteString(PS)) +import qualified Data.ByteString.Unsafe as ByteString +import Foreign.C.Types +import Data.Int +import Data.Word +import Language.Rust.Inline.Storable.Tuple () + +-- | A generalisation of `Storable`'s `with` that respects finalizers and lets us avoid copies for types that can not be `Storable`. +class Marshalable a where + -- | The size of the `Storable` representation of `a` + sizeOfWith :: a -> Int + default sizeOfWith :: Storable a => a -> Int + sizeOfWith = sizeOf + + alignmentWith :: a -> Int + default alignmentWith :: Storable a => a -> Int + alignmentWith = alignment + + -- | Holds a reference to its argument while making it available to foreign code. + -- By default allocates space to store `WithPtrType a` and makes `a` available as `WithPtrType a` there. + with + :: a + -- ^ The data to marshal + -> (Foreign.Ptr a -> IO b) + -- ^ The continuation that takes a pointer to the marshaled data + -> IO b + with x k = Foreign.allocaBytesAligned (sizeOfWith (undefined :: a)) (alignmentWith (undefined :: a)) $ \loc -> withLoc x loc (k loc) + + -- | Hold a reference to `a` and `poke`s the `WithPtrType a` into a preallocated location. + -- Call this if you have a specific memory layout requirement, e.g. to marshal `a` as part of a larger data structure. + withLoc + :: a + -- ^ The data to marshal + -> Ptr a + -- ^ The location to marshal into (where to put the pointer to the data) + -> IO b + -- ^ The action to run while holding the reference to the data + -> IO b + default withLoc :: Storable a => a -> Ptr a -> IO b -> IO b + withLoc p loc k = poke loc p >> k + + sizeOfPeek :: a -> Int + default sizeOfPeek :: Storable a => a -> Int + sizeOfPeek = sizeOf + + alignmentPeek :: a -> Int + default alignmentPeek :: Storable a => a -> Int + alignmentPeek = alignment + + -- | `peek` the `Storable` representation and convert it to `a` + peek + :: Ptr b + -- ^ The pointer to peek at + -> IO a + default peek :: Storable a => Ptr b -> IO a + peek = Foreign.peek . castPtr + +instance Marshalable ByteString where + sizeOfWith _ = Foreign.sizeOf (undefined :: (Foreign.Ptr Foreign.Word8, Word)) + alignmentWith _ = Foreign.alignment (undefined :: (Foreign.Ptr Foreign.Word8, Word)) + withLoc (PS ptr off len) loc k = Foreign.withForeignPtr ptr $ \ptr' -> do + Foreign.poke @(Foreign.Ptr Foreign.Word8, Word) (castPtr loc) (ptr' `Foreign.plusPtr` off, fromIntegral len) + k + sizeOfPeek _ = Foreign.sizeOf (undefined :: (Foreign.Ptr Foreign.Word8, Word, Foreign.FunPtr (Foreign.Ptr Foreign.Word8 -> Word -> IO ()))) + alignmentPeek _ = Foreign.alignment (undefined :: (Foreign.Ptr Foreign.Word8, Word, Foreign.FunPtr (Foreign.Ptr Foreign.Word8 -> Word -> IO ()))) + peek p = do + (ptr, len, finalizer) <- Foreign.peek (castPtr p) + ByteString.unsafePackCStringFinalizer ptr (fromIntegral len) (freeByteString finalizer ptr len) + +instance Marshalable (Foreign.ForeignPtr a) where + sizeOfWith = const $ sizeOf (undefined :: (Foreign.Ptr a)) + alignmentWith = const $ alignment (undefined :: (Foreign.Ptr a)) + withLoc fp loc k = Foreign.withForeignPtr fp $ \ptr -> do + Foreign.poke (castPtr loc) ptr + k + sizeOfPeek _ = Foreign.sizeOf (undefined :: (Foreign.Ptr a, Foreign.FunPtr (Foreign.Ptr a -> IO ()))) + alignmentPeek _ = Foreign.alignment (undefined :: (Foreign.Ptr a, Foreign.FunPtr (Foreign.Ptr a -> IO ()))) + peek p = do + (ptr, finalizer) <- Foreign.peek (castPtr p) + Foreign.newForeignPtr finalizer ptr + +foreign import ccall safe "dynamic" freeByteString :: Foreign.FunPtr (Foreign.Ptr Foreign.Word8 -> Word -> IO ()) -> Foreign.Ptr Foreign.Word8 -> Word -> IO () + +instance Marshalable CChar +instance Marshalable CSChar +instance Marshalable CUChar +instance Marshalable CShort +instance Marshalable CUShort +instance Marshalable CInt +instance Marshalable CUInt +instance Marshalable CLong +instance Marshalable CULong +instance Marshalable CPtrdiff +instance Marshalable CSize +instance Marshalable CWchar +instance Marshalable CLLong +instance Marshalable CULLong +instance Marshalable CBool +instance Marshalable CIntPtr +instance Marshalable CUIntPtr +instance Marshalable CIntMax +instance Marshalable CUIntMax +instance Marshalable CClock +instance Marshalable CTime +instance Marshalable CUSeconds +instance Marshalable CSUSeconds +instance Marshalable CFloat +instance Marshalable CDouble + +-- TODO: Should we marshal these? We have them in the libc context ... +-- instance Marshalable CFile +-- instance Marshalable CFpos + +instance Marshalable Int8 +instance Marshalable Int16 +instance Marshalable Int32 +instance Marshalable Int64 + +instance Marshalable Word8 +instance Marshalable Word16 +instance Marshalable Word32 +instance Marshalable Word64 + +instance Marshalable Char +instance Marshalable Float +instance Marshalable Double +instance Marshalable Int +instance Marshalable Word +instance Marshalable () diff --git a/src/Language/Rust/Inline/Context/Prelude.hs b/src/Language/Rust/Inline/Context/Prelude.hs index 08448cc..b72b57f 100644 --- a/src/Language/Rust/Inline/Context/Prelude.hs +++ b/src/Language/Rust/Inline/Context/Prelude.hs @@ -8,7 +8,6 @@ Stability : experimental Portability : GHC -} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE ScopedTypeVariables #-} {-# OPTIONS_GHC -w #-} @@ -16,7 +15,10 @@ Portability : GHC module Language.Rust.Inline.Context.Prelude where import Language.Rust.Inline.Context +import Language.Rust.Inline.Context.Marshalable import Language.Rust.Inline.TH +import Language.Rust.Inline.TH.Storable (mkStorable, mkTupleStorable) +import Language.Rust.Inline.TH.Marshalable (mkMarshalable, mkTupleMarshalable) import Language.Rust.Data.Ident ( Ident(..), mkIdent ) @@ -38,9 +40,10 @@ import Data.Maybe ( fromMaybe ) -- * Tuples up and including to arity 16 -- -- Note that arity 0 is in 'Foreign.Storable' and arity 1 makes no sense in Haskell. -mkStorable [t| forall a. Storable a => Storable (Maybe a) |] -mkStorable [t| forall l r. (Storable l, Storable r) => Storable (Either l r) |] -fmap join (traverse mkTupleStorable [2..16]) +mkMarshalable [t| forall a. Marshalable a => Marshalable (Maybe a) |] +mkMarshalable [t| forall l r. (Marshalable l, Marshalable r) => Marshalable (Either l r) |] + +fmap join (traverse mkTupleMarshalable [2..16]) -- | Make a generic path type (e.g. something like @Vec@). mkGenPathTy :: Ident -> [Ty ()] -> Ty () @@ -64,7 +67,7 @@ maybeContext = do where rule (PathTy Nothing (Path False [PathSegment "Option" (Just (AngleBracketed [] [t] [] _)) _] _) _) context = do (t', rInterOpt) <- lookupRTypeInContext t context - let inter = mkGenPathTy "MaybeC" <$> ((\x -> [x]) <$> maybe (pure t) id rInterOpt) + let inter = mkGenPathTy "MaybeC" . pure <$> fromMaybe (pure t) rInterOpt pure ([t| Maybe $t' |], Just inter) rule _ _ = mempty @@ -202,7 +205,7 @@ eitherItems = map unlines , "struct RightC(T);" ] -- impl MarshalInto> for Result - , [ "impl + Copy, R: Copy, R1: MarshalInto + Copy> MarshalInto> for Result {" + , [ "impl, R: Copy, R1: MarshalInto> MarshalInto> for Result {" , " fn marshal(self) -> EitherC {" , " match self {" , " Err(l) => EitherC { tag: 0, payload: TaggedEitherC { left: LeftC(l.marshal()) } }," diff --git a/src/Language/Rust/Inline/Context/Vector.hs b/src/Language/Rust/Inline/Context/Vector.hs new file mode 100644 index 0000000..1f8057c --- /dev/null +++ b/src/Language/Rust/Inline/Context/Vector.hs @@ -0,0 +1,87 @@ +{-# LANGUAGE OverloadedStrings#-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TemplateHaskell #-} +{-# OPTIONS_GHC -Wno-orphans #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Language.Rust.Inline.Context.Vector where +import Language.Rust.Syntax (Ty(PathTy), Path (Path), PathSegment (PathSegment), PathParameters (AngleBracketed)) +import Language.Rust.Inline.Context.Prelude (mkGenPathTy) +import Language.Haskell.TH +import Language.Rust.Quote (ty) +import Data.Maybe (fromMaybe) +import Language.Rust.Inline.Context.Marshalable +import qualified Foreign +import Data.Foldable (foldrM) +import qualified Language.Rust.Inline.Context.Marshalable as Marshalable +import Control.Monad (void) +import Language.Rust.Inline.Context (Context(..), lookupRTypeInContext) + +vectors :: Q Context +vectors = pure $ Context ([rule], [], [rustList, impl]) + where + rule (PathTy Nothing (Path False [PathSegment "Vec" (Just (AngleBracketed [] [t'] [] _)) _] _) _) ctx + | t' /= void [ty|u8|] = do + (t'', rInterOpt) <- lookupRTypeInContext t' ctx + let inter = mkGenPathTy "Vector" . pure <$> fromMaybe (pure t') rInterOpt + pure ([t| [$t''] |], Just inter) + rule _ _ = mempty + + rustList = unlines + [ "#[repr(C)]" + , "pub struct Vector(*mut T, usize, extern \"C\" fn (*mut T, usize));" + , "impl Copy for Vector { }" + , "impl Clone for Vector { fn clone(&self) -> Self { Vector(self.0, self.1, self.2) } }" + ] + + impl = unlines + [ "impl + Copy> MarshalInto> for Vector {" + , " fn marshal(self) -> Vec {" + , " let Vector(ptr, len, _) = self;" + , " let inputs = unsafe { std::slice::from_raw_parts(ptr, len) };" + , " let mut vec = Vec::with_capacity(len);" + , " for i in 0..len {" + , " vec.push(inputs[i].marshal());" + , " }" + , " vec" + , " }" + , "}" + , "" + , "impl> MarshalInto> for Vec {" + , " fn marshal(self) -> Vector {" + , " let mut vec = Vec::with_capacity(self.len());" + , " for x in self.into_iter() {" + , " vec.push(x.marshal());" + , " }" + , "" + , " let slice = Box::leak(vec.into_boxed_slice());" + , " let len = slice.len();" + , "" + , " extern fn free(ptr: *mut U, len: usize) {" + , " let data = unsafe { Box::from_raw(std::ptr::slice_from_raw_parts_mut(ptr, len)) };" + , " drop(data);" + , " }" + , " Vector(slice.as_mut_ptr(), len, free)" + , " }" + , "}" + ] + +instance Marshalable a => Marshalable [a] where + sizeOfWith _ = Foreign.sizeOf (undefined :: (Foreign.Ptr (), Word, Foreign.FunPtr (Foreign.Ptr () -> Word -> IO ()))) + alignmentWith _ = Foreign.alignment (undefined :: (Foreign.Ptr (), Word, Foreign.FunPtr (Foreign.Ptr () -> Word -> IO ()))) + withLoc as loc k = Foreign.allocaBytesAligned (length as * sizeOfWith (undefined :: a)) (alignmentWith (undefined :: a)) $ \space -> do + let pokeSlice = Foreign.poke (Foreign.castPtr loc) (Foreign.castPtr space, length as, Foreign.nullFunPtr) + let pokeAt i a = withLoc a $ space `Foreign.plusPtr` (i * sizeOfWith (undefined :: a)) + foldr ($) (pokeSlice >> k) $ zipWith pokeAt [0..] as + + sizeOfPeek _ = Foreign.sizeOf (undefined :: (Foreign.Ptr (), Word, Foreign.FunPtr (Foreign.Ptr () -> Word -> IO ()))) + alignmentPeek _ = Foreign.alignment (undefined :: (Foreign.Ptr (), Word, Foreign.FunPtr (Foreign.Ptr () -> Word -> IO ()))) + peek p = do + (ptr, len, finalizer) <- Foreign.peek (Foreign.castPtr p) + let peekAtOffset :: Word -> IO a + peekAtOffset offset = Marshalable.peek $ ptr `Foreign.plusPtr` (fromIntegral offset * sizeOfPeek (undefined :: a)) + list <- foldrM (\a b -> (:) <$> peekAtOffset a <*> pure b) [] $ take (fromIntegral len) [0..] + freeVector finalizer ptr len + pure list + +foreign import ccall safe "dynamic" freeVector :: Foreign.FunPtr (Foreign.Ptr () -> Word -> IO ()) -> Foreign.Ptr () -> Word -> IO () diff --git a/src/Language/Rust/Inline/Marshal.hs b/src/Language/Rust/Inline/Marshal.hs index 506df15..8061c0e 100644 --- a/src/Language/Rust/Inline/Marshal.hs +++ b/src/Language/Rust/Inline/Marshal.hs @@ -10,6 +10,7 @@ Portability : GHC {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE MagicHash #-} +{-# LANGUAGE LambdaCase #-} module Language.Rust.Inline.Marshal where @@ -32,21 +33,14 @@ import Data.Array.Storable ( StorableArray, Ix, withStorableArray, import GHC.Exts -data MarshalForm - = UnboxedDirect -- ^ value is marshallable and must be passed directly to the FFI - | BoxedDirect -- ^ value is marshallable and can be passed directly to the FFI - | BoxedIndirect -- ^ value isn't marshallable directly but may be passed indirectly via a 'Ptr' - | ByteString - | ForeignPtr - | OptionalForeignPtr - | OptionalByteString - deriving (Eq, Show) - -passByValue :: MarshalForm -> Bool -passByValue = (`elem` [UnboxedDirect, BoxedDirect, ForeignPtr]) - -returnByValue :: MarshalForm -> Bool -returnByValue = (`elem` [UnboxedDirect, BoxedDirect]) +data MarshalForm = MarshalForm + { passByValue :: Bool + , returnByValue :: Bool + , returnType :: Type -> Q Type + , argumentType :: Type -> Q Type + , runsInIO :: Bool + , addIOUnit :: Bool + } -- | Identify which types can be marshalled by the GHC FFI and which types are -- unlifted. A negative response to the first of these questions doesn't mean @@ -61,21 +55,58 @@ ghcMarshallable ty = do simpleB <- sequence qSimpleBoxed tyconsU <- sequence qTyconsUnboxed tyconsB <- sequence qTyconsBoxed + unitType <- [t| () |] bytestring <- [t| ByteString |] fptrCons <- [t| ForeignPtr |] - maybeCons <- [t| Maybe |] + + let unboxedDirect = MarshalForm + { passByValue = True + , returnByValue = True + , returnType = pure + , argumentType = pure + , runsInIO = False + , addIOUnit = False + } + boxedDirect = unboxedDirect{ returnType = \t -> [t|IO $(pure t)|], runsInIO = True } + unitDirect = boxedDirect { passByValue = False, argumentType = \t -> [t|Ptr $(pure t)|] } + boxedIndirect = MarshalForm + { passByValue = False + , returnByValue = False + , returnType = const [t|Ptr ()|] + , argumentType = \t -> [t|Ptr $(pure t)|] + , runsInIO = True + , addIOUnit = True + } + foreignPtr = MarshalForm + { passByValue = False + , returnByValue = False + , returnType = \case + AppT _ r -> [t|Ptr (Ptr $(pure r), FunPtr (Ptr $(pure r) -> IO ()))|] + t -> fail $ "Cannot marshal " <> (show . pprParendType) t <> " as a ForeignPtr" + , argumentType = \case + AppT _ r -> [t|Ptr (ForeignPtr $(pure r))|] + t -> fail $ "Cannot marshal " <> (show . pprParendType) t <> " as a ForeignPtr" + , runsInIO = True + , addIOUnit = True + } + byteString = MarshalForm + { passByValue = False + , returnByValue = False + , returnType = const [t|Ptr (Ptr Word8, Word, FunPtr (Ptr Word8 -> Word -> IO ()))|] + , argumentType = const [t|Ptr ByteString|] + , runsInIO = True + , addIOUnit = True + } case ty of - _ | ty `elem` simpleU -> pure UnboxedDirect - | ty `elem` simpleB -> pure BoxedDirect - | ty == bytestring -> pure ByteString - AppT con _ | con `elem` tyconsU -> pure UnboxedDirect - | con `elem` tyconsB -> pure BoxedDirect - | con == fptrCons -> pure ForeignPtr - AppT mb (AppT c _) - | mb == maybeCons && c == fptrCons -> pure OptionalForeignPtr - AppT mb c | mb == maybeCons && c == bytestring -> pure OptionalByteString - _ -> pure BoxedIndirect + _ | ty `elem` simpleU -> pure unboxedDirect + | ty `elem` simpleB -> pure boxedDirect + | ty == unitType -> pure unitDirect + | ty == bytestring -> pure byteString + AppT con _ | con `elem` tyconsU -> pure unboxedDirect + | con `elem` tyconsB -> pure boxedDirect + | con == fptrCons -> pure foreignPtr + _ -> pure boxedIndirect where qSimpleUnboxed = [ [t| Char# |] , [t| Int# |] @@ -97,7 +128,7 @@ ghcMarshallable ty = do , [t| Double |] , [t| Float |] - , [t| Bool |], [t| () |] -- TODO: let through `IO ()` but not `()` + , [t| Bool |] , [t| Int8 |], [t| Int16 |], [t| Int32 |], [t| Int64 |] , [t| Word8 |], [t| Word16 |], [t| Word32 |], [t| Word64 |] diff --git a/src/Language/Rust/Inline/Storable/Tuple.hs b/src/Language/Rust/Inline/Storable/Tuple.hs new file mode 100644 index 0000000..1aefebf --- /dev/null +++ b/src/Language/Rust/Inline/Storable/Tuple.hs @@ -0,0 +1,11 @@ +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# OPTIONS_GHC -w #-} + +module Language.Rust.Inline.Storable.Tuple where + +import Control.Monad (join) +import Foreign.Storable +import Language.Rust.Inline.TH.Storable + +fmap join (traverse mkTupleStorable [2..16]) diff --git a/src/Language/Rust/Inline/TH.hs b/src/Language/Rust/Inline/TH.hs index 36e0deb..da3f866 100644 --- a/src/Language/Rust/Inline/TH.hs +++ b/src/Language/Rust/Inline/TH.hs @@ -1,9 +1,19 @@ -module Language.Rust.Inline.TH ( adtCtx, rustTyCtx, mkStorable, mkTupleStorable ) where +module Language.Rust.Inline.TH + ( adtCtx + , rustTyCtx + , mkStorable + , mkTupleStorable + , Marshalable(..) + , mkMarshalable + , mkTupleMarshalable + ) where import Language.Rust.Inline.TH.Utilities ( getTyConOpt, getTyCon ) import Language.Rust.Inline.TH.ReprC import Language.Rust.Inline.TH.Storable ( mkStorable, mkTupleStorable ) +import Language.Rust.Inline.Context.Marshalable (Marshalable(..)) +import Language.Rust.Inline.TH.Marshalable (mkMarshalable, mkTupleMarshalable ) import Language.Rust.Inline.Context import Language.Rust.Inline.Internal import Language.Rust.Inline.Pretty @@ -83,7 +93,7 @@ rustTyCtx tyq = do (_, ty) <- case ty' of ForallT tyvars [] t -> pure (tyvars, t) - ForallT _ _ _ -> fail "rustTyCtx: type cannot have context" + ForallT {} -> fail "rustTyCtx: type cannot have context" t -> pure ([], t) -- Get the type and its name diff --git a/src/Language/Rust/Inline/TH/Marshalable.hs b/src/Language/Rust/Inline/TH/Marshalable.hs new file mode 100644 index 0000000..00bff0c --- /dev/null +++ b/src/Language/Rust/Inline/TH/Marshalable.hs @@ -0,0 +1,364 @@ +{-| +Module : Language.Rust.Inline.TH.Marshalable +Description : Generate Marshalable instances +Copyright : (c) Alec Theriault, 2018 +License : BSD-style +Maintainer : ners +Stability : experimental +Portability : GHC +-} + +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE FlexibleInstances #-} +{-# OPTIONS_GHC -Wwarn #-} -- TODO: GHC bug around "unused pattern binds" in splices +{-# LANGUAGE TypeApplications #-} + -- TODO: GHC feature around setting extensions from within TH +module Language.Rust.Inline.TH.Marshalable ( + mkMarshalable, + mkTupleMarshalable, +) where + +import Language.Rust.Inline.TH.Utilities + +import Language.Haskell.TH +import Language.Haskell.TH.Syntax hiding (lift) +import Control.Monad.Trans.State ( StateT(..), get, put ) +import Control.Monad.Trans.Class ( lift ) +import Data.Traversable ( for ) +import Foreign.Ptr ( alignPtr, plusPtr, castPtr, Ptr ) +import Data.Word ( Word8, Word16, Word32, Word64 ) +import Language.Rust.Inline.Context.Marshalable +import qualified Foreign +import Data.List (intercalate) + +-- | Generate 'Marshalable' instance for a non-recursive simple algebraic data +-- type. The instance follows the usual C layout for determining alignment and +-- size. +-- +-- Sum types are implemented as tagged unions. +-- +-- >>> mkMarshalable [t| forall a. Marshalable a => Marshalable (Maybe a) |] +-- +-- Remember to have 'ScopedTypeVariables', 'ExplicitForall', and 'EmptyCase' +-- enabled when calling this! +mkMarshalable :: TypeQ -- ^ a type representing the desired instance head + -> Q [Dec] -- ^ the instance declaration +mkMarshalable tyq = do + pseudoInstHead <- tyq + + -- Extract the context + marshalable <- [t| Marshalable |] + (ctx, ty') <- + case pseudoInstHead of + ForallT _ ctx (AppT s ty) | s == marshalable -> pure (ctx, ty) + AppT s ty | s == marshalable -> pure ([], ty) + _ -> fail "mkMarshalable: malformed 'Marshalable' instance head" + + -- Get the type constructors name + (name,cons') <- getConstructors ty' + + -- Produce the instance + decs' <- processADT name [ (nameCon n, tyArgs) | (n,tyArgs) <- cons' ] + pure . pure $ InstanceD Nothing ctx (AppT marshalable ty') decs' + +mkTupleMarshalable :: Int -- ^ arity of tuple + -> Q [Dec] -- ^ the instance declaration +mkTupleMarshalable n = do + marshalable <- [t| Marshalable |] + tyVars <- sequence (take n [ newName (c : show i) + | i <- [(1 :: Int)..] + , c <- ['a'..'z'] + ]) + let name = mkName $ "(" <> intercalate "," (show <$> tyVars) <> ")" + let ctx c = [ AppT c (VarT tyVar) | tyVar <- tyVars ] + let instHead c = AppT c (foldl AppT (TupleT n) (map VarT tyVars)) + + decs' <- processADT name [ (tupCon, map VarT tyVars) ] + pure . pure $ InstanceD Nothing (ctx marshalable) (instHead marshalable) decs' + +-- * Constructor utilities +data Constructor = Constructor + { conPat :: [Pat] -> Pat + , conExp :: [Exp] -> Exp + } + +nameCon :: Name -> Constructor +nameCon n = Constructor (ConP n []) (foldl AppE (ConE n)) + +tupCon :: Constructor +tupCon = Constructor TupP (TupE . fmap Just) + + +-- * Alignment + +-- | This is the information you need to carry along as you visit the fields of +-- a struct/union. +data Alignment = Alignment + { decs :: [Dec] + -- ^ declarations for variables relied on by offset and align + + , offsetSoFar :: Code Q Int + -- ^ total bytes occupied so far by fields + + , alignSoFar :: Code Q Int + -- ^ size (in bytes) of the largest member in the struct + } + +-- | Combining alignment means concatenating the dependent declarations, and +-- take the maximum for offset and alignment. +instance Semigroup Alignment where + a1 <> a2 = Alignment + { decs = decs a1 <> decs a2 + , offsetSoFar = [|| $$(offsetSoFar a1) `max` $$(offsetSoFar a2) ||] + , alignSoFar = [|| $$(alignSoFar a1) `max` $$(alignSoFar a2) ||] + } + +-- | The 'mconcat' method calls 'maximum' +instance Monoid Alignment where + mempty = Alignment [] [|| 0 ||] [|| 1 ||] + mappend = (<>) + mconcat as = Alignment + { decs = concatMap decs as + , offsetSoFar = [|| maximum $$(liftCode $ listTE <$> mapM (examineCode . offsetSoFar) as) ||] + , alignSoFar = [|| maximum $$(liftCode $ listTE <$> mapM (examineCode . alignSoFar) as) ||] + } + +-- | This is the state we will bundle along while visiting fields. +type StructState = StateT Alignment Q + +-- | Make a typed list. This function is like 'listE', but for 'TExp'. +listTE :: [TExp a] -> TExp [a] +listTE = TExp . ListE . map unType + +-- * With and Peek helper functions + +-- | TODO: vkleen will write docs +withCon :: Constructor -- ^ name of the constructor + -> [Exp -> Q Exp] -- ^ how to offset to every field + -> Name -- ^ the base pointer + -> Name -- ^ the name of the continuation parameter + -> Q (Pat, Exp) -- ^ an expression for poking the constructor +withCon con fieldOffsets ptr k = do + (ns, fields) <- unzip <$> do + for fieldOffsets $ \offset -> do + n <- newName "n" + pure (VarP n, [e| withLoc $(varE n) $(offset (VarE ptr)) |]) + let pat = conPat con ns + f <- foldr (\b e -> [e| $b $e |]) (varE k) fields + pure (pat, f) + +-- | Produces a 'do' block for peeking a constructor. The generated code has the +-- following shape: +-- +-- @ +-- do f1 <- ... ptr +-- f2 <- ... ptr +-- ... +-- fn <- ... ptr +-- return (Con f1 f2 ... fn) +-- @ +-- +peekCon :: Constructor -- ^ name of the constructor + -> [Exp -> Q Exp] -- ^ how to offset to every field + -> Name -- ^ the base pointer + -> Q Exp -- ^ a 'do' expression for peeking the constructor +peekCon con fieldOffsets ptr = do + (ns, binds) <- unzip <$> do + for fieldOffsets $ \offset -> do + n <- newName "n" + pure (varE n, bindS (varP n) [e| peek $(offset (VarE ptr)) |]) + let ret = [e| return $(conExp con <$> sequence ns) |] + doE (binds ++ [noBindS ret]) + +alignQInt :: Q Exp -> Q Exp -> Q Exp +alignQInt size alignment = [e| (($size + $alignment - 1) `div` $alignment) * $alignment |] + +alignCodeInt :: Code Q Int -> Code Q Int -> Code Q Int +alignCodeInt size alignment = [|| (($$size + $$alignment - 1) `div` $$alignment) * $$alignment ||] + +-- * Traversing fields (putting everything together) + +-- | Process a field of a given type. +processField :: Name -> Name -> Type -> StructState (Exp -> Q Exp) +processField alignment sizeOf ty = do + let alignTy, sizeTy :: Code Q Int + alignTy = Code $ TExp <$> [e| $(varE alignment) (undefined :: $(pure ty)) |] + sizeTy = Code $ TExp <$> [e| $(varE sizeOf) (undefined :: $(pure ty)) |] + + -- get state at the end of the last field + (Alignment prevDecs prevOff prevAlign) <- get + + -- where to peek: align (prevOff) currentAlign + -- new total alignment: max prevAlign currentAlign + -- new offset: where to peek + currentSize + + -- beginning offset + beginOffV <- lift $ newName "beginOff" + let beginOffE, beginOff :: Code Q Int + beginOffE = alignCodeInt prevOff alignTy + beginOff = Code $ TExp <$> varE beginOffV + assignBeginOff <- lift [d| $(varP beginOffV) = $(unType <$> examineCode beginOffE) |] + + -- offset after this field + newOffV <- lift $ newName "afterOff" + let newOffE :: Code Q Int + newOffE = [|| $$beginOff + $$sizeTy ||] + newOff <- lift (TExp <$> varE newOffV) + assignNewOff <- lift [d| $(varP newOffV) = $(unType <$> examineCode newOffE) |] + + -- alignment after this field + newAlignV <- lift $ newName "algn" + let newAlignE :: Code Q Int + newAlignE = [|| max $$alignTy $$prevAlign ||] + newAlign <- lift (TExp <$> varE newAlignV) + assignNewAlign <- lift [d| $(varP newAlignV) = $(unType <$> examineCode newAlignE) |] + + -- update state + put (Alignment { decs = concat [ assignBeginOff + , assignNewOff + , assignNewAlign + , prevDecs + ] + , offsetSoFar = liftCode (pure newOff) + , alignSoFar = liftCode (pure newAlign) + }) + + -- TODO: consider degenerate sizeof(..) = 0 cases + pure $ \addrE -> [e| (castPtr $(pure addrE) `plusPtr` $(unType <$> examineCode beginOff)) |] + + +-- | Process an algebraic data type. +-- +-- TODO: think about the zero constructor case... +processADT :: Name -- ^ name of the type + -> [(Constructor, [Type])] -- ^ constructors and the types of their fields + -> Q [Dec] -- ^ marshalable implementations + +-- The one constructor case is special - we don't need to specify a tag +processADT _ [(con, fields)] = do + initAlign <- mempty + (offsetsWith, Alignment dsWith sizeWith algnWith) <- runStateT (traverse (processField 'alignmentWith 'sizeOfWith) fields) initAlign + (offsetsPeek, Alignment dsPeek sizePeek algnPeek) <- runStateT (traverse (processField 'alignmentPeek 'sizeOfPeek) fields) initAlign + + sizeOfWith' <- funD + (mkName "sizeOfWith") + [clause [wildP] + (NormalB . unType <$> examineCode (alignCodeInt sizeWith algnWith)) + (pure <$> dsWith)] + + alignmentWith' <- funD + (mkName "alignmentWith") + [clause [wildP] + (NormalB . unType <$> examineCode algnWith) + (pure <$> dsWith)] + + withLoc' <- do + ptr <- newName "ptr" + k <- newName "k" + (pat, body) <- withCon con offsetsWith ptr k + funD (mkName "withLoc") [clause [pure pat, varP ptr, varP k] (normalB $ pure body) (pure <$> dsWith)] + + sizeOfPeek' <- funD + (mkName "sizeOfPeek") + [clause [wildP] + (NormalB . unType <$> examineCode (alignCodeInt sizePeek algnPeek)) + (pure <$> dsPeek)] + + alignmentPeek' <- funD + (mkName "alignmentPeek") + [clause [wildP] + (NormalB . unType <$> examineCode algnPeek) + (pure <$> dsPeek)] + + peek' <- do + ptr <- newName "ptr" + funD (mkName "peek") [clause [varP ptr] (normalB (peekCon con offsetsPeek ptr)) (pure <$> dsPeek)] + + pure [sizeOfWith', alignmentWith', withLoc', sizeOfPeek', alignmentPeek', peek'] + +processADT name cons = do + let discNum = length cons + discTy <- snd . head . dropWhile (\(m,_) -> discNum > m + 1) $ + [ (fromIntegral (maxBound :: Word8), [t| Word8 |]) + , (fromIntegral (maxBound :: Word16), [t| Word16 |]) + , (fromIntegral (maxBound :: Word32), [t| Word32 |]) + , (fromIntegral (maxBound :: Word64), [t| Word64 |]) + ] + + initAlign <- mempty + (conWithsPeeks, algnsWith, algnsPeek) <- unzip3 <$> do + for cons $ \(con, fields) -> do + (offsetsWith, algnWith) <- runStateT (traverse (processField 'alignmentWith 'sizeOfWith) fields) initAlign + (offsetsPeek, algnPeek) <- runStateT (traverse (processField 'alignmentPeek 'sizeOfPeek) fields) initAlign + pure ((con, offsetsWith, offsetsPeek), algnWith, algnPeek) + let (Alignment dsWith offWith algnWith) = mconcat algnsWith + let (Alignment dsPeek offPeek algnPeek) = mconcat algnsPeek + let discSizeOf = [e| Foreign.sizeOf (undefined :: $(pure discTy)) |] + discAlign = [e| Foreign.alignment (undefined :: $(pure discTy)) |] + + sizeOfWith' <- funD + (mkName "sizeOfWith") + [clause [wildP] + (normalB [e| $(alignQInt discSizeOf (unType <$> examineCode algnWith)) + $(unType <$> examineCode offWith) |]) + (pure <$> dsWith)] + + alignmentWith' <- funD + (mkName "alignmentWith") + [clause [wildP] + (normalB [e| max $(discAlign) $(unType <$> examineCode algnWith) |]) + (pure <$> dsWith)] + + withLoc' <- do + ptr <- newName "ptr" + ptrOff <- newName "ptrOff" + k <- newName "k" + d' <- [d| $(varP ptrOff) = ($(varE ptr) `plusPtr` $(discSizeOf)) `alignPtr` $(unType <$> examineCode algnWith) |] + x <- newName "x" + + let mtchs = [ do (pat, body) <- patBody + match (pure pat) + (normalB . doE $ noBindS <$> [ [e| Foreign.poke (Foreign.castPtr $(varE ptr) :: Ptr $(pure discTy)) $(litE n') |] + , pure body + ]) + [] + | (n, (con, offsetsWith, _)) <- zip [0..] conWithsPeeks + , let patBody = withCon con offsetsWith ptrOff k + , let n' = IntegerL n + ] + + funD (mkName "withLoc") [clause [varP x, varP ptr, varP k] (normalB $ caseE (varE x) mtchs) (pure <$> d' ++ dsWith)] + + sizeOfPeek' <- funD + (mkName "sizeOfPeek") + [clause [wildP] + (normalB [e| $(alignQInt discSizeOf (unType <$> examineCode algnPeek)) + $(unType <$> examineCode offPeek) |]) + (pure <$> dsPeek)] + + alignmentPeek' <- funD + (mkName "alignmentPeek") + [clause [wildP] + (NormalB . unType <$> examineCode algnPeek) + (pure <$> dsPeek)] + + peek' <- do + ptr <- newName "ptr" + ptrOff <- newName "ptrOff" + d' <- [d| $(varP ptrOff) = ($(varE ptr) `plusPtr` $(discSizeOf)) `alignPtr` $(unType <$> examineCode algnPeek) |] + disc <- newName "disc" + let mtchs = [ match (litP n') (normalB (peekCon con offsetsPeek ptrOff)) [] + | (n, (con, _, offsetsPeek)) <- zip [0..] conWithsPeeks + , let n' = IntegerL n + ] + <> + [ match wildP (normalB [e| + fail $ "Unknown discriminator for " <> $(liftString $ show name) <> ": " <> show $(varE disc) + |]) [] + ] + funD (mkName "peek") + [clause [varP ptr] + (normalB (doE [ bindS (varP disc) [e| Foreign.peek (castPtr $(varE ptr) :: Ptr $(pure discTy)) |] + , noBindS (caseE (varE disc) mtchs) + ])) + (pure <$> d' ++ dsPeek)] + + pure [sizeOfWith', alignmentWith', withLoc', sizeOfPeek', alignmentPeek', peek'] diff --git a/src/Language/Rust/Inline/TH/Storable.hs b/src/Language/Rust/Inline/TH/Storable.hs index 989577a..54a822b 100644 --- a/src/Language/Rust/Inline/TH/Storable.hs +++ b/src/Language/Rust/Inline/TH/Storable.hs @@ -245,32 +245,26 @@ processADT [(con, fields)] = do let ds' = map pure ds -- sizeOf - sizeOf_ <- do - Just sizeOfN <- lookupValueName "sizeOf" - funD sizeOfN [clause [wildP] + sizeOf_ <- funD (mkName "sizeOf") [clause [wildP] (normalB [e| let c = $(unType <$> examineCode off) in c + mod (negate c) $(unType <$> examineCode algn) |]) ds'] -- alignment - alignment_ <- do - Just alignmentN <- lookupValueName "alignment" - funD alignmentN [clause [wildP] (normalB (unType <$> examineCode algn)) ds'] + alignment_ <- funD (mkName "alignment") [clause [wildP] (normalB (unType <$> examineCode algn)) ds'] let (peekFields, pokeFields) = unzip peekPokes -- peek peek_ <- do ptr <- newName "ptr" - Just peekN <- lookupValueName "peek" - funD peekN [clause [varP ptr] (normalB (peekCon con peekFields ptr)) ds'] + funD (mkName "peek") [clause [varP ptr] (normalB (peekCon con peekFields ptr)) ds'] -- poke poke_ <- do ptr <- newName "ptr" (cPat,body) <- pokeCon con pokeFields ptr - Just pokeN <- lookupValueName "poke" - funD pokeN [clause [varP ptr, pure cPat] (normalB (pure body)) ds'] + funD (mkName "poke") [clause [varP ptr, pure cPat] (normalB (pure body)) ds'] pure [sizeOf_, alignment_, peek_, poke_] @@ -296,17 +290,13 @@ processADT cons = do let ds' = map pure ds -- sizeOf - sizeOf_ <- do - Just sizeOfN <- lookupValueName "sizeOf" - funD sizeOfN [clause [wildP] + sizeOf_ <- funD (mkName "sizeOf") [clause [wildP] (normalB [e| let c = $(unType <$> examineCode off) in $algn' + c + mod (negate c) $algn' |]) ds'] -- alignment - alignment_ <- do - Just alignmentN <- lookupValueName "alignment" - funD alignmentN [clause [wildP] (normalB algn') ds'] + alignment_ <- funD (mkName "alignment") [clause [wildP] (normalB algn') ds'] -- peek peek_ <- do @@ -318,8 +308,7 @@ processADT cons = do | (n, (con, peekFields, _)) <- zip [0..] conPeekPokess , let n' = IntegerL n ] - Just peekN <- lookupValueName "peek" - funD peekN + funD (mkName "peek") [clause [varP ptr] (normalB (doE [ bindS (varP disc) [e| peek (castPtr $(varE ptr) :: Ptr $(pure discTy)) |] , noBindS (caseE (varE disc) mtchs) @@ -343,8 +332,7 @@ processADT cons = do , let patBody = pokeCon con pokeFields ptrOff , let n' = IntegerL n ] - Just pokeN <- lookupValueName "poke" - funD pokeN + funD (mkName "poke") [clause [varP ptr, varP disc] (normalB (caseE (varE disc) mtchs)) (map pure d' ++ ds')] pure [sizeOf_, alignment_, peek_, poke_] diff --git a/tests/AlgebraicDataTypes.hs b/tests/AlgebraicDataTypes.hs index aa044b5..e67c4dc 100644 --- a/tests/AlgebraicDataTypes.hs +++ b/tests/AlgebraicDataTypes.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE QuasiQuotes, TemplateHaskell, ExplicitForAll, ScopedTypeVariables #-} module AlgebraicDataTypes where import Language.Rust.Inline @@ -19,11 +18,11 @@ import Data.Int ( Int8, Int16, Int32, Int64 ) -- | A struct-like ADT where the fields have different sizes data StructLike = StructLike Int16 Int64 deriving (Show, Eq) -mkStorable [t| Storable StructLike |] +mkMarshalable [t| Marshalable StructLike |] -- | A struct-like newtype ADT where the field is compound newtype StructLike2 = StructLike2 (Int16, Int64) deriving (Show, Eq) -mkStorable [t| Storable StructLike2 |] +mkMarshalable [t| Marshalable StructLike2 |] -- | An ADT where: @@ -36,14 +35,15 @@ data Foo | Baz Char Int | Qux (Complex Float) Char deriving (Show, Eq) -mkStorable [t| Storable Foo |] +mkMarshalable [t| forall a. Marshalable a => Marshalable (Complex a) |] +mkMarshalable [t| Marshalable Foo |] -- | An ADT where fields are nested ADTs data Croc = Lob (Maybe Foo) Int | Boo Int8 Int8 deriving (Show, Eq) -mkStorable [t| Storable Croc |] +mkMarshalable [t| Marshalable Croc |] -- | A polymorphic ADT. (From the @these@ package). data These a b @@ -51,7 +51,7 @@ data These a b | That b | Both a b deriving (Show, Eq) -mkStorable [t| forall a b. (Storable a, Storable b) => Storable (These a b) |] +mkMarshalable [t| forall a b. (Marshalable a, Marshalable b) => Marshalable (These a b) |] -- | An ADT that needs more that a 'Word8' to store the tag data Big a @@ -86,7 +86,7 @@ data Big a | C280 | C281 | C282 | C283 | C284 | C285 | C286 | C287 | C288 | C289 | C290 | C291 | C292 | C293 | C294 | C295 | C296 | C297 | C298 | C299 a deriving (Show, Eq) -mkStorable [t| forall a. Storable a => Storable (Big a) |] +mkMarshalable [t| forall a. Marshalable a => Marshalable (Big a) |] -- | An ADT with a mixture of polymorphism and not. data Foo2 a @@ -95,7 +95,7 @@ data Foo2 a | Qux2 a a | Quux2 Int a deriving (Show, Eq) -mkStorable [t| forall a. Storable a => Storable (Foo2 a) |] +mkMarshalable [t| forall a. Marshalable a => Marshalable (Foo2 a) |] -- Set the context @@ -226,28 +226,28 @@ algebraicDataTypes = describe "Algebraic data types" $ do z1 = 1.3 :+ 4.5 z2 = 6.7 :+ 8.9 [rust| Cpx { $(z1: Cpx) + $(z2: Cpx) } |] `shouldBe` z1 + z2 - + it "Can marshal a custom single-constructor ADT argument/return" $ do let s1 = StructLike 78 (negate 267) s2 = StructLike 92 45223 s3 = StructLike2 (34, -92391) s4 = StructLike2 (576, 1234) - + for_ [s1,s2] $ \si -> [rust| StructLike2 { $(si: StructLike).in2() } |] `shouldBe` in2 si for_ [s3,s4] $ \si -> [rust| StructLike { $(si: StructLike2).out2() } |] `shouldBe` out2 si - + it "Can marshal a custom monomorphic ADT argument/return" $ do let f1, f2, f3, f4 :: Foo f1 = Baz 'a' 0 f2 = Baz 'b' 2 f3 = Qux (7.1 :+ 3.4) 'f' f4 = Bar - + for_ [f1,f2,f3,f4] $ \fi -> [rust| Foo { $(fi: Foo).quux() } |] `shouldBe` quux fi - + it "Can marshal nested monomorphic ADT arguments/returns" $ do let c1, c2, c3, c4, c5, c6, c7 :: Croc c1 = Lob (Just (Baz 'a' 0)) 2 @@ -257,23 +257,23 @@ algebraicDataTypes = describe "Algebraic data types" $ do c5 = Lob Nothing 3 c6 = Boo 3 (-2) c7 = Boo (-4) 2 - + for_ [c1,c2,c3,c4,c5,c6,c7] $ \ci -> [rust| Croc { $(ci: Croc).croc() } |] `shouldBe` croc ci - + it "Can marshal polymorphic ADT arguments/returns" $ do let t1, t2, t3 :: These Int8 Int64 t1 = This maxBound t2 = That 432442 t3 = Both (maxBound - 3) 879 - + for_ [t1,t2,t3] $ \ti -> let v1 = [rust| These { $(ti: These).bimap(|x| x as i16 * 2, |y| y + 2) } |] v2 = bimap (\x -> fromIntegral x * 2) (+ 2) ti in v1 `shouldBe` v2 - + it "Can marshal nested polymorphic ADT arguments/returns" $ do let t1, t2, t3 :: These (Maybe Int) (These Int8 Int8) t1 = This (Just 6) @@ -284,7 +284,7 @@ algebraicDataTypes = describe "Algebraic data types" $ do t6 = Both (Just 3) (That 8) t7 = Both Nothing (Both 3 5) t8 = Both (Just 213) (Both 78 98) - + for_ [t1,t2,t3,t4,t5,t6,t7,t8] $ \ti -> let v1 = [rust| These,These> { $(ti: These,These>).bimap( @@ -294,14 +294,14 @@ algebraicDataTypes = describe "Algebraic data types" $ do } |] v2 = bimap (fmap (+5)) (bimap (+2) (*3)) ti in v1 `shouldBe` v2 - + it "Can marshal a big ADT whose tag needs more than a `Word8`" $ do let b1, b2, b3, b4 :: Big Int64 b1 = C000 b2 = C160 b3 = C298 b4 = C299 89 - + for_ [b1,b2,b3,b4] $ \bi -> let v1 = [rust| Big { match $(bi: Big) { @@ -315,20 +315,20 @@ algebraicDataTypes = describe "Algebraic data types" $ do C299 i -> C299 (i + 1) b -> b in v1 `shouldBe` v2 - + it "Can marshal a custom `Foo2 Int` and `Foo2 (Foo2 Int)` return" $ do let f1, f2, f3, f4 :: Foo2 Int f1 = Bar2 f2 = Baz2 3 f3 = Qux2 (-1) 2 f4 = Quux2 (-8) 3 - + let fooed f = case f of Qux2 x y -> Qux2 (Qux2 x y) (Qux2 y x) Quux2 i x -> Quux2 (i + 1) (Qux2 x x) Bar2 -> Bar2 Baz2 w -> Baz2 w - + let fooed' f = [rust| Foo2> { match $(f: Foo2) { Foo2::Qux2(x,y) => Foo2::Qux2(Foo2::Qux2(x,y), Foo2::Qux2(y,x)), @@ -337,7 +337,7 @@ algebraicDataTypes = describe "Algebraic data types" $ do Foo2::Baz2(w) => Foo2::Baz2(w), } } |] - + fooed f1 `shouldBe` fooed' f1 fooed f2 `shouldBe` fooed' f2 fooed f3 `shouldBe` fooed' f3 diff --git a/tests/ByteStrings.hs b/tests/ByteStrings.hs index 714315f..a390ba6 100644 --- a/tests/ByteStrings.hs +++ b/tests/ByteStrings.hs @@ -1,6 +1,3 @@ -{-# LANGUAGE QuasiQuotes #-} -{-# LANGUAGE TemplateHaskell #-} - module ByteStrings where import Language.Rust.Inline @@ -11,8 +8,11 @@ import qualified Data.ByteString as ByteString import qualified Data.ByteString.Unsafe as ByteString import Data.Maybe (fromJust) import Data.String +import Data.Either (fromRight) +import Language.Rust.Inline.TH extendContext basic +extendContext prelude extendContext bytestrings setCrateModule @@ -36,9 +36,36 @@ bytestringSpec = describe "ByteStrings" $ do `shouldBe` rustBs ByteString.unsafeFinalize rustBs + it "can marshal optional ByteString arguments and return values" $ do + let rustSum inputs = + [rust| Option { + let inputs = $( inputs: Option<&[u8]> ); + inputs.map(|inputs| inputs.iter().sum()) + } |] + let inputs = ByteString.pack [0, 1, 2, 3] + rustSum Nothing `shouldBe` Nothing + rustSum (Just inputs) `shouldBe` Just (sum $ ByteString.unpack inputs) + + it "can marshal result ByteString arguments" $ do + let rustSum inputs = + [rust| Result { + let inputs = $( inputs: Result<&[u8], ()> ); + inputs.map(|inputs| inputs.iter().sum()) + } |] + let inputs = ByteString.pack [0, 1, 2, 3] + rustSum (Left ()) `shouldBe` Left () + rustSum (Right inputs) `shouldBe` Right (sum $ ByteString.unpack inputs) + it "can marshal optional ByteString return values" $ do let noRustBs = [rust| Option> { None } |] noRustBs `shouldBe` Nothing let rustBs = [rust| Option> { Some(vec![0, 1, 2, 3]) } |] - fromJust rustBs `shouldBe` ByteString.pack [0, 1, 2, 3] + rustBs `shouldBe` Just (ByteString.pack [0, 1, 2, 3]) + + it "can marshal result ByteString return values" $ do + let errRustBs = [rust| Result, ()> { Err(()) } |] + errRustBs `shouldBe` Left () + + let okRustBs = [rust| Result, ()> { Ok(vec![0, 1, 2, 3]) } |] + okRustBs `shouldBe` Right (ByteString.pack [0, 1, 2, 3]) diff --git a/tests/ForeignPtr.hs b/tests/ForeignPtr.hs index 833c626..b2c8e98 100644 --- a/tests/ForeignPtr.hs +++ b/tests/ForeignPtr.hs @@ -1,17 +1,15 @@ -{-# LANGUAGE QuasiQuotes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} - module ForeignPtr where import Language.Rust.Inline +import Language.Rust.Quote -import Data.Maybe (fromJust) +import Data.Maybe (fromJust, isJust) import Data.Word (Word64) import Foreign (Storable (..)) import Foreign.ForeignPtr import Foreign.Ptr import Test.Hspec +import Data.Either (fromRight) extendContext foreignPointers extendContext pointers @@ -19,6 +17,12 @@ extendContext prelude extendContext basic setCrateModule + +extendContext (singleton [ty| NotCopy |] [t| () |]) +[rust| +pub struct NotCopy(()); +|] + foreignPtrTypes :: Spec foreignPtrTypes = describe "ForeignPtr types" $ do it "Can marshal ForeignPtr arguments as references" $ do @@ -59,6 +63,19 @@ foreignPtrTypes = describe "ForeignPtr types" $ do } |] withForeignPtr (fromJust mp) peek >>= (`shouldBe` 42) + it "Can marshal result ForeignPtr returns" $ do + let mp = + [rust| Result, ()> { + Err(()) + } |] + mp `shouldBe` Left () + + let mp = + [rust| Result, ()> { + Ok(Box::new(42).into()) + } |] + withForeignPtr (fromRight undefined mp) peek >>= (`shouldBe` 42) + it "still has working pointers" $ alloca $ \p -> do [rustIO| () { @@ -71,3 +88,8 @@ foreignPtrTypes = describe "ForeignPtr types" $ do unsafe { *$(p: *u64) } } |] val `shouldBe` 42 + + it "still works without Copy" $ do + let mp = [rust| Option> { Some(Box::new(NotCopy(())).into()) } |] + mp `shouldSatisfy` isJust + diff --git a/tests/FunctionPointerTypes.hs b/tests/FunctionPointerTypes.hs index f0b5663..e75df56 100644 --- a/tests/FunctionPointerTypes.hs +++ b/tests/FunctionPointerTypes.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE QuasiQuotes, TemplateHaskell #-} module FunctionPointerTypes where import Language.Rust.Inline diff --git a/tests/GhcUnboxedTypes.hs b/tests/GhcUnboxedTypes.hs index 75b7f01..c458dad 100644 --- a/tests/GhcUnboxedTypes.hs +++ b/tests/GhcUnboxedTypes.hs @@ -1,4 +1,6 @@ -{-# LANGUAGE QuasiQuotes, TemplateHaskell, MagicHash, UnliftedFFITypes #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE UnliftedFFITypes #-} + module GhcUnboxedTypes where import Language.Rust.Inline diff --git a/tests/Main.hs b/tests/Main.hs index bff3ca2..beb5ef6 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE TemplateHaskell, QuasiQuotes, CPP #-} +{-# LANGUAGE CPP #-} #ifdef darwin_HOST_OS {-# OPTIONS_GHC -optl-Wl,-all_load #-} @@ -10,21 +10,22 @@ module Main where import Language.Rust.Inline -import SimpleTypes +import AlgebraicDataTypes +import ByteStrings +import Data.Word +import Foreign.Marshal.Array +import Foreign.Ptr +import Foreign.Storable +import ForeignPtr +import FunctionPointerTypes import GhcUnboxedTypes import PointerTypes -import FunctionPointerTypes import PreludeTypes -import AlgebraicDataTypes -import ByteStrings +import SimpleTypes import Submodule import Submodule.Submodule -import ForeignPtr -import Data.Word import Test.Hspec -import Foreign.Storable -import Foreign.Ptr -import Foreign.Marshal.Array +import Vectors extendContext basic setCrateRoot [] @@ -32,13 +33,14 @@ setCrateRoot [] main :: IO () main = hspec $ describe "Rust quasiquoter" $ do - simpleTypes + algebraicDataTypes + bytestringSpec + vectorsSpec + foreignPtrTypes + funcPointerTypes ghcUnboxedTypes pointerTypes - funcPointerTypes preludeTypes - algebraicDataTypes + simpleTypes submoduleTest subsubmoduleTest - bytestringSpec - foreignPtrTypes diff --git a/tests/PointerTypes.hs b/tests/PointerTypes.hs index d8cad90..91287d9 100644 --- a/tests/PointerTypes.hs +++ b/tests/PointerTypes.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE QuasiQuotes, TemplateHaskell #-} module PointerTypes where import Language.Rust.Inline diff --git a/tests/PreludeTypes.hs b/tests/PreludeTypes.hs index d52e356..2ee27c9 100644 --- a/tests/PreludeTypes.hs +++ b/tests/PreludeTypes.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE QuasiQuotes, TemplateHaskell #-} module PreludeTypes where import Language.Rust.Inline @@ -22,38 +21,38 @@ preludeTypes = describe "Common Prelude types" $ do x2 = Nothing [rust| Option { $(x1: Option).map(|n| n+2) } |] `shouldBe` fmap (+2) x1 [rust| Option { $(x2: Option).map(|n| n+2) } |] `shouldBe` fmap (+2) x2 - + it "Can marshal an `Either Int32 Char` argument/return" $ do let x1, x2 :: Either Int32 Char x1 = Left 9 x2 = Right 'e' - + [rust| Result { $(x1: Result) .map(|c| c.to_uppercase().next().unwrap()) .map_err(|n| n+2) } |] `shouldBe` bimap (+2) toUpper x1 - + [rust| Result { $(x2: Result) .map(|c| c.to_uppercase().next().unwrap()) .map_err(|n| n+2) } |] `shouldBe` bimap (+2) toUpper x2 - + it "Can marshal `(Int32, Char)` argument and `(Int32, Char, Word32)` return" $ do let x = (-9, 'c') - + [rust| (i32, char, u32) { let (n, c) = $(x: (i32, char)); (n * 3, c.to_uppercase().next().unwrap(), n.abs() as u32) } |] `shouldBe` (fst x * 3, toUpper (snd x), fromIntegral (abs (fst x))) - + it "Can marshal `Maybe (Int32, Either Char Word32)` argument/return" $ do let x1, x2, x3 :: Maybe (Int32, Either Char Word32) x1 = Nothing x2 = Just (3, Left 'c') x3 = Just (4, Right 8) - + let f x = [rust| Option<(i32, Result)> { $(x: Option<(i32, Result)>).map(|x| { match x { @@ -62,12 +61,11 @@ preludeTypes = describe "Common Prelude types" $ do } }) } |] - + let f' = fmap (\x -> case x of (x1, Right x2) -> (x1 * fromIntegral x2, Left 'a') (x1, e) -> (x1 - 1, e)) - + f x1 `shouldBe` f' x1 f x2 `shouldBe` f' x2 f x3 `shouldBe` f' x3 - diff --git a/tests/SimpleTypes.hs b/tests/SimpleTypes.hs index ae16b07..85e3e0c 100644 --- a/tests/SimpleTypes.hs +++ b/tests/SimpleTypes.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE QuasiQuotes, TemplateHaskell #-} module SimpleTypes where import Language.Rust.Inline @@ -67,7 +66,7 @@ simpleTypes = describe "Simple types" $ do it "Can marshal a `Bool` argument/return" $ do let x = 0 :: Word8 [rust| bool { !$(x: bool) } |] `shouldBe` (1 :: Word8) - --- it "Can marshal a `()` argument/return" $ do --- let x = () --- [rust| () { $(x: ()) } |] `shouldBe` () + + it "Can marshal a `()` argument/return" $ do + let x = () + [rust| () { $(x: ()) } |] `shouldBe` () diff --git a/tests/Submodule.hs b/tests/Submodule.hs index 66e6744..282504e 100644 --- a/tests/Submodule.hs +++ b/tests/Submodule.hs @@ -1,6 +1,3 @@ -{-# LANGUAGE QuasiQuotes #-} -{-# LANGUAGE TemplateHaskell #-} - module Submodule where import Data.Int diff --git a/tests/Submodule/Submodule.hs b/tests/Submodule/Submodule.hs index 41b5e66..557465b 100644 --- a/tests/Submodule/Submodule.hs +++ b/tests/Submodule/Submodule.hs @@ -1,6 +1,3 @@ -{-# LANGUAGE QuasiQuotes #-} -{-# LANGUAGE TemplateHaskell #-} - module Submodule.Submodule where import Data.Int diff --git a/tests/Vectors.hs b/tests/Vectors.hs new file mode 100644 index 0000000..94dbdb2 --- /dev/null +++ b/tests/Vectors.hs @@ -0,0 +1,36 @@ +module Vectors where + +import Language.Rust.Inline +import Language.Rust.Inline.TH +import Test.Hspec +import Control.Monad (forM) +import Foreign (withForeignPtr) +import Data.Word +import qualified Foreign + +extendContext basic +extendContext prelude +extendContext foreignPointers +extendContext vectors +setCrateModule + +vectorsSpec :: Spec +vectorsSpec = describe "Vectors" $ do + it "can marshal list return values" $ do + [rust| Vec { vec![] } |] `shouldBe` [] + + let mints = [rust| Vec> { vec![Some(17), None] } |] + mints `shouldBe` [Just 17, Nothing] + + let fps = [rust| Vec> { vec![Box::new(17).into(), Box::new(42).into() ] } |] + values <- forM fps $ flip withForeignPtr Foreign.peek + values `shouldBe` [17, 42] + + it "can marshal list arguments" $ do + let ints = [17, 42] :: [Word64] + let rsum = [rust| u64 { $(ints: Vec).iter().sum() } |] + rsum `shouldBe` sum ints + + it "can marshal pairs of lists" $ do + let yes = [rust| (Vec, Vec) { (vec![], vec![]) } |] + yes `shouldBe` ([], [])