From a5b517c2eaa42f8312720788bf483478cee1c72b Mon Sep 17 00:00:00 2001 From: ners Date: Sat, 15 Feb 2025 19:21:01 +0100 Subject: [PATCH 01/13] wip --- flake.lock | 12 +- inline-rust.cabal | 9 +- src/Language/Rust/Inline.hs | 154 +------ .../Rust/Inline/Context/Marshalable.hs | 406 ++++++++++++++++++ src/Language/Rust/Inline/Context/Prelude.hs | 2 +- src/Language/Rust/Inline/Marshal.hs | 79 ++-- tests/ByteStrings.hs | 10 +- tests/ForeignPtr.hs | 36 +- 8 files changed, 515 insertions(+), 193 deletions(-) create mode 100644 src/Language/Rust/Inline/Context/Marshalable.hs diff --git a/flake.lock b/flake.lock index 9e23b27..074062a 100644 --- a/flake.lock +++ b/flake.lock @@ -3,11 +3,11 @@ "language-rust": { "flake": false, "locked": { - "lastModified": 1729283594, - "narHash": "sha256-nVexe9Jrj0DtpjyPX+S2t/Op2iJyU1IRncW/BO8Ties=", + "lastModified": 1736519943, + "narHash": "sha256-5b4PRIpNCaZiTL3M1DZsKtSJWz5PPmyMpHatB7zBe+M=", "owner": "GaloisInc", "repo": "language-rust", - "rev": "ae59c7355b2fbbe818c1fdd23894df6bfa0432e5", + "rev": "f23c2b055217479bd3d1525dc9455d0f8bfcc394", "type": "github" }, "original": { @@ -18,11 +18,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1731319897, - "narHash": "sha256-PbABj4tnbWFMfBp6OcUK5iGy1QY+/Z96ZcLpooIbuEI=", + "lastModified": 1739446958, + "narHash": "sha256-+/bYK3DbPxMIvSL4zArkMX0LQvS7rzBKXnDXLfKyRVc=", "owner": "nixos", "repo": "nixpkgs", - "rev": "dc460ec76cbff0e66e269457d7b728432263166c", + "rev": "2ff53fe64443980e139eaa286017f53f88336dd0", "type": "github" }, "original": { diff --git a/inline-rust.cabal b/inline-rust.cabal index 907a777..efbf292 100644 --- a/inline-rust.cabal +++ b/inline-rust.cabal @@ -28,15 +28,16 @@ library exposed-modules: Language.Rust.Inline Language.Rust.Inline.TH other-modules: Language.Rust.Inline.Context - Language.Rust.Inline.Context.Prelude Language.Rust.Inline.Context.ByteString + Language.Rust.Inline.Context.Marshalable + Language.Rust.Inline.Context.Prelude + Language.Rust.Inline.Internal Language.Rust.Inline.Marshal Language.Rust.Inline.Parser Language.Rust.Inline.Pretty - Language.Rust.Inline.Internal + Language.Rust.Inline.TH.ReprC Language.Rust.Inline.TH.Storable Language.Rust.Inline.TH.Utilities - Language.Rust.Inline.TH.ReprC other-extensions: DeriveDataTypeable , CPP @@ -62,7 +63,7 @@ library test-suite spec hs-source-dirs: tests - ghc-options: -threaded + ghc-options: -threaded -ddump-splices -ddump-to-file if os(windows) extra-libraries: diff --git a/src/Language/Rust/Inline.hs b/src/Language/Rust/Inline.hs index d700f4d..18944f1 100644 --- a/src/Language/Rust/Inline.hs +++ b/src/Language/Rust/Inline.hs @@ -113,6 +113,7 @@ import System.Random (randomIO) import qualified Data.ByteString.Unsafe as ByteString import Foreign.Storable (Storable (..)) +import qualified Language.Rust.Inline.Context.Marshalable as Marshalable {- $overview @@ -314,64 +315,15 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do -- Convert the Haskell return type to a marshallable FFI type (returnFfi, haskRet') <- do marshalForm <- ghcMarshallable haskRet - let fptrRet haskRet' = [t|Ptr (Ptr $(pure haskRet'), FunPtr (Ptr $(pure haskRet') -> IO ())) -> IO ()|] - let bsRet = [t|Ptr (Ptr Word8, Word, FunPtr (Ptr Word8 -> Word -> IO ())) -> IO ()|] - ret <- case marshalForm of - BoxedDirect -> [t|IO $(pure haskRet)|] - BoxedIndirect -> [t|Ptr $(pure haskRet) -> IO ()|] - UnboxedDirect - | isPure -> pure haskRet - | otherwise -> - let retTy = showTy haskRet - in fail ("Cannot put unlifted type ‘" ++ retTy ++ "’ in IO") - ByteString -> bsRet - OptionalByteString -> bsRet - ForeignPtr - | AppT _ haskRet' <- haskRet -> fptrRet haskRet' - | otherwise -> fail ("Cannot marshal " ++ showTy haskRet ++ " using the ForeignPtr calling convention") - OptionalForeignPtr - | AppT _ (AppT _ haskRet') <- haskRet -> fptrRet haskRet' - | otherwise -> fail ("Cannot marshal " ++ showTy haskRet ++ " as an optional ForeignPtr") + ret <- returnType marshalForm haskRet pure (marshalForm, pure ret) -- Convert the Haskell arguments to marshallable FFI types (marshalForms, haskArgs') <- fmap unzip $ for haskArgs $ \haskArg -> do marshalForm <- ghcMarshallable haskArg - case marshalForm of - BoxedIndirect - | returnFfi == UnboxedDirect -> - let argTy = showTy haskArg - retTy = showTy haskRet - in fail - ( "Cannot pass an argument ‘" - ++ argTy - ++ "’" - ++ " indirectly when returning an unlifted type " - ++ "‘" - ++ retTy - ++ "’" - ) - | otherwise -> do - ptr <- [t|Ptr $(pure haskArg)|] - pure (BoxedIndirect, ptr) - ByteString -> do - rbsT <- [t|Ptr (Ptr Word8, Word)|] - pure (ByteString, rbsT) - OptionalByteString -> do - rbsT <- [t|Ptr (Ptr Word8, Word)|] - pure (OptionalByteString, rbsT) - ForeignPtr - | AppT _ haskArg' <- haskArg -> do - ptr <- [t|Ptr $(pure haskArg')|] - pure (ForeignPtr, ptr) - | otherwise -> fail ("Cannot marshal " ++ showTy haskRet ++ " using the ForeignPtr calling convention") - OptionalForeignPtr - | AppT _ (AppT _ haskArg') <- haskArg -> do - ptr <- [t|Ptr $(pure haskArg')|] - pure (OptionalForeignPtr, ptr) - | otherwise -> fail ("Cannot marshal " ++ showTy haskRet ++ " as an optional ForeignPtr") - _ -> pure (marshalForm, haskArg) + ret <- argumentType marshalForm haskArg + pure (marshalForm, ret) -- Generate the Haskell FFI import declaration and emit it bsFree <- newName $ "bsFree" ++ show (abs q) @@ -394,68 +346,6 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do -- accumulated arguments. If the return value is not marshallable, we have to -- 'alloca' some space to put the return value. goArgs acc [] - | returnFfi == ByteString = do - ret <- newName "ret" - ptr <- newName "ptr" - len <- newName "len" - finalizer <- newName "finalizer" - [e| - alloca - ( \($(varP ret)) -> do - $(appsE (varE qqName : reverse (varE ret : acc))) - ($(varP ptr), $(varP len), $(varP finalizer)) <- peek $(varE ret) - ByteString.unsafePackCStringFinalizer - $(varE ptr) - (fromIntegral $(varE len)) - ($(varE bsFree) $(varE finalizer) $(varE ptr) $(varE len)) - ) - |] - | returnFfi == ForeignPtr = do - finalizer <- newName "finalizer" - ptr <- newName "ptr" - ret <- newName "ret" - [e| - alloca - ( \($(varP ret)) -> do - $(appsE (varE qqName : reverse (varE ret : acc))) - ($(varP ptr), $(varP finalizer)) <- peek $(varE ret) - newForeignPtr $(varE finalizer) $(varE ptr) - ) - |] - | returnFfi == OptionalForeignPtr = do - finalizer <- newName "finalizer" - ptr <- newName "ptr" - ret <- newName "ret" - [e| - alloca - ( \($(varP ret)) -> do - $(appsE (varE qqName : reverse (varE ret : acc))) - ($(varP ptr), $(varP finalizer)) <- peek $(varE ret) - if $(varE ptr) == nullPtr - then pure Nothing - else Just <$> newForeignPtr $(varE finalizer) $(varE ptr) - ) - |] - | returnFfi == OptionalByteString = do - ret <- newName "ret" - ptr <- newName "ptr" - len <- newName "len" - finalizer <- newName "finalizer" - [e| - alloca - ( \($(varP ret)) -> do - $(appsE (varE qqName : reverse (varE ret : acc))) - ($(varP ptr), $(varP len), $(varP finalizer)) <- peek $(varE ret) - if $(varE ptr) == nullPtr - then pure Nothing - else - Just - <$> ByteString.unsafePackCStringFinalizer - $(varE ptr) - (fromIntegral $(varE len)) - ($(varE bsFree) $(varE finalizer) $(varE ptr) $(varE len)) - ) - |] | returnByValue returnFfi = appsE (varE qqName : reverse acc) | otherwise = do ret <- newName "ret" @@ -464,7 +354,7 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do ( \($(varP ret)) -> do $(appsE (varE qqName : reverse (varE ret : acc))) - peek $(varE ret) + Marshalable.peek $(varE ret) ) |] @@ -475,46 +365,20 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do case arg of Nothing -> fail ("Could not find Haskell variable ‘" ++ argStr ++ "’") Just argName - | marshalForm == ByteString -> do - ptr <- newName "ptr" - len <- newName "len" - bsp <- newName "bsp" - [e| - withByteString - $(varE argName) - ( \($(varP ptr)) ($(varP len)) -> - with ($(varE ptr), $(varE len)) (\($(varP bsp)) -> $(goArgs (varE bsp : acc) args)) - ) - |] - | marshalForm == ForeignPtr -> do - ptr <- newName "ptr" - [e| - withForeignPtr $(varE argName) (\($(varP ptr)) -> $(goArgs (varE ptr : acc) args)) - |] - | marshalForm == OptionalForeignPtr -> do - ptr <- newName "ptr" - fptr <- newName "fptr" - [e| - case $(varE argName) of - Nothing -> let $(varP ptr) = nullPtr in $(goArgs (varE ptr : acc) args) - Just $(varP fptr) -> - withForeignPtr $(varE fptr) (\($(varP ptr)) -> $(goArgs (varE ptr : acc) args)) - |] - | marshalForm == OptionalByteString -> fail "Don't" - | passByValue marshalForm -> goArgs (varE argName : acc) args - | otherwise -> do + | marshalStep marshalForm -> do x <- newName "x" [e| - with + Marshalable.with $(varE argName) ( \($(varP x)) -> $(goArgs (varE x : acc) args) ) |] + | otherwise -> goArgs (varE argName : acc) args let haskCall' = goArgs [] (rustArgNames `zip` marshalForms) haskCall = - if isPure && returnFfi /= UnboxedDirect + if isPure && runsInIO returnFfi then [e|unsafeLocalState $haskCall'|] else haskCall' diff --git a/src/Language/Rust/Inline/Context/Marshalable.hs b/src/Language/Rust/Inline/Context/Marshalable.hs new file mode 100644 index 0000000..9414494 --- /dev/null +++ b/src/Language/Rust/Inline/Context/Marshalable.hs @@ -0,0 +1,406 @@ +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE FlexibleContexts #-} + +module Language.Rust.Inline.Context.Marshalable where + +import Foreign + ( Word8, + Ptr, + FunPtr, + ForeignPtr, + Storable, + nullPtr, + plusPtr, + newForeignPtr, + withForeignPtr) +import qualified Foreign +import Data.ByteString (ByteString) +import Data.ByteString.Internal (ByteString(PS)) +import qualified Data.ByteString.Unsafe as ByteString +import Language.Rust.Inline.Context.Prelude () + +-- type family WithPtrType a where +-- WithPtrType ByteString = Ptr (Ptr Word8, Word) +-- WithPtrType (ForeignPtr a) = Ptr a +-- WithPtrType a = Ptr a + +class HasWith a where + type WithPtrType a + with :: a -> (Ptr (WithPtrType a) -> IO b) -> IO b + +instance {-# OVERLAPPING #-} HasWith ByteString where + type WithPtrType ByteString = (Ptr Word8, Word) + with (PS ptr off len) cont = withForeignPtr ptr $ \ptr' -> Foreign.with (ptr' `plusPtr` off, fromIntegral len) cont + +instance {-# OVERLAPPING #-} HasWith (ForeignPtr a) where + type WithPtrType (ForeignPtr a) = a + with = withForeignPtr + +-- instance {-# OVERLAPPABLE #-} (WithPtrType a ~ Ptr a, Storable a) => HasWith a where +-- with = Foreign.with + +-- type family PeekType a where +-- PeekType ByteString = (Ptr Word8, Word, FunPtr (Ptr Word8 -> Word -> IO ())) +-- PeekType (ForeignPtr a) = (Ptr a, FunPtr (Ptr a -> IO ())) +-- PeekType a = a + +class HasPeek a where + type PeekType a + peek :: Ptr (PeekType a) -> IO a + +foreign import ccall safe "dynamic" bytestringFree :: FunPtr (Ptr Word8 -> Word -> IO ()) -> Ptr Word8 -> Word -> IO () + +instance {-# OVERLAPPING #-} HasPeek ByteString where + type PeekType ByteString = (Ptr Word8, Word, FunPtr (Ptr Word8 -> Word -> IO ())) + peek ret = do + (ptr, len, finalizer) <- Foreign.peek ret + ByteString.unsafePackCStringFinalizer ptr (fromIntegral len) (bytestringFree finalizer ptr len) + +instance {-# OVERLAPPING #-} HasPeek (ForeignPtr a) where + type PeekType (ForeignPtr a) = (Ptr a, FunPtr (Ptr a -> IO ())) + peek ret = do + (ptr, finalizer) <- Foreign.peek ret + newForeignPtr finalizer ptr + +-- instance {-# OVERLAPPABLE #-} (PeekType a ~ a, Storable a) => HasPeek a where +-- peek = Foreign.peek + +class (HasWith a, HasPeek a) => Marshalable a where + +instance (Storable (PeekType a), HasPeek a) => HasPeek (Maybe a) where + type PeekType (Maybe a) = Maybe (PeekType a) + peek ret = do + d <- Foreign.peek $ Foreign.castPtr @_ @Word8 ret + case d of + 0 -> pure Nothing + _ -> Foreign.peek $ ret `plusPtr` (Foreign.alignment @(PeekType a) undefined) + +instance HasWith a => HasWith (Maybe a) where + type WithPtrType (Maybe a) = WithPtrType a + with Nothing f = f nullPtr + with (Just a) f = with a f + +-- -- | Generate 'Marshalable' instance for a non-recursive simple algebraic data +-- -- type. The instance follows the usual C layout for determining alignment and +-- -- size. +-- -- +-- -- Sum types are implemented as tagged unions. +-- -- +-- -- >>> mkMarshalable [t| forall a. Marshalable a => Marshalable (Maybe a) |] +-- -- +-- -- Remember to have 'ScopedTypeVariables', 'ExplicitForall', and 'EmptyCase' +-- -- enabled when calling this! +-- mkMarshalable :: TypeQ -- ^ a type representing the desired instance head +-- -> Q [Dec] -- ^ the instance declaration +-- mkMarshalable tyq = do +-- pseudoInstHead <- tyq +-- +-- -- Extract the context +-- marshalable <- [t| Marshalable |] +-- (ctx, ty') <- +-- case pseudoInstHead of +-- ForallT _ ctx (AppT s ty) | s == marshalable -> pure (ctx, ty) +-- AppT s ty | s == marshalable -> pure ([], ty) +-- _ -> fail "mkMarshalable: malformed 'Marshalable' instance head" +-- +-- -- Get the type constructors name +-- (_,cons') <- getConstructors ty' +-- +-- -- Produce the instance +-- methods <- processADT [ (nameCon n, tyArgs) | (n,tyArgs) <- cons' ] +-- dec <- instanceD (pure ctx) (pure (AppT marshalable ty')) (map pure methods) +-- pure [dec] +-- +-- mkTupleMarshalable :: Int -- ^ arity of tuple +-- -> Q [Dec] -- ^ the instance declaration +-- mkTupleMarshalable n = do +-- marshalable <- [t| Marshalable |] +-- tyVars <- sequence (take n [ newName (c : show i) +-- | i <- [(1 :: Int)..] +-- , c <- ['a'..'z'] +-- ]) +-- let ctx = [ AppT marshalable (VarT tyVar) | tyVar <- tyVars ] +-- let instHead = AppT marshalable (foldl AppT (TupleT n) (map VarT tyVars)) +-- +-- methods <- processADT [ (tupCon, map VarT tyVars) ] +-- let dec = InstanceD Nothing ctx instHead methods +-- pure [dec] +-- +-- -- * Constructor utilities +-- data Constructor = Constructor +-- { conPat :: [Pat] -> Pat +-- , conExp :: [Exp] -> Exp +-- } +-- +-- nameCon :: Name -> Constructor +-- nameCon n = Constructor (ConP n []) (foldl AppE (ConE n)) +-- +-- tupCon :: Constructor +-- tupCon = Constructor TupP (TupE . fmap Just) +-- +-- +-- -- * Alignment +-- +-- -- | This is the information you need to carry along as you visit the fields of +-- -- a struct/union. +-- data Alignment = Alignment +-- { decs :: [Dec] +-- -- ^ declarations for variables relied on by offset and align +-- +-- , offsetSoFar :: Code Q Int +-- -- ^ total bytes occupied so far by fields +-- +-- , alignSoFar :: Code Q Int +-- -- ^ size (in bytes) of the largest member in the struct +-- } +-- +-- -- | Combining alignment means concatenating the dependent declarations, and +-- -- take the maximum for offset and alignment. +-- instance Semigroup Alignment where +-- a1 <> a2 = Alignment +-- { decs = decs a1 <> decs a2 +-- , offsetSoFar = [|| $$(offsetSoFar a1) `max` $$(offsetSoFar a2) ||] +-- , alignSoFar = [|| $$(alignSoFar a1) `max` $$(alignSoFar a2) ||] +-- } +-- +-- -- | The 'mconcat' method calls 'maximum' +-- instance Monoid Alignment where +-- mempty = Alignment [] [|| 0 ||] [|| 1 ||] +-- mappend = (<>) +-- mconcat as = Alignment +-- { decs = concatMap decs as +-- , offsetSoFar = [|| maximum $$(liftCode $ listTE <$> mapM (examineCode . offsetSoFar) as) ||] +-- , alignSoFar = [|| maximum $$(liftCode $ listTE <$> mapM (examineCode . alignSoFar) as) ||] +-- } +-- +-- -- | This is the state we will bundle along while visiting fields. +-- type StructState = StateT Alignment Q +-- +-- -- | Make a typed list. This function is like 'listE', but for 'TExp'. +-- listTE :: [TExp a] -> TExp [a] +-- listTE = TExp . ListE . map unType +-- +-- +-- -- * Peek and poke helper functions +-- +-- -- | Produces a 'do' block for peeking a constructor. The generated code has the +-- -- following shape: +-- -- +-- -- @ +-- -- do f1 <- ... ptr +-- -- f2 <- ... ptr +-- -- ... +-- -- fn <- ... ptr +-- -- return (Con f1 f2 ... fn) +-- -- @ +-- -- +-- peekCon :: Constructor -- ^ name of the constructor +-- -> [Exp -> Q Exp] -- ^ how to peek every field +-- -> Name -- ^ the base pointer +-- -> Q Exp -- ^ a 'do' expression for peeking the constructor +-- peekCon con peekFields ptr = do +-- (ns, binds) <- unzip <$> do +-- for peekFields $ \fldCont -> do +-- n <- newName "n" +-- pure (varE n, bindS (varP n) (fldCont (VarE ptr))) +-- let ret = [e| return $(conExp con <$> sequence ns) |] +-- doE (binds ++ [noBindS ret]) +-- +-- -- | Produces a 'do' block for poking a constructor, along with a pattern for +-- -- extracting out the right fields. Given a pattern like @Con f1 f2 ... fn@, the +-- -- generated block has the following shape: +-- -- +-- -- @ +-- -- do ... ptr f1 +-- -- ... ptr f2 +-- -- ... +-- -- ... ptr fn +-- -- @ +-- pokeCon :: Constructor -- ^ name of the constructor +-- -> [Exp -> Q Exp] -- ^ how to poke every field +-- -> Name -- ^ the base poniter +-- -> Q (Pat, Exp) -- ^ a pattern to match, an expression for poking +-- pokeCon con pokeFields ptr = do +-- (ns, stmts) <- unzip <$> do +-- for pokeFields $ \fldCont -> do +-- n <- newName "n" +-- pure (varP n, noBindS [e| $(fldCont (VarE ptr)) $(varE n) |]) +-- pat <- conPat con <$> sequence ns +-- expr <- if null stmts then [e| pure () |] else doE stmts +-- return (pat, expr) +-- +-- +-- -- * Traversing fields (putting everything together) +-- +-- -- TODO: look at `alignPtr :: Ptr a -> Int -> Ptr a` +-- +-- -- | Process a field of a given type. +-- processField :: Type -> StructState (Exp -> Q Exp, Exp -> Q Exp) +-- processField ty = do +-- let alignTy, sizeTy :: Code Q Int +-- alignTy = Code $ TExp <$> [e| alignment (undefined :: $(pure ty)) |] +-- sizeTy = Code $ TExp <$> [e| sizeOf (undefined :: $(pure ty)) |] +-- +-- -- get state at the end of the last field +-- Alignment prevDecs prevOff prevAlign <- get +-- +-- -- beginning offset +-- beginOffV <- lift $ newName "beginOff" +-- let beginOffE, beginOff :: Code Q Int +-- beginOffE = [|| $$prevOff + mod (negate $$prevOff) $$alignTy ||] +-- beginOff = Code $ TExp <$> varE beginOffV +-- assignBeginOff <- lift [d| $(varP beginOffV) = $(unType <$> examineCode beginOffE) |] +-- +-- -- offset after this field +-- newOffV <- lift $ newName "afterOff" +-- let newOffE :: Code Q Int +-- newOffE = [|| $$beginOff + $$sizeTy ||] +-- newOff <- lift (TExp <$> varE newOffV) +-- assignNewOff <- lift [d| $(varP newOffV) = $(unType <$> examineCode newOffE) |] +-- +-- -- alignment after this field +-- newAlignV <- lift $ newName "algn" +-- let newAlignE :: Code Q Int +-- newAlignE = [|| $$alignTy `max` $$prevAlign ||] +-- newAlign <- lift (TExp <$> varE newAlignV) +-- assignNewAlign <- lift [d| $(varP newAlignV) = $(unType <$> examineCode newAlignE) |] +-- +-- -- update state +-- put (Alignment { decs = concat [ assignBeginOff +-- , assignNewOff +-- , assignNewAlign +-- , prevDecs +-- ] +-- , offsetSoFar = liftCode (pure newOff) +-- , alignSoFar = liftCode (pure newAlign) +-- }) +-- +-- -- TODO: consider degenerate sizeof(..) = 0 cases +-- pure ( \addrE -> [e| peek (castPtr $(pure addrE) `plusPtr` $(unType <$> examineCode beginOff)) |] +-- , \addrE -> [e| poke (castPtr $(pure addrE) `plusPtr` $(unType <$> examineCode beginOff)) |] +-- ) +-- +-- +-- -- | Process an algebraic data type. +-- -- +-- -- TODO: think about the zero constructor case... +-- processADT :: [(Constructor, [Type])] -- ^ constructors and the types of their fields +-- -> Q [Dec] -- ^ methods of the 'Marshalable' class +-- +-- -- The one constructor case is special - we don't need to specify a tag +-- processADT [(con, fields)] = do +-- +-- initAlign <- mempty +-- (peekPokes, Alignment ds off algn) +-- <- runStateT (traverse processField fields) initAlign +-- let ds' = map pure ds +-- +-- -- sizeOf +-- sizeOf_ <- do +-- Just sizeOfN <- lookupValueName "sizeOf" +-- funD sizeOfN [clause [wildP] +-- (normalB [e| let c = $(unType <$> examineCode off) +-- in c + mod (negate c) $(unType <$> examineCode algn) |]) +-- ds'] +-- +-- -- alignment +-- alignment_ <- do +-- Just alignmentN <- lookupValueName "alignment" +-- funD alignmentN [clause [wildP] (normalB (unType <$> examineCode algn)) ds'] +-- +-- let (peekFields, pokeFields) = unzip peekPokes +-- +-- -- peek +-- peek_ <- do +-- ptr <- newName "ptr" +-- Just peekN <- lookupValueName "peek" +-- funD peekN [clause [varP ptr] (normalB (peekCon con peekFields ptr)) ds'] +-- +-- -- poke +-- poke_ <- do +-- ptr <- newName "ptr" +-- (cPat,body) <- pokeCon con pokeFields ptr +-- Just pokeN <- lookupValueName "poke" +-- funD pokeN [clause [varP ptr, pure cPat] (normalB (pure body)) ds'] +-- +-- pure [sizeOf_, alignment_, peek_, poke_] +-- +-- processADT cons = do +-- +-- let discNum = length cons +-- discTy <- snd . head . dropWhile (\(m,_) -> discNum > m + 1) $ +-- [ (fromIntegral (maxBound :: Word8), [t| Word8 |]) +-- , (fromIntegral (maxBound :: Word16), [t| Word16 |]) +-- , (fromIntegral (maxBound :: Word32), [t| Word32 |]) +-- , (fromIntegral (maxBound :: Word64), [t| Word64 |]) +-- ] +-- +-- initAlign <- mempty +-- (conPeekPokess, algns) <- unzip <$> do +-- for cons $ \(con, fields) -> do +-- (peekPokes, algn) <- runStateT (traverse processField fields) initAlign +-- let (peekFields, pokeFields) = unzip peekPokes +-- pure ((con, peekFields, pokeFields), algn) +-- Alignment ds off algn <- mconcat (map pure algns) +-- let discSizeOf = [e| sizeOf (undefined :: $(pure discTy)) |] +-- algn' = [e| $discSizeOf `max` $(unType <$> examineCode algn) |] +-- let ds' = map pure ds +-- +-- -- sizeOf +-- sizeOf_ <- do +-- Just sizeOfN <- lookupValueName "sizeOf" +-- funD sizeOfN [clause [wildP] +-- (normalB [e| let c = $(unType <$> examineCode off) +-- in $algn' + c + mod (negate c) $algn' |]) +-- ds'] +-- +-- -- alignment +-- alignment_ <- do +-- Just alignmentN <- lookupValueName "alignment" +-- funD alignmentN [clause [wildP] (normalB algn') ds'] +-- +-- -- peek +-- peek_ <- do +-- ptr <- newName "ptr" +-- ptrOff <- newName "ptrOff" +-- d' <- [d| $(varP ptrOff) = $(varE ptr) `plusPtr` $algn' |] +-- disc <- newName "disc" +-- let mtchs = [ match (litP n') (normalB (peekCon con peekFields ptrOff)) [] +-- | (n, (con, peekFields, _)) <- zip [0..] conPeekPokess +-- , let n' = IntegerL n +-- ] +-- Just peekN <- lookupValueName "peek" +-- funD peekN +-- [clause [varP ptr] +-- (normalB (doE [ bindS (varP disc) [e| peek (castPtr $(varE ptr) :: Ptr $(pure discTy)) |] +-- , noBindS (caseE (varE disc) mtchs) +-- ])) +-- (map pure d' ++ ds')] +-- +-- -- poke +-- poke_ <- do +-- ptr <- newName "ptr" +-- ptrOff <- newName "ptrOff" +-- d' <- [d| $(varP ptrOff) = $(varE ptr) `plusPtr` $algn' |] +-- disc <- newName "disc" +-- let mtchs = [ do { (pat,body) <- patBody +-- ; match (pure pat) +-- (normalB (doE (map noBindS [ [e| poke (castPtr $(varE ptr) :: Ptr $(pure discTy)) $(litE n') |] +-- , pure body +-- ]))) +-- [] +-- } +-- | (n, (con, _, pokeFields)) <- zip [0..] conPeekPokess +-- , let patBody = pokeCon con pokeFields ptrOff +-- , let n' = IntegerL n +-- ] +-- Just pokeN <- lookupValueName "poke" +-- funD pokeN +-- [clause [varP ptr, varP disc] (normalB (caseE (varE disc) mtchs)) (map pure d' ++ ds')] +-- +-- pure [sizeOf_, alignment_, peek_, poke_] + diff --git a/src/Language/Rust/Inline/Context/Prelude.hs b/src/Language/Rust/Inline/Context/Prelude.hs index 08448cc..b6ada69 100644 --- a/src/Language/Rust/Inline/Context/Prelude.hs +++ b/src/Language/Rust/Inline/Context/Prelude.hs @@ -64,7 +64,7 @@ maybeContext = do where rule (PathTy Nothing (Path False [PathSegment "Option" (Just (AngleBracketed [] [t] [] _)) _] _) _) context = do (t', rInterOpt) <- lookupRTypeInContext t context - let inter = mkGenPathTy "MaybeC" <$> ((\x -> [x]) <$> maybe (pure t) id rInterOpt) + let inter = pure . mkGenPathTy "MaybeC" $ fromMaybe t rInterOpt pure ([t| Maybe $t' |], Just inter) rule _ _ = mempty diff --git a/src/Language/Rust/Inline/Marshal.hs b/src/Language/Rust/Inline/Marshal.hs index 506df15..907dc19 100644 --- a/src/Language/Rust/Inline/Marshal.hs +++ b/src/Language/Rust/Inline/Marshal.hs @@ -10,6 +10,7 @@ Portability : GHC {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE MagicHash #-} +{-# LANGUAGE LambdaCase #-} module Language.Rust.Inline.Marshal where @@ -32,21 +33,14 @@ import Data.Array.Storable ( StorableArray, Ix, withStorableArray, import GHC.Exts -data MarshalForm - = UnboxedDirect -- ^ value is marshallable and must be passed directly to the FFI - | BoxedDirect -- ^ value is marshallable and can be passed directly to the FFI - | BoxedIndirect -- ^ value isn't marshallable directly but may be passed indirectly via a 'Ptr' - | ByteString - | ForeignPtr - | OptionalForeignPtr - | OptionalByteString - deriving (Eq, Show) - -passByValue :: MarshalForm -> Bool -passByValue = (`elem` [UnboxedDirect, BoxedDirect, ForeignPtr]) - -returnByValue :: MarshalForm -> Bool -returnByValue = (`elem` [UnboxedDirect, BoxedDirect]) +data MarshalForm = MarshalForm + { passByValue :: Bool + , marshalStep :: Bool + , returnByValue :: Bool + , returnType :: Type -> Q Type + , argumentType :: Type -> Q Type + , runsInIO :: Bool + } -- | Identify which types can be marshalled by the GHC FFI and which types are -- unlifted. A negative response to the first of these questions doesn't mean @@ -65,17 +59,52 @@ ghcMarshallable ty = do fptrCons <- [t| ForeignPtr |] maybeCons <- [t| Maybe |] + let unboxedDirect = MarshalForm + { passByValue = True + , marshalStep = False + , returnByValue = True + , returnType = pure + , argumentType = pure + , runsInIO = False + } + boxedDirect = unboxedDirect{ returnType = \t -> [t|IO $(pure t)|], runsInIO = True } + boxedIndirect = MarshalForm + { passByValue = False + , marshalStep = True + , returnByValue = False + , returnType = \t -> [t|Ptr $(pure t) -> IO ()|] + , argumentType = \t -> [t|Ptr $(pure t)|] + , runsInIO = True + } + foreignPtr = MarshalForm + { passByValue = True + , marshalStep = True + , returnByValue = False + , returnType = \case + AppT _ r -> [t|Ptr (Ptr $(pure r), FunPtr (Ptr $(pure r) -> IO ())) -> IO ()|] + t -> fail $ "Cannot marshal " <> (show . pprParendType) t <> " as a ForeignPtr" + , argumentType = \case + AppT _ r -> [t|Ptr $(pure r)|] + t -> fail $ "Cannot marshal " <> (show . pprParendType) t <> " as a ForeignPtr" + , runsInIO = True + } + byteString = MarshalForm + { passByValue = False + , marshalStep = True + , returnByValue = False + , returnType = const [t|Ptr (Ptr Word8, Word, FunPtr (Ptr Word8 -> Word -> IO ())) -> IO ()|] + , argumentType = const [t|Ptr (Ptr Word8, Word)|] + , runsInIO = True + } + case ty of - _ | ty `elem` simpleU -> pure UnboxedDirect - | ty `elem` simpleB -> pure BoxedDirect - | ty == bytestring -> pure ByteString - AppT con _ | con `elem` tyconsU -> pure UnboxedDirect - | con `elem` tyconsB -> pure BoxedDirect - | con == fptrCons -> pure ForeignPtr - AppT mb (AppT c _) - | mb == maybeCons && c == fptrCons -> pure OptionalForeignPtr - AppT mb c | mb == maybeCons && c == bytestring -> pure OptionalByteString - _ -> pure BoxedIndirect + _ | ty `elem` simpleU -> pure unboxedDirect + | ty `elem` simpleB -> pure boxedDirect + | ty == bytestring -> pure byteString + AppT con _ | con `elem` tyconsU -> pure unboxedDirect + | con `elem` tyconsB -> pure boxedDirect + | con == fptrCons -> pure foreignPtr + _ -> pure boxedIndirect where qSimpleUnboxed = [ [t| Char# |] , [t| Int# |] diff --git a/tests/ByteStrings.hs b/tests/ByteStrings.hs index 714315f..2aecffa 100644 --- a/tests/ByteStrings.hs +++ b/tests/ByteStrings.hs @@ -11,6 +11,7 @@ import qualified Data.ByteString as ByteString import qualified Data.ByteString.Unsafe as ByteString import Data.Maybe (fromJust) import Data.String +import Data.Either (fromRight) extendContext basic extendContext bytestrings @@ -41,4 +42,11 @@ bytestringSpec = describe "ByteStrings" $ do noRustBs `shouldBe` Nothing let rustBs = [rust| Option> { Some(vec![0, 1, 2, 3]) } |] - fromJust rustBs `shouldBe` ByteString.pack [0, 1, 2, 3] + rustBs `shouldBe` Just (ByteString.pack [0, 1, 2, 3]) + + -- it "can marshal result ByteString return values" $ do + -- let errRustBs = [rust| Result, ()> { Err(()) } |] + -- errRustBs `shouldBe` Left () + + -- let okRustBs = [rust| Result, ()> { Ok(vec![0, 1, 2, 3]) } |] + -- okRustBs `shouldBe` Right (ByteString.pack [0, 1, 2, 3]) diff --git a/tests/ForeignPtr.hs b/tests/ForeignPtr.hs index 833c626..9e8099a 100644 --- a/tests/ForeignPtr.hs +++ b/tests/ForeignPtr.hs @@ -12,6 +12,7 @@ import Foreign (Storable (..)) import Foreign.ForeignPtr import Foreign.Ptr import Test.Hspec +import Data.Either (fromRight) extendContext foreignPointers extendContext pointers @@ -46,18 +47,31 @@ foreignPtrTypes = describe "ForeignPtr types" $ do val <- withForeignPtr p peek val `shouldBe` 42 - it "Can marshal optional ForeignPtr returns" $ do - let mp = - [rust| Option> { - None - } |] - mp `shouldBe` Nothing + -- it "Can marshal optional ForeignPtr returns" $ do + -- let mp = + -- [rust| Option> { + -- None + -- } |] + -- mp `shouldBe` Nothing - let mp = - [rust| Option> { - Some(Box::new(42).into()) - } |] - withForeignPtr (fromJust mp) peek >>= (`shouldBe` 42) + -- let mp = + -- [rust| Option> { + -- Some(Box::new(42).into()) + -- } |] + -- withForeignPtr (fromJust mp) peek >>= (`shouldBe` 42) + + -- it "Can marshal result ForeignPtr returns" $ do + -- let mp = + -- [rust| Result, ()> { + -- Err(()) + -- } |] + -- mp `shouldBe` Left () + + -- let mp = + -- [rust| Result, ()> { + -- Ok(Box::new(42).into()) + -- } |] + -- withForeignPtr (fromRight mp) peek >>= (`shouldBe` 42) it "still has working pointers" $ alloca $ \p -> do From effa596ef9c2111e844e7a58593c47313f236087 Mon Sep 17 00:00:00 2001 From: Viktor Kleen Date: Sun, 16 Feb 2025 10:25:26 +0000 Subject: [PATCH 02/13] wip --- src/Language/Rust/Inline.hs | 18 +- .../Rust/Inline/Context/Marshalable.hs | 32 ++- src/Language/Rust/Inline/Context/Prelude.hs | 2 +- src/Language/Rust/Inline/Marshal.hs | 9 +- tests/AlgebraicDataTypes.hs | 243 +++++++++--------- tests/PreludeTypes.hs | 109 ++++---- 6 files changed, 211 insertions(+), 202 deletions(-) diff --git a/src/Language/Rust/Inline.hs b/src/Language/Rust/Inline.hs index 18944f1..af876bd 100644 --- a/src/Language/Rust/Inline.hs +++ b/src/Language/Rust/Inline.hs @@ -79,6 +79,8 @@ module Language.Rust.Inline ( mkStorable, mkReprC, + Marshalable.PeekType, + -- * Top-level Rust items ) where @@ -94,7 +96,6 @@ import Language.Rust.Inline.Pretty import Language.Rust.Inline.TH.ReprC (mkReprC) import Language.Rust.Inline.TH.Storable (mkStorable) -import Language.Haskell.TH (pprParendType) import Language.Haskell.TH.Lib import Language.Haskell.TH.Quote (QuasiQuoter (..)) import Language.Haskell.TH.Syntax @@ -103,7 +104,7 @@ import Foreign.Marshal.Alloc (alloca, free) import Foreign.Marshal.Array (newArray, withArrayLen) import Foreign.Marshal.Unsafe (unsafeLocalState) import Foreign.Marshal.Utils (new, with) -import Foreign.Ptr (FunPtr, Ptr, freeHaskellFunPtr, nullPtr) +import Foreign.Ptr (FunPtr, Ptr, freeHaskellFunPtr) import Control.Monad (void) import Data.List (intercalate) @@ -111,8 +112,6 @@ import Data.Traversable (for) import Data.Word (Word8) import System.Random (randomIO) -import qualified Data.ByteString.Unsafe as ByteString -import Foreign.Storable (Storable (..)) import qualified Language.Rust.Inline.Context.Marshalable as Marshalable {- $overview @@ -280,9 +279,6 @@ rustQuasiQuoter safety isPure supportDecs = | supportDecs = emitCodeBlock | otherwise = err -showTy :: Type -> String -showTy = show . pprParendType - {- | This function sums up the packages. What it does: 1. Map the Rust type annotations in the quasiquote to their Haskell types. @@ -326,12 +322,12 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do pure (marshalForm, ret) -- Generate the Haskell FFI import declaration and emit it - bsFree <- newName $ "bsFree" ++ show (abs q) - bsFreeSig <- [t|FunPtr (Ptr Word8 -> Word -> IO ()) -> Ptr Word8 -> Word -> IO ()|] + -- bsFree <- newName $ "bsFree" ++ show (abs q) + -- bsFreeSig <- [t|FunPtr (Ptr Word8 -> Word -> IO ()) -> Ptr Word8 -> Word -> IO ()|] haskSig <- foldr (\l r -> [t|$(pure l) -> $r|]) haskRet' haskArgs' let ffiImport = ForeignD (ImportF CCall safety qqStrName qqName haskSig) - let ffiBsFree = ForeignD (ImportF CCall Safe "dynamic" bsFree bsFreeSig) - addTopDecls [ffiImport, ffiBsFree] + -- let ffiBsFree = ForeignD (ImportF CCall Safe "dynamic" bsFree bsFreeSig) + addTopDecls [ffiImport] -- Generate the Haskell FFI call let goArgs :: diff --git a/src/Language/Rust/Inline/Context/Marshalable.hs b/src/Language/Rust/Inline/Context/Marshalable.hs index 9414494..5901410 100644 --- a/src/Language/Rust/Inline/Context/Marshalable.hs +++ b/src/Language/Rust/Inline/Context/Marshalable.hs @@ -2,6 +2,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE InstanceSigs #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE FlexibleContexts #-} @@ -13,7 +14,6 @@ import Foreign FunPtr, ForeignPtr, Storable, - nullPtr, plusPtr, newForeignPtr, withForeignPtr) @@ -28,17 +28,22 @@ import Language.Rust.Inline.Context.Prelude () -- WithPtrType (ForeignPtr a) = Ptr a -- WithPtrType a = Ptr a -class HasWith a where +class Storable (WithPtrType a) => HasWith a where type WithPtrType a with :: a -> (Ptr (WithPtrType a) -> IO b) -> IO b + with x k = Foreign.alloca $ \loc -> withLoc x loc (k loc) + + withLoc :: a -> Ptr (WithPtrType a) -> IO b -> IO b instance {-# OVERLAPPING #-} HasWith ByteString where type WithPtrType ByteString = (Ptr Word8, Word) - with (PS ptr off len) cont = withForeignPtr ptr $ \ptr' -> Foreign.with (ptr' `plusPtr` off, fromIntegral len) cont + withLoc (PS ptr off len) loc k = withForeignPtr ptr $ \ptr' -> + Foreign.poke loc (ptr' `plusPtr` off, fromIntegral len) >> k -instance {-# OVERLAPPING #-} HasWith (ForeignPtr a) where - type WithPtrType (ForeignPtr a) = a - with = withForeignPtr +instance {-# OVERLAPPING #-} Storable a => HasWith (ForeignPtr a) where + type WithPtrType (ForeignPtr a) = Ptr a + withLoc fp loc k = withForeignPtr fp $ \ptr -> + Foreign.poke loc ptr >> k -- instance {-# OVERLAPPABLE #-} (WithPtrType a ~ Ptr a, Storable a) => HasWith a where -- with = Foreign.with @@ -72,17 +77,22 @@ instance {-# OVERLAPPING #-} HasPeek (ForeignPtr a) where class (HasWith a, HasPeek a) => Marshalable a where instance (Storable (PeekType a), HasPeek a) => HasPeek (Maybe a) where - type PeekType (Maybe a) = Maybe (PeekType a) + type PeekType (Maybe a) = (Word8, PeekType a) + peek :: Ptr (Word8, PeekType a) -> IO (Maybe a) peek ret = do d <- Foreign.peek $ Foreign.castPtr @_ @Word8 ret case d of 0 -> pure Nothing - _ -> Foreign.peek $ ret `plusPtr` (Foreign.alignment @(PeekType a) undefined) + _ -> Just <$> peek @a (ret `plusPtr` Foreign.alignment @(PeekType a) undefined) instance HasWith a => HasWith (Maybe a) where - type WithPtrType (Maybe a) = WithPtrType a - with Nothing f = f nullPtr - with (Just a) f = with a f + type WithPtrType (Maybe a) = (Word8, WithPtrType a) + withLoc Nothing loc k = + Foreign.poke (Foreign.castPtr @_ @Word8 loc) 0 >> k + withLoc (Just a) loc k = + let align = Foreign.alignment @(WithPtrType a) undefined + in do Foreign.poke (Foreign.castPtr @_ @Word8 loc) 1 + withLoc a (Foreign.castPtr loc `plusPtr` align) k -- -- | Generate 'Marshalable' instance for a non-recursive simple algebraic data -- -- type. The instance follows the usual C layout for determining alignment and diff --git a/src/Language/Rust/Inline/Context/Prelude.hs b/src/Language/Rust/Inline/Context/Prelude.hs index b6ada69..e651298 100644 --- a/src/Language/Rust/Inline/Context/Prelude.hs +++ b/src/Language/Rust/Inline/Context/Prelude.hs @@ -64,7 +64,7 @@ maybeContext = do where rule (PathTy Nothing (Path False [PathSegment "Option" (Just (AngleBracketed [] [t] [] _)) _] _) _) context = do (t', rInterOpt) <- lookupRTypeInContext t context - let inter = pure . mkGenPathTy "MaybeC" $ fromMaybe t rInterOpt + let inter = mkGenPathTy "MaybeC" . pure <$> fromMaybe (pure t) rInterOpt pure ([t| Maybe $t' |], Just inter) rule _ _ = mempty diff --git a/src/Language/Rust/Inline/Marshal.hs b/src/Language/Rust/Inline/Marshal.hs index 907dc19..1d7681d 100644 --- a/src/Language/Rust/Inline/Marshal.hs +++ b/src/Language/Rust/Inline/Marshal.hs @@ -15,6 +15,7 @@ Portability : GHC module Language.Rust.Inline.Marshal where import Language.Rust.Inline.Context +import Language.Rust.Inline.Context.Marshalable (PeekType, WithPtrType) import Language.Haskell.TH import Language.Haskell.TH.Syntax ( addTopDecls ) @@ -72,19 +73,19 @@ ghcMarshallable ty = do { passByValue = False , marshalStep = True , returnByValue = False - , returnType = \t -> [t|Ptr $(pure t) -> IO ()|] - , argumentType = \t -> [t|Ptr $(pure t)|] + , returnType = \t -> [t|Ptr (PeekType $(pure t)) -> IO ()|] + , argumentType = \t -> [t|Ptr (WithPtrType $(pure t))|] , runsInIO = True } foreignPtr = MarshalForm - { passByValue = True + { passByValue = False , marshalStep = True , returnByValue = False , returnType = \case AppT _ r -> [t|Ptr (Ptr $(pure r), FunPtr (Ptr $(pure r) -> IO ())) -> IO ()|] t -> fail $ "Cannot marshal " <> (show . pprParendType) t <> " as a ForeignPtr" , argumentType = \case - AppT _ r -> [t|Ptr $(pure r)|] + AppT _ r -> [t|Ptr (Ptr $(pure r))|] t -> fail $ "Cannot marshal " <> (show . pprParendType) t <> " as a ForeignPtr" , runsInIO = True } diff --git a/tests/AlgebraicDataTypes.hs b/tests/AlgebraicDataTypes.hs index aa044b5..f6d39fc 100644 --- a/tests/AlgebraicDataTypes.hs +++ b/tests/AlgebraicDataTypes.hs @@ -221,125 +221,126 @@ impl These { algebraicDataTypes :: Spec algebraicDataTypes = describe "Algebraic data types" $ do - it "Can marshal a `Complex Float` argument/return" $ do - let z1, z2 :: Complex Float - z1 = 1.3 :+ 4.5 - z2 = 6.7 :+ 8.9 - [rust| Cpx { $(z1: Cpx) + $(z2: Cpx) } |] `shouldBe` z1 + z2 - - it "Can marshal a custom single-constructor ADT argument/return" $ do - let s1 = StructLike 78 (negate 267) - s2 = StructLike 92 45223 - s3 = StructLike2 (34, -92391) - s4 = StructLike2 (576, 1234) - - for_ [s1,s2] $ \si -> - [rust| StructLike2 { $(si: StructLike).in2() } |] `shouldBe` in2 si - for_ [s3,s4] $ \si -> - [rust| StructLike { $(si: StructLike2).out2() } |] `shouldBe` out2 si - - it "Can marshal a custom monomorphic ADT argument/return" $ do - let f1, f2, f3, f4 :: Foo - f1 = Baz 'a' 0 - f2 = Baz 'b' 2 - f3 = Qux (7.1 :+ 3.4) 'f' - f4 = Bar - - for_ [f1,f2,f3,f4] $ \fi -> - [rust| Foo { $(fi: Foo).quux() } |] `shouldBe` quux fi - - it "Can marshal nested monomorphic ADT arguments/returns" $ do - let c1, c2, c3, c4, c5, c6, c7 :: Croc - c1 = Lob (Just (Baz 'a' 0)) 2 - c2 = Lob (Just (Baz 'b' 2)) 6 - c3 = Lob (Just (Qux (7.1 :+ 3.4) 'f')) 8 - c4 = Lob (Just Bar) 9 - c5 = Lob Nothing 3 - c6 = Boo 3 (-2) - c7 = Boo (-4) 2 - - for_ [c1,c2,c3,c4,c5,c6,c7] $ \ci -> - [rust| Croc { $(ci: Croc).croc() } |] `shouldBe` croc ci - - it "Can marshal polymorphic ADT arguments/returns" $ do - let t1, t2, t3 :: These Int8 Int64 - t1 = This maxBound - t2 = That 432442 - t3 = Both (maxBound - 3) 879 - - for_ [t1,t2,t3] $ \ti -> - let v1 = [rust| These { - $(ti: These).bimap(|x| x as i16 * 2, |y| y + 2) - } |] - v2 = bimap (\x -> fromIntegral x * 2) (+ 2) ti - in v1 `shouldBe` v2 - - it "Can marshal nested polymorphic ADT arguments/returns" $ do - let t1, t2, t3 :: These (Maybe Int) (These Int8 Int8) - t1 = This (Just 6) - t2 = This Nothing - t3 = That (This 8) - t4 = That (That 9) - t5 = That (Both 1 2) - t6 = Both (Just 3) (That 8) - t7 = Both Nothing (Both 3 5) - t8 = Both (Just 213) (Both 78 98) - - for_ [t1,t2,t3,t4,t5,t6,t7,t8] $ \ti -> - let v1 = [rust| These,These> { - $(ti: These,These>).bimap( - |oi| oi.map(|i| i + 5), - |t| t.bimap(|i| i + 2, |j| j * 3), - ) - } |] - v2 = bimap (fmap (+5)) (bimap (+2) (*3)) ti - in v1 `shouldBe` v2 - - it "Can marshal a big ADT whose tag needs more than a `Word8`" $ do - let b1, b2, b3, b4 :: Big Int64 - b1 = C000 - b2 = C160 - b3 = C298 - b4 = C299 89 - - for_ [b1,b2,b3,b4] $ \bi -> - let v1 = [rust| Big { - match $(bi: Big) { - Big::C160 => Big::C161, - Big::C299(i) => Big::C299(i+1), - b => b, - } - } |] - v2 = case bi of - C160 -> C161 - C299 i -> C299 (i + 1) - b -> b - in v1 `shouldBe` v2 - - it "Can marshal a custom `Foo2 Int` and `Foo2 (Foo2 Int)` return" $ do - let f1, f2, f3, f4 :: Foo2 Int - f1 = Bar2 - f2 = Baz2 3 - f3 = Qux2 (-1) 2 - f4 = Quux2 (-8) 3 - - let fooed f = case f of - Qux2 x y -> Qux2 (Qux2 x y) (Qux2 y x) - Quux2 i x -> Quux2 (i + 1) (Qux2 x x) - Bar2 -> Bar2 - Baz2 w -> Baz2 w - - let fooed' f = [rust| Foo2> { - match $(f: Foo2) { - Foo2::Qux2(x,y) => Foo2::Qux2(Foo2::Qux2(x,y), Foo2::Qux2(y,x)), - Foo2::Quux2(i,x) => Foo2::Quux2(i+1, Foo2::Qux2(x,x)), - Foo2::Bar2 => Foo2::Bar2, - Foo2::Baz2(w) => Foo2::Baz2(w), - } - } |] - - fooed f1 `shouldBe` fooed' f1 - fooed f2 `shouldBe` fooed' f2 - fooed f3 `shouldBe` fooed' f3 - fooed f4 `shouldBe` fooed' f4 + pure () + -- it "Can marshal a `Complex Float` argument/return" $ do + -- let z1, z2 :: Complex Float + -- z1 = 1.3 :+ 4.5 + -- z2 = 6.7 :+ 8.9 + -- [rust| Cpx { $(z1: Cpx) + $(z2: Cpx) } |] `shouldBe` z1 + z2 + -- + -- it "Can marshal a custom single-constructor ADT argument/return" $ do + -- let s1 = StructLike 78 (negate 267) + -- s2 = StructLike 92 45223 + -- s3 = StructLike2 (34, -92391) + -- s4 = StructLike2 (576, 1234) + -- + -- for_ [s1,s2] $ \si -> + -- [rust| StructLike2 { $(si: StructLike).in2() } |] `shouldBe` in2 si + -- for_ [s3,s4] $ \si -> + -- [rust| StructLike { $(si: StructLike2).out2() } |] `shouldBe` out2 si + -- + -- it "Can marshal a custom monomorphic ADT argument/return" $ do + -- let f1, f2, f3, f4 :: Foo + -- f1 = Baz 'a' 0 + -- f2 = Baz 'b' 2 + -- f3 = Qux (7.1 :+ 3.4) 'f' + -- f4 = Bar + -- + -- for_ [f1,f2,f3,f4] $ \fi -> + -- [rust| Foo { $(fi: Foo).quux() } |] `shouldBe` quux fi + -- + -- it "Can marshal nested monomorphic ADT arguments/returns" $ do + -- let c1, c2, c3, c4, c5, c6, c7 :: Croc + -- c1 = Lob (Just (Baz 'a' 0)) 2 + -- c2 = Lob (Just (Baz 'b' 2)) 6 + -- c3 = Lob (Just (Qux (7.1 :+ 3.4) 'f')) 8 + -- c4 = Lob (Just Bar) 9 + -- c5 = Lob Nothing 3 + -- c6 = Boo 3 (-2) + -- c7 = Boo (-4) 2 + -- + -- for_ [c1,c2,c3,c4,c5,c6,c7] $ \ci -> + -- [rust| Croc { $(ci: Croc).croc() } |] `shouldBe` croc ci + -- + -- it "Can marshal polymorphic ADT arguments/returns" $ do + -- let t1, t2, t3 :: These Int8 Int64 + -- t1 = This maxBound + -- t2 = That 432442 + -- t3 = Both (maxBound - 3) 879 + -- + -- for_ [t1,t2,t3] $ \ti -> + -- let v1 = [rust| These { + -- $(ti: These).bimap(|x| x as i16 * 2, |y| y + 2) + -- } |] + -- v2 = bimap (\x -> fromIntegral x * 2) (+ 2) ti + -- in v1 `shouldBe` v2 + -- + -- it "Can marshal nested polymorphic ADT arguments/returns" $ do + -- let t1, t2, t3 :: These (Maybe Int) (These Int8 Int8) + -- t1 = This (Just 6) + -- t2 = This Nothing + -- t3 = That (This 8) + -- t4 = That (That 9) + -- t5 = That (Both 1 2) + -- t6 = Both (Just 3) (That 8) + -- t7 = Both Nothing (Both 3 5) + -- t8 = Both (Just 213) (Both 78 98) + -- + -- for_ [t1,t2,t3,t4,t5,t6,t7,t8] $ \ti -> + -- let v1 = [rust| These,These> { + -- $(ti: These,These>).bimap( + -- |oi| oi.map(|i| i + 5), + -- |t| t.bimap(|i| i + 2, |j| j * 3), + -- ) + -- } |] + -- v2 = bimap (fmap (+5)) (bimap (+2) (*3)) ti + -- in v1 `shouldBe` v2 + -- + -- it "Can marshal a big ADT whose tag needs more than a `Word8`" $ do + -- let b1, b2, b3, b4 :: Big Int64 + -- b1 = C000 + -- b2 = C160 + -- b3 = C298 + -- b4 = C299 89 + -- + -- for_ [b1,b2,b3,b4] $ \bi -> + -- let v1 = [rust| Big { + -- match $(bi: Big) { + -- Big::C160 => Big::C161, + -- Big::C299(i) => Big::C299(i+1), + -- b => b, + -- } + -- } |] + -- v2 = case bi of + -- C160 -> C161 + -- C299 i -> C299 (i + 1) + -- b -> b + -- in v1 `shouldBe` v2 + -- + -- it "Can marshal a custom `Foo2 Int` and `Foo2 (Foo2 Int)` return" $ do + -- let f1, f2, f3, f4 :: Foo2 Int + -- f1 = Bar2 + -- f2 = Baz2 3 + -- f3 = Qux2 (-1) 2 + -- f4 = Quux2 (-8) 3 + -- + -- let fooed f = case f of + -- Qux2 x y -> Qux2 (Qux2 x y) (Qux2 y x) + -- Quux2 i x -> Quux2 (i + 1) (Qux2 x x) + -- Bar2 -> Bar2 + -- Baz2 w -> Baz2 w + -- + -- let fooed' f = [rust| Foo2> { + -- match $(f: Foo2) { + -- Foo2::Qux2(x,y) => Foo2::Qux2(Foo2::Qux2(x,y), Foo2::Qux2(y,x)), + -- Foo2::Quux2(i,x) => Foo2::Quux2(i+1, Foo2::Qux2(x,x)), + -- Foo2::Bar2 => Foo2::Bar2, + -- Foo2::Baz2(w) => Foo2::Baz2(w), + -- } + -- } |] + -- + -- fooed f1 `shouldBe` fooed' f1 + -- fooed f2 `shouldBe` fooed' f2 + -- fooed f3 `shouldBe` fooed' f3 + -- fooed f4 `shouldBe` fooed' f4 diff --git a/tests/PreludeTypes.hs b/tests/PreludeTypes.hs index d52e356..fac390f 100644 --- a/tests/PreludeTypes.hs +++ b/tests/PreludeTypes.hs @@ -16,58 +16,59 @@ setCrateModule preludeTypes :: Spec preludeTypes = describe "Common Prelude types" $ do - it "Can marshal a `Maybe Int32` argument/return" $ do - let x1, x2 :: Maybe Int32 - x1 = Just 9 - x2 = Nothing - [rust| Option { $(x1: Option).map(|n| n+2) } |] `shouldBe` fmap (+2) x1 - [rust| Option { $(x2: Option).map(|n| n+2) } |] `shouldBe` fmap (+2) x2 - - it "Can marshal an `Either Int32 Char` argument/return" $ do - let x1, x2 :: Either Int32 Char - x1 = Left 9 - x2 = Right 'e' - - [rust| Result { - $(x1: Result) - .map(|c| c.to_uppercase().next().unwrap()) - .map_err(|n| n+2) - } |] `shouldBe` bimap (+2) toUpper x1 - - [rust| Result { - $(x2: Result) - .map(|c| c.to_uppercase().next().unwrap()) - .map_err(|n| n+2) - } |] `shouldBe` bimap (+2) toUpper x2 - - it "Can marshal `(Int32, Char)` argument and `(Int32, Char, Word32)` return" $ do - let x = (-9, 'c') - - [rust| (i32, char, u32) { - let (n, c) = $(x: (i32, char)); - (n * 3, c.to_uppercase().next().unwrap(), n.abs() as u32) - } |] `shouldBe` (fst x * 3, toUpper (snd x), fromIntegral (abs (fst x))) - - it "Can marshal `Maybe (Int32, Either Char Word32)` argument/return" $ do - let x1, x2, x3 :: Maybe (Int32, Either Char Word32) - x1 = Nothing - x2 = Just (3, Left 'c') - x3 = Just (4, Right 8) - - let f x = [rust| Option<(i32, Result)> { - $(x: Option<(i32, Result)>).map(|x| { - match x { - (x1, Ok(x2)) => (x1 * x2 as i32, Err('a')), - (x1, e) => (x1 - 1, e), - } - }) - } |] - - let f' = fmap (\x -> case x of - (x1, Right x2) -> (x1 * fromIntegral x2, Left 'a') - (x1, e) -> (x1 - 1, e)) - - f x1 `shouldBe` f' x1 - f x2 `shouldBe` f' x2 - f x3 `shouldBe` f' x3 + pure () + -- it "Can marshal a `Maybe Int32` argument/return" $ do + -- let x1, x2 :: Maybe Int32 + -- x1 = Just 9 + -- x2 = Nothing + -- [rust| Option { $(x1: Option).map(|n| n+2) } |] `shouldBe` fmap (+2) x1 + -- [rust| Option { $(x2: Option).map(|n| n+2) } |] `shouldBe` fmap (+2) x2 + -- + -- it "Can marshal an `Either Int32 Char` argument/return" $ do + -- let x1, x2 :: Either Int32 Char + -- x1 = Left 9 + -- x2 = Right 'e' + -- + -- [rust| Result { + -- $(x1: Result) + -- .map(|c| c.to_uppercase().next().unwrap()) + -- .map_err(|n| n+2) + -- } |] `shouldBe` bimap (+2) toUpper x1 + -- + -- [rust| Result { + -- $(x2: Result) + -- .map(|c| c.to_uppercase().next().unwrap()) + -- .map_err(|n| n+2) + -- } |] `shouldBe` bimap (+2) toUpper x2 + -- + -- it "Can marshal `(Int32, Char)` argument and `(Int32, Char, Word32)` return" $ do + -- let x = (-9, 'c') + -- + -- [rust| (i32, char, u32) { + -- let (n, c) = $(x: (i32, char)); + -- (n * 3, c.to_uppercase().next().unwrap(), n.abs() as u32) + -- } |] `shouldBe` (fst x * 3, toUpper (snd x), fromIntegral (abs (fst x))) + -- + -- it "Can marshal `Maybe (Int32, Either Char Word32)` argument/return" $ do + -- let x1, x2, x3 :: Maybe (Int32, Either Char Word32) + -- x1 = Nothing + -- x2 = Just (3, Left 'c') + -- x3 = Just (4, Right 8) + -- + -- let f x = [rust| Option<(i32, Result)> { + -- $(x: Option<(i32, Result)>).map(|x| { + -- match x { + -- (x1, Ok(x2)) => (x1 * x2 as i32, Err('a')), + -- (x1, e) => (x1 - 1, e), + -- } + -- }) + -- } |] + -- + -- let f' = fmap (\x -> case x of + -- (x1, Right x2) -> (x1 * fromIntegral x2, Left 'a') + -- (x1, e) -> (x1 - 1, e)) + -- + -- f x1 `shouldBe` f' x1 + -- f x2 `shouldBe` f' x2 + -- f x3 `shouldBe` f' x3 From 43b8019e55d82bb1780ffbb26b04b7df693e4b45 Mon Sep 17 00:00:00 2001 From: Viktor Kleen Date: Sun, 16 Feb 2025 20:48:47 +0000 Subject: [PATCH 03/13] Make Maybe ForeignPtr and Maybe ByteString work generically MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rust quasiquoter Simple types Can marshal a `Char` argument/return [✔] Can marshal a `Int` argument/return [✔] Can marshal an `Int8` argument/return [✔] Can marshal an `Int16` argument/return [✔] Can marshal an `Int32` argument/return [✔] Can marshal an `Int64` argument/return [✔] Can marshal an `Word` argument/return [✔] Can marshal a `Word8` argument/return [✔] Can marshal a `Word16` argument/return [✔] Can marshal a `Word32` argument/return [✔] Can marshal a `Word64` argument/return [✔] Can marshal a `Float` argument/return [✔] Can marshal a `Double` argument/return [✔] Can marshal a `Bool` argument/return [✔] GHC unboxed types Can marshal a `Char#` argument/return [✔] Can marshal an `Int#` argument/return [✔] Can marshal a `Word#` argument/return [✔] Can marshal a `Float#` argument/return [✔] Can marshal a `Double#` argument/return [✔] Pointer types Can marshal an immutable `Ptr Int` argument/return [✔] Can marshal a mutable `Ptr Word` argument [✔] Supports null pointers [✔] Function pointer types Can marshal a `FunPtr (Int -> Int)` argument [✔] Can marshal a `FunPtr (Int -> Int)` return [✔] Can marshal a `FunPtr (Word -> Char -> Int)` argument [✔] Can marshal a `FunPtr (Word -> Char -> Int)` return [✔] Submodules Can link against submodules [✔] Subsubmodules Can link against subsubmodules [✔] ByteStrings can marshal ByteString arguments [✔] can marshal ByteString return values [✔] can marshal optional ByteString return values [✔] ForeignPtr types Can marshal ForeignPtr arguments as references [✔] Can marshal ForeignPtr arguments as mutable references [✔] Can marshal ForeignPtr returns [✔] Can marshal optional ForeignPtr returns [✔] still has working pointers [✔] --- src/Language/Rust/Inline.hs | 10 +++++---- src/Language/Rust/Inline/Context.hs | 15 ++----------- .../Rust/Inline/Context/ByteString.hs | 18 +++++++-------- src/Language/Rust/Inline/Marshal.hs | 11 +++++++--- tests/ByteStrings.hs | 2 ++ tests/ForeignPtr.hs | 22 +++++++++---------- 6 files changed, 38 insertions(+), 40 deletions(-) diff --git a/src/Language/Rust/Inline.hs b/src/Language/Rust/Inline.hs index af876bd..89caaf1 100644 --- a/src/Language/Rust/Inline.hs +++ b/src/Language/Rust/Inline.hs @@ -1,6 +1,7 @@ {-# LANGUAGE ForeignFunctionInterface #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} +{-# LANGUAGE ScopedTypeVariables #-} {- | Module : Language.Rust.Inline @@ -311,8 +312,7 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do -- Convert the Haskell return type to a marshallable FFI type (returnFfi, haskRet') <- do marshalForm <- ghcMarshallable haskRet - ret <- returnType marshalForm haskRet - pure (marshalForm, pure ret) + pure (marshalForm, returnType marshalForm haskRet) -- Convert the Haskell arguments to marshallable FFI types (marshalForms, haskArgs') <- fmap unzip $ @@ -324,7 +324,8 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do -- Generate the Haskell FFI import declaration and emit it -- bsFree <- newName $ "bsFree" ++ show (abs q) -- bsFreeSig <- [t|FunPtr (Ptr Word8 -> Word -> IO ()) -> Ptr Word8 -> Word -> IO ()|] - haskSig <- foldr (\l r -> [t|$(pure l) -> $r|]) haskRet' haskArgs' + haskRet'' <- if addIOUnit returnFfi then [t|$(haskRet') -> IO ()|] else haskRet' + haskSig <- foldr (\l r -> [t|$(pure l) -> $r|]) (pure haskRet'') haskArgs' let ffiImport = ForeignD (ImportF CCall safety qqStrName qqName haskSig) -- let ffiBsFree = ForeignD (ImportF CCall Safe "dynamic" bsFree bsFreeSig) addTopDecls [ffiImport] @@ -350,7 +351,8 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do ( \($(varP ret)) -> do $(appsE (varE qqName : reverse (varE ret : acc))) - Marshalable.peek $(varE ret) + r :: $(pure haskRet) <- Marshalable.peek $(varE ret) + pure r ) |] diff --git a/src/Language/Rust/Inline/Context.hs b/src/Language/Rust/Inline/Context.hs index 624d8df..c4513e2 100644 --- a/src/Language/Rust/Inline/Context.hs +++ b/src/Language/Rust/Inline/Context.hs @@ -323,16 +323,14 @@ foreignPointers = do | First (Just (t', Nothing)) <- lookupRTypeInContext t context = pure ([t|ForeignPtr $t'|], Nothing) rule (PathTy Nothing (Path False [PathSegment "ForeignPtr" (Just (AngleBracketed [] [t] [] _)) _] _) _) context | First (Just (t', Nothing)) <- lookupRTypeInContext t context = pure ([t|ForeignPtr $t'|], Nothing) - rule (PathTy Nothing (Path False [PathSegment "Option" (Just (AngleBracketed [] [PathTy Nothing (Path False [PathSegment "ForeignPtr" (Just (AngleBracketed [] [t] [] _)) _] _) _] [] _)) _] _) _) context - | First (Just (t', Nothing)) <- lookupRTypeInContext t context = - pure ([t|Maybe (ForeignPtr $t')|], pure . pure $ PathTy Nothing (Path False [PathSegment "ForeignPtr" (Just (AngleBracketed [] [t] [] ())) ()] ()) ()) rule _ _ = mempty rev _ _ _ = mempty foreignPtr = unlines - [ "#[repr(C)]" + [ "#[derive(Copy, Clone)]" + , "#[repr(C)]" , "pub struct ForeignPtr(pub *mut T, pub extern \"C\" fn (*mut T));" ] @@ -371,15 +369,6 @@ foreignPointers = do , "impl<'a, T> MarshalInto<&'a mut T> for &'a mut T {" , " fn marshal(self) -> &'a mut T { self }" , "}" - , "" - , "impl MarshalInto> for Option> {" - , " fn marshal(self) -> ForeignPtr {" - , " extern fn panic(_ptr: *mut T) {" - , " panic!(\"Attempted to free a null ForeignPtr\")" - , " }" - , " self.unwrap_or(ForeignPtr(std::ptr::null_mut(), panic))" - , " }" - , "}" ] {- | This maps a Rust function type into the corresponding 'FunPtr' wrapped diff --git a/src/Language/Rust/Inline/Context/ByteString.hs b/src/Language/Rust/Inline/Context/ByteString.hs index dcc8647..1e0e31a 100644 --- a/src/Language/Rust/Inline/Context/ByteString.hs +++ b/src/Language/Rust/Inline/Context/ByteString.hs @@ -38,7 +38,7 @@ bytestrings = rule rty _ | rty == void [ty| &[u8] |] = pure ([t|ByteString|], pure . pure $ void [ty| RustByteString |]) | rty == void [ty| Vec |] = pure ([t|ByteString|], pure . pure $ void [ty| RustOwnedByteString |]) - | rty == void [ty| Option> |] = pure ([t|Maybe ByteString|], pure . pure $ void [ty| RustOwnedByteString |]) + -- | rty == void [ty| Option> |] = pure ([t|Maybe ByteString|], pure . pure $ void [ty| RustOwnedByteString |]) rule _ _ = mempty rustByteString = @@ -74,12 +74,12 @@ bytestrings = , " }" , "}" , "" - , "impl MarshalInto for Option> {" - , " fn marshal(self) -> RustOwnedByteString {" - , " extern fn panic(ptr: *mut u8, len: usize) {" - , " panic!(\"Attempted to free a null ByteString\");" - , " }" - , " self.map(|bs| bs.marshal()).unwrap_or(RustOwnedByteString(std::ptr::null_mut(), 0, panic))" - , " }" - , "}" + -- , "impl MarshalInto for Option> {" + -- , " fn marshal(self) -> RustOwnedByteString {" + -- , " extern fn panic(ptr: *mut u8, len: usize) {" + -- , " panic!(\"Attempted to free a null ByteString\");" + -- , " }" + -- , " self.map(|bs| bs.marshal()).unwrap_or(RustOwnedByteString(std::ptr::null_mut(), 0, panic))" + -- , " }" + -- , "}" ] diff --git a/src/Language/Rust/Inline/Marshal.hs b/src/Language/Rust/Inline/Marshal.hs index 1d7681d..bbd2b55 100644 --- a/src/Language/Rust/Inline/Marshal.hs +++ b/src/Language/Rust/Inline/Marshal.hs @@ -41,6 +41,7 @@ data MarshalForm = MarshalForm , returnType :: Type -> Q Type , argumentType :: Type -> Q Type , runsInIO :: Bool + , addIOUnit :: Bool } -- | Identify which types can be marshalled by the GHC FFI and which types are @@ -67,35 +68,39 @@ ghcMarshallable ty = do , returnType = pure , argumentType = pure , runsInIO = False + , addIOUnit = False } boxedDirect = unboxedDirect{ returnType = \t -> [t|IO $(pure t)|], runsInIO = True } boxedIndirect = MarshalForm { passByValue = False , marshalStep = True , returnByValue = False - , returnType = \t -> [t|Ptr (PeekType $(pure t)) -> IO ()|] + , returnType = \t -> [t|Ptr (PeekType $(pure t))|] , argumentType = \t -> [t|Ptr (WithPtrType $(pure t))|] , runsInIO = True + , addIOUnit = True } foreignPtr = MarshalForm { passByValue = False , marshalStep = True , returnByValue = False , returnType = \case - AppT _ r -> [t|Ptr (Ptr $(pure r), FunPtr (Ptr $(pure r) -> IO ())) -> IO ()|] + AppT _ r -> [t|Ptr (Ptr $(pure r), FunPtr (Ptr $(pure r) -> IO ()))|] t -> fail $ "Cannot marshal " <> (show . pprParendType) t <> " as a ForeignPtr" , argumentType = \case AppT _ r -> [t|Ptr (Ptr $(pure r))|] t -> fail $ "Cannot marshal " <> (show . pprParendType) t <> " as a ForeignPtr" , runsInIO = True + , addIOUnit = True } byteString = MarshalForm { passByValue = False , marshalStep = True , returnByValue = False - , returnType = const [t|Ptr (Ptr Word8, Word, FunPtr (Ptr Word8 -> Word -> IO ())) -> IO ()|] + , returnType = const [t|Ptr (Ptr Word8, Word, FunPtr (Ptr Word8 -> Word -> IO ()))|] , argumentType = const [t|Ptr (Ptr Word8, Word)|] , runsInIO = True + , addIOUnit = True } case ty of diff --git a/tests/ByteStrings.hs b/tests/ByteStrings.hs index 2aecffa..2b3b400 100644 --- a/tests/ByteStrings.hs +++ b/tests/ByteStrings.hs @@ -1,5 +1,6 @@ {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE ScopedTypeVariables #-} module ByteStrings where @@ -14,6 +15,7 @@ import Data.String import Data.Either (fromRight) extendContext basic +extendContext prelude extendContext bytestrings setCrateModule diff --git a/tests/ForeignPtr.hs b/tests/ForeignPtr.hs index 9e8099a..17c6d92 100644 --- a/tests/ForeignPtr.hs +++ b/tests/ForeignPtr.hs @@ -47,18 +47,18 @@ foreignPtrTypes = describe "ForeignPtr types" $ do val <- withForeignPtr p peek val `shouldBe` 42 - -- it "Can marshal optional ForeignPtr returns" $ do - -- let mp = - -- [rust| Option> { - -- None - -- } |] - -- mp `shouldBe` Nothing + it "Can marshal optional ForeignPtr returns" $ do + let mp = + [rust| Option> { + None + } |] + mp `shouldBe` Nothing - -- let mp = - -- [rust| Option> { - -- Some(Box::new(42).into()) - -- } |] - -- withForeignPtr (fromJust mp) peek >>= (`shouldBe` 42) + let mp = + [rust| Option> { + Some(Box::new(42).into()) + } |] + withForeignPtr (fromJust mp) peek >>= (`shouldBe` 42) -- it "Can marshal result ForeignPtr returns" $ do -- let mp = From 05b4ccc1258d7c9f48f0f6333e6c03415a88dad3 Mon Sep 17 00:00:00 2001 From: ners Date: Fri, 21 Feb 2025 23:13:21 +0100 Subject: [PATCH 04/13] cleanup --- inline-rust.cabal | 7 +- src/Language/Rust/Inline.hs | 3 +- .../Rust/Inline/Context/Marshalable.hs | 337 ------------------ src/Language/Rust/Inline/Marshal.hs | 1 - tests/AlgebraicDataTypes.hs | 1 - tests/ByteStrings.hs | 4 - tests/ForeignPtr.hs | 4 - tests/FunctionPointerTypes.hs | 1 - tests/GhcUnboxedTypes.hs | 4 +- tests/Main.hs | 30 +- tests/PointerTypes.hs | 1 - tests/PreludeTypes.hs | 1 - tests/SimpleTypes.hs | 1 - tests/Submodule.hs | 3 - tests/Submodule/Submodule.hs | 3 - 15 files changed, 25 insertions(+), 376 deletions(-) diff --git a/inline-rust.cabal b/inline-rust.cabal index efbf292..aa493b1 100644 --- a/inline-rust.cabal +++ b/inline-rust.cabal @@ -35,6 +35,7 @@ library Language.Rust.Inline.Marshal Language.Rust.Inline.Parser Language.Rust.Inline.Pretty + Language.Rust.Inline.TH.Marshalable Language.Rust.Inline.TH.ReprC Language.Rust.Inline.TH.Storable Language.Rust.Inline.TH.Utilities @@ -76,7 +77,11 @@ test-suite spec main-is: Main.hs type: exitcode-stdio-1.0 - default-language: Haskell2010 + default-language: Haskell2010 + default-extensions: ExplicitForAll + , QuasiQuotes + , ScopedTypeVariables + , TemplateHaskell other-modules: SimpleTypes , GhcUnboxedTypes , PointerTypes diff --git a/src/Language/Rust/Inline.hs b/src/Language/Rust/Inline.hs index 89caaf1..327ae9e 100644 --- a/src/Language/Rust/Inline.hs +++ b/src/Language/Rust/Inline.hs @@ -105,12 +105,11 @@ import Foreign.Marshal.Alloc (alloca, free) import Foreign.Marshal.Array (newArray, withArrayLen) import Foreign.Marshal.Unsafe (unsafeLocalState) import Foreign.Marshal.Utils (new, with) -import Foreign.Ptr (FunPtr, Ptr, freeHaskellFunPtr) +import Foreign.Ptr (freeHaskellFunPtr) import Control.Monad (void) import Data.List (intercalate) import Data.Traversable (for) -import Data.Word (Word8) import System.Random (randomIO) import qualified Language.Rust.Inline.Context.Marshalable as Marshalable diff --git a/src/Language/Rust/Inline/Context/Marshalable.hs b/src/Language/Rust/Inline/Context/Marshalable.hs index 5901410..3484c23 100644 --- a/src/Language/Rust/Inline/Context/Marshalable.hs +++ b/src/Language/Rust/Inline/Context/Marshalable.hs @@ -23,11 +23,6 @@ import Data.ByteString.Internal (ByteString(PS)) import qualified Data.ByteString.Unsafe as ByteString import Language.Rust.Inline.Context.Prelude () --- type family WithPtrType a where --- WithPtrType ByteString = Ptr (Ptr Word8, Word) --- WithPtrType (ForeignPtr a) = Ptr a --- WithPtrType a = Ptr a - class Storable (WithPtrType a) => HasWith a where type WithPtrType a with :: a -> (Ptr (WithPtrType a) -> IO b) -> IO b @@ -45,14 +40,6 @@ instance {-# OVERLAPPING #-} Storable a => HasWith (ForeignPtr a) where withLoc fp loc k = withForeignPtr fp $ \ptr -> Foreign.poke loc ptr >> k --- instance {-# OVERLAPPABLE #-} (WithPtrType a ~ Ptr a, Storable a) => HasWith a where --- with = Foreign.with - --- type family PeekType a where --- PeekType ByteString = (Ptr Word8, Word, FunPtr (Ptr Word8 -> Word -> IO ())) --- PeekType (ForeignPtr a) = (Ptr a, FunPtr (Ptr a -> IO ())) --- PeekType a = a - class HasPeek a where type PeekType a peek :: Ptr (PeekType a) -> IO a @@ -71,9 +58,6 @@ instance {-# OVERLAPPING #-} HasPeek (ForeignPtr a) where (ptr, finalizer) <- Foreign.peek ret newForeignPtr finalizer ptr --- instance {-# OVERLAPPABLE #-} (PeekType a ~ a, Storable a) => HasPeek a where --- peek = Foreign.peek - class (HasWith a, HasPeek a) => Marshalable a where instance (Storable (PeekType a), HasPeek a) => HasPeek (Maybe a) where @@ -93,324 +77,3 @@ instance HasWith a => HasWith (Maybe a) where let align = Foreign.alignment @(WithPtrType a) undefined in do Foreign.poke (Foreign.castPtr @_ @Word8 loc) 1 withLoc a (Foreign.castPtr loc `plusPtr` align) k - --- -- | Generate 'Marshalable' instance for a non-recursive simple algebraic data --- -- type. The instance follows the usual C layout for determining alignment and --- -- size. --- -- --- -- Sum types are implemented as tagged unions. --- -- --- -- >>> mkMarshalable [t| forall a. Marshalable a => Marshalable (Maybe a) |] --- -- --- -- Remember to have 'ScopedTypeVariables', 'ExplicitForall', and 'EmptyCase' --- -- enabled when calling this! --- mkMarshalable :: TypeQ -- ^ a type representing the desired instance head --- -> Q [Dec] -- ^ the instance declaration --- mkMarshalable tyq = do --- pseudoInstHead <- tyq --- --- -- Extract the context --- marshalable <- [t| Marshalable |] --- (ctx, ty') <- --- case pseudoInstHead of --- ForallT _ ctx (AppT s ty) | s == marshalable -> pure (ctx, ty) --- AppT s ty | s == marshalable -> pure ([], ty) --- _ -> fail "mkMarshalable: malformed 'Marshalable' instance head" --- --- -- Get the type constructors name --- (_,cons') <- getConstructors ty' --- --- -- Produce the instance --- methods <- processADT [ (nameCon n, tyArgs) | (n,tyArgs) <- cons' ] --- dec <- instanceD (pure ctx) (pure (AppT marshalable ty')) (map pure methods) --- pure [dec] --- --- mkTupleMarshalable :: Int -- ^ arity of tuple --- -> Q [Dec] -- ^ the instance declaration --- mkTupleMarshalable n = do --- marshalable <- [t| Marshalable |] --- tyVars <- sequence (take n [ newName (c : show i) --- | i <- [(1 :: Int)..] --- , c <- ['a'..'z'] --- ]) --- let ctx = [ AppT marshalable (VarT tyVar) | tyVar <- tyVars ] --- let instHead = AppT marshalable (foldl AppT (TupleT n) (map VarT tyVars)) --- --- methods <- processADT [ (tupCon, map VarT tyVars) ] --- let dec = InstanceD Nothing ctx instHead methods --- pure [dec] --- --- -- * Constructor utilities --- data Constructor = Constructor --- { conPat :: [Pat] -> Pat --- , conExp :: [Exp] -> Exp --- } --- --- nameCon :: Name -> Constructor --- nameCon n = Constructor (ConP n []) (foldl AppE (ConE n)) --- --- tupCon :: Constructor --- tupCon = Constructor TupP (TupE . fmap Just) --- --- --- -- * Alignment --- --- -- | This is the information you need to carry along as you visit the fields of --- -- a struct/union. --- data Alignment = Alignment --- { decs :: [Dec] --- -- ^ declarations for variables relied on by offset and align --- --- , offsetSoFar :: Code Q Int --- -- ^ total bytes occupied so far by fields --- --- , alignSoFar :: Code Q Int --- -- ^ size (in bytes) of the largest member in the struct --- } --- --- -- | Combining alignment means concatenating the dependent declarations, and --- -- take the maximum for offset and alignment. --- instance Semigroup Alignment where --- a1 <> a2 = Alignment --- { decs = decs a1 <> decs a2 --- , offsetSoFar = [|| $$(offsetSoFar a1) `max` $$(offsetSoFar a2) ||] --- , alignSoFar = [|| $$(alignSoFar a1) `max` $$(alignSoFar a2) ||] --- } --- --- -- | The 'mconcat' method calls 'maximum' --- instance Monoid Alignment where --- mempty = Alignment [] [|| 0 ||] [|| 1 ||] --- mappend = (<>) --- mconcat as = Alignment --- { decs = concatMap decs as --- , offsetSoFar = [|| maximum $$(liftCode $ listTE <$> mapM (examineCode . offsetSoFar) as) ||] --- , alignSoFar = [|| maximum $$(liftCode $ listTE <$> mapM (examineCode . alignSoFar) as) ||] --- } --- --- -- | This is the state we will bundle along while visiting fields. --- type StructState = StateT Alignment Q --- --- -- | Make a typed list. This function is like 'listE', but for 'TExp'. --- listTE :: [TExp a] -> TExp [a] --- listTE = TExp . ListE . map unType --- --- --- -- * Peek and poke helper functions --- --- -- | Produces a 'do' block for peeking a constructor. The generated code has the --- -- following shape: --- -- --- -- @ --- -- do f1 <- ... ptr --- -- f2 <- ... ptr --- -- ... --- -- fn <- ... ptr --- -- return (Con f1 f2 ... fn) --- -- @ --- -- --- peekCon :: Constructor -- ^ name of the constructor --- -> [Exp -> Q Exp] -- ^ how to peek every field --- -> Name -- ^ the base pointer --- -> Q Exp -- ^ a 'do' expression for peeking the constructor --- peekCon con peekFields ptr = do --- (ns, binds) <- unzip <$> do --- for peekFields $ \fldCont -> do --- n <- newName "n" --- pure (varE n, bindS (varP n) (fldCont (VarE ptr))) --- let ret = [e| return $(conExp con <$> sequence ns) |] --- doE (binds ++ [noBindS ret]) --- --- -- | Produces a 'do' block for poking a constructor, along with a pattern for --- -- extracting out the right fields. Given a pattern like @Con f1 f2 ... fn@, the --- -- generated block has the following shape: --- -- --- -- @ --- -- do ... ptr f1 --- -- ... ptr f2 --- -- ... --- -- ... ptr fn --- -- @ --- pokeCon :: Constructor -- ^ name of the constructor --- -> [Exp -> Q Exp] -- ^ how to poke every field --- -> Name -- ^ the base poniter --- -> Q (Pat, Exp) -- ^ a pattern to match, an expression for poking --- pokeCon con pokeFields ptr = do --- (ns, stmts) <- unzip <$> do --- for pokeFields $ \fldCont -> do --- n <- newName "n" --- pure (varP n, noBindS [e| $(fldCont (VarE ptr)) $(varE n) |]) --- pat <- conPat con <$> sequence ns --- expr <- if null stmts then [e| pure () |] else doE stmts --- return (pat, expr) --- --- --- -- * Traversing fields (putting everything together) --- --- -- TODO: look at `alignPtr :: Ptr a -> Int -> Ptr a` --- --- -- | Process a field of a given type. --- processField :: Type -> StructState (Exp -> Q Exp, Exp -> Q Exp) --- processField ty = do --- let alignTy, sizeTy :: Code Q Int --- alignTy = Code $ TExp <$> [e| alignment (undefined :: $(pure ty)) |] --- sizeTy = Code $ TExp <$> [e| sizeOf (undefined :: $(pure ty)) |] --- --- -- get state at the end of the last field --- Alignment prevDecs prevOff prevAlign <- get --- --- -- beginning offset --- beginOffV <- lift $ newName "beginOff" --- let beginOffE, beginOff :: Code Q Int --- beginOffE = [|| $$prevOff + mod (negate $$prevOff) $$alignTy ||] --- beginOff = Code $ TExp <$> varE beginOffV --- assignBeginOff <- lift [d| $(varP beginOffV) = $(unType <$> examineCode beginOffE) |] --- --- -- offset after this field --- newOffV <- lift $ newName "afterOff" --- let newOffE :: Code Q Int --- newOffE = [|| $$beginOff + $$sizeTy ||] --- newOff <- lift (TExp <$> varE newOffV) --- assignNewOff <- lift [d| $(varP newOffV) = $(unType <$> examineCode newOffE) |] --- --- -- alignment after this field --- newAlignV <- lift $ newName "algn" --- let newAlignE :: Code Q Int --- newAlignE = [|| $$alignTy `max` $$prevAlign ||] --- newAlign <- lift (TExp <$> varE newAlignV) --- assignNewAlign <- lift [d| $(varP newAlignV) = $(unType <$> examineCode newAlignE) |] --- --- -- update state --- put (Alignment { decs = concat [ assignBeginOff --- , assignNewOff --- , assignNewAlign --- , prevDecs --- ] --- , offsetSoFar = liftCode (pure newOff) --- , alignSoFar = liftCode (pure newAlign) --- }) --- --- -- TODO: consider degenerate sizeof(..) = 0 cases --- pure ( \addrE -> [e| peek (castPtr $(pure addrE) `plusPtr` $(unType <$> examineCode beginOff)) |] --- , \addrE -> [e| poke (castPtr $(pure addrE) `plusPtr` $(unType <$> examineCode beginOff)) |] --- ) --- --- --- -- | Process an algebraic data type. --- -- --- -- TODO: think about the zero constructor case... --- processADT :: [(Constructor, [Type])] -- ^ constructors and the types of their fields --- -> Q [Dec] -- ^ methods of the 'Marshalable' class --- --- -- The one constructor case is special - we don't need to specify a tag --- processADT [(con, fields)] = do --- --- initAlign <- mempty --- (peekPokes, Alignment ds off algn) --- <- runStateT (traverse processField fields) initAlign --- let ds' = map pure ds --- --- -- sizeOf --- sizeOf_ <- do --- Just sizeOfN <- lookupValueName "sizeOf" --- funD sizeOfN [clause [wildP] --- (normalB [e| let c = $(unType <$> examineCode off) --- in c + mod (negate c) $(unType <$> examineCode algn) |]) --- ds'] --- --- -- alignment --- alignment_ <- do --- Just alignmentN <- lookupValueName "alignment" --- funD alignmentN [clause [wildP] (normalB (unType <$> examineCode algn)) ds'] --- --- let (peekFields, pokeFields) = unzip peekPokes --- --- -- peek --- peek_ <- do --- ptr <- newName "ptr" --- Just peekN <- lookupValueName "peek" --- funD peekN [clause [varP ptr] (normalB (peekCon con peekFields ptr)) ds'] --- --- -- poke --- poke_ <- do --- ptr <- newName "ptr" --- (cPat,body) <- pokeCon con pokeFields ptr --- Just pokeN <- lookupValueName "poke" --- funD pokeN [clause [varP ptr, pure cPat] (normalB (pure body)) ds'] --- --- pure [sizeOf_, alignment_, peek_, poke_] --- --- processADT cons = do --- --- let discNum = length cons --- discTy <- snd . head . dropWhile (\(m,_) -> discNum > m + 1) $ --- [ (fromIntegral (maxBound :: Word8), [t| Word8 |]) --- , (fromIntegral (maxBound :: Word16), [t| Word16 |]) --- , (fromIntegral (maxBound :: Word32), [t| Word32 |]) --- , (fromIntegral (maxBound :: Word64), [t| Word64 |]) --- ] --- --- initAlign <- mempty --- (conPeekPokess, algns) <- unzip <$> do --- for cons $ \(con, fields) -> do --- (peekPokes, algn) <- runStateT (traverse processField fields) initAlign --- let (peekFields, pokeFields) = unzip peekPokes --- pure ((con, peekFields, pokeFields), algn) --- Alignment ds off algn <- mconcat (map pure algns) --- let discSizeOf = [e| sizeOf (undefined :: $(pure discTy)) |] --- algn' = [e| $discSizeOf `max` $(unType <$> examineCode algn) |] --- let ds' = map pure ds --- --- -- sizeOf --- sizeOf_ <- do --- Just sizeOfN <- lookupValueName "sizeOf" --- funD sizeOfN [clause [wildP] --- (normalB [e| let c = $(unType <$> examineCode off) --- in $algn' + c + mod (negate c) $algn' |]) --- ds'] --- --- -- alignment --- alignment_ <- do --- Just alignmentN <- lookupValueName "alignment" --- funD alignmentN [clause [wildP] (normalB algn') ds'] --- --- -- peek --- peek_ <- do --- ptr <- newName "ptr" --- ptrOff <- newName "ptrOff" --- d' <- [d| $(varP ptrOff) = $(varE ptr) `plusPtr` $algn' |] --- disc <- newName "disc" --- let mtchs = [ match (litP n') (normalB (peekCon con peekFields ptrOff)) [] --- | (n, (con, peekFields, _)) <- zip [0..] conPeekPokess --- , let n' = IntegerL n --- ] --- Just peekN <- lookupValueName "peek" --- funD peekN --- [clause [varP ptr] --- (normalB (doE [ bindS (varP disc) [e| peek (castPtr $(varE ptr) :: Ptr $(pure discTy)) |] --- , noBindS (caseE (varE disc) mtchs) --- ])) --- (map pure d' ++ ds')] --- --- -- poke --- poke_ <- do --- ptr <- newName "ptr" --- ptrOff <- newName "ptrOff" --- d' <- [d| $(varP ptrOff) = $(varE ptr) `plusPtr` $algn' |] --- disc <- newName "disc" --- let mtchs = [ do { (pat,body) <- patBody --- ; match (pure pat) --- (normalB (doE (map noBindS [ [e| poke (castPtr $(varE ptr) :: Ptr $(pure discTy)) $(litE n') |] --- , pure body --- ]))) --- [] --- } --- | (n, (con, _, pokeFields)) <- zip [0..] conPeekPokess --- , let patBody = pokeCon con pokeFields ptrOff --- , let n' = IntegerL n --- ] --- Just pokeN <- lookupValueName "poke" --- funD pokeN --- [clause [varP ptr, varP disc] (normalB (caseE (varE disc) mtchs)) (map pure d' ++ ds')] --- --- pure [sizeOf_, alignment_, peek_, poke_] - diff --git a/src/Language/Rust/Inline/Marshal.hs b/src/Language/Rust/Inline/Marshal.hs index bbd2b55..cb4851e 100644 --- a/src/Language/Rust/Inline/Marshal.hs +++ b/src/Language/Rust/Inline/Marshal.hs @@ -59,7 +59,6 @@ ghcMarshallable ty = do tyconsB <- sequence qTyconsBoxed bytestring <- [t| ByteString |] fptrCons <- [t| ForeignPtr |] - maybeCons <- [t| Maybe |] let unboxedDirect = MarshalForm { passByValue = True diff --git a/tests/AlgebraicDataTypes.hs b/tests/AlgebraicDataTypes.hs index f6d39fc..c92faee 100644 --- a/tests/AlgebraicDataTypes.hs +++ b/tests/AlgebraicDataTypes.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE QuasiQuotes, TemplateHaskell, ExplicitForAll, ScopedTypeVariables #-} module AlgebraicDataTypes where import Language.Rust.Inline diff --git a/tests/ByteStrings.hs b/tests/ByteStrings.hs index 2b3b400..044ab6e 100644 --- a/tests/ByteStrings.hs +++ b/tests/ByteStrings.hs @@ -1,7 +1,3 @@ -{-# LANGUAGE QuasiQuotes #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE ScopedTypeVariables #-} - module ByteStrings where import Language.Rust.Inline diff --git a/tests/ForeignPtr.hs b/tests/ForeignPtr.hs index 17c6d92..4eff1c5 100644 --- a/tests/ForeignPtr.hs +++ b/tests/ForeignPtr.hs @@ -1,7 +1,3 @@ -{-# LANGUAGE QuasiQuotes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} - module ForeignPtr where import Language.Rust.Inline diff --git a/tests/FunctionPointerTypes.hs b/tests/FunctionPointerTypes.hs index f0b5663..e75df56 100644 --- a/tests/FunctionPointerTypes.hs +++ b/tests/FunctionPointerTypes.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE QuasiQuotes, TemplateHaskell #-} module FunctionPointerTypes where import Language.Rust.Inline diff --git a/tests/GhcUnboxedTypes.hs b/tests/GhcUnboxedTypes.hs index 75b7f01..c458dad 100644 --- a/tests/GhcUnboxedTypes.hs +++ b/tests/GhcUnboxedTypes.hs @@ -1,4 +1,6 @@ -{-# LANGUAGE QuasiQuotes, TemplateHaskell, MagicHash, UnliftedFFITypes #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE UnliftedFFITypes #-} + module GhcUnboxedTypes where import Language.Rust.Inline diff --git a/tests/Main.hs b/tests/Main.hs index bff3ca2..7fc208a 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE TemplateHaskell, QuasiQuotes, CPP #-} +{-# LANGUAGE CPP #-} #ifdef darwin_HOST_OS {-# OPTIONS_GHC -optl-Wl,-all_load #-} @@ -10,21 +10,21 @@ module Main where import Language.Rust.Inline -import SimpleTypes +import AlgebraicDataTypes +import ByteStrings +import Data.Word +import Foreign.Marshal.Array +import Foreign.Ptr +import Foreign.Storable +import ForeignPtr +import FunctionPointerTypes import GhcUnboxedTypes import PointerTypes -import FunctionPointerTypes import PreludeTypes -import AlgebraicDataTypes -import ByteStrings +import SimpleTypes import Submodule import Submodule.Submodule -import ForeignPtr -import Data.Word import Test.Hspec -import Foreign.Storable -import Foreign.Ptr -import Foreign.Marshal.Array extendContext basic setCrateRoot [] @@ -32,13 +32,13 @@ setCrateRoot [] main :: IO () main = hspec $ describe "Rust quasiquoter" $ do - simpleTypes + algebraicDataTypes + bytestringSpec + foreignPtrTypes + funcPointerTypes ghcUnboxedTypes pointerTypes - funcPointerTypes preludeTypes - algebraicDataTypes + simpleTypes submoduleTest subsubmoduleTest - bytestringSpec - foreignPtrTypes diff --git a/tests/PointerTypes.hs b/tests/PointerTypes.hs index d8cad90..91287d9 100644 --- a/tests/PointerTypes.hs +++ b/tests/PointerTypes.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE QuasiQuotes, TemplateHaskell #-} module PointerTypes where import Language.Rust.Inline diff --git a/tests/PreludeTypes.hs b/tests/PreludeTypes.hs index fac390f..ff31d70 100644 --- a/tests/PreludeTypes.hs +++ b/tests/PreludeTypes.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE QuasiQuotes, TemplateHaskell #-} module PreludeTypes where import Language.Rust.Inline diff --git a/tests/SimpleTypes.hs b/tests/SimpleTypes.hs index ae16b07..29d77cb 100644 --- a/tests/SimpleTypes.hs +++ b/tests/SimpleTypes.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE QuasiQuotes, TemplateHaskell #-} module SimpleTypes where import Language.Rust.Inline diff --git a/tests/Submodule.hs b/tests/Submodule.hs index 66e6744..282504e 100644 --- a/tests/Submodule.hs +++ b/tests/Submodule.hs @@ -1,6 +1,3 @@ -{-# LANGUAGE QuasiQuotes #-} -{-# LANGUAGE TemplateHaskell #-} - module Submodule where import Data.Int diff --git a/tests/Submodule/Submodule.hs b/tests/Submodule/Submodule.hs index 41b5e66..557465b 100644 --- a/tests/Submodule/Submodule.hs +++ b/tests/Submodule/Submodule.hs @@ -1,6 +1,3 @@ -{-# LANGUAGE QuasiQuotes #-} -{-# LANGUAGE TemplateHaskell #-} - module Submodule.Submodule where import Data.Int From d04122dc42f31d440202d4f592829d0383d36313 Mon Sep 17 00:00:00 2001 From: ners Date: Fri, 21 Feb 2025 23:43:05 +0100 Subject: [PATCH 05/13] wip --- examples/ADT.hs | 2 +- examples/PreludeStuff.hs | 2 +- flake.lock | 6 +- src/Language/Rust/Inline.hs | 4 +- .../Rust/Inline/Context/Marshalable.hs | 1 - src/Language/Rust/Inline/Context/Prelude.hs | 5 +- src/Language/Rust/Inline/TH.hs | 4 +- src/Language/Rust/Inline/TH/Marshalable.hs | 350 ++++++++++++++++++ tests/AlgebraicDataTypes.hs | 14 +- 9 files changed, 368 insertions(+), 20 deletions(-) create mode 100644 src/Language/Rust/Inline/TH/Marshalable.hs diff --git a/examples/ADT.hs b/examples/ADT.hs index 8bfb7f5..01a1039 100644 --- a/examples/ADT.hs +++ b/examples/ADT.hs @@ -15,7 +15,7 @@ data Point a = Point a a deriving (Show) -- data Either ... {- already defined in 'Data.Either' -- Make some 'Storable' instances -mkStorable [t| forall a. Storable a => Storable (Point a) |] +mkMarshalable [t| forall a. Storable a => Storable (Point a) |] -- Generate corresponding Rust types extendContext (rustTyCtx [t| forall a. Point a |]) diff --git a/examples/PreludeStuff.hs b/examples/PreludeStuff.hs index 58f10ee..189c810 100644 --- a/examples/PreludeStuff.hs +++ b/examples/PreludeStuff.hs @@ -15,7 +15,7 @@ setCrateRoot [] -- Some ADTs data Point a = Point a a deriving (Show) -mkStorable [t| forall a. Storable a => Storable (Point a) |] +mkMarshalable [t| forall a. Storable a => Storable (Point a) |] extendContext (rustTyCtx [t| forall a. Point a |]) main = do diff --git a/flake.lock b/flake.lock index 074062a..a4f2a68 100644 --- a/flake.lock +++ b/flake.lock @@ -18,11 +18,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1739446958, - "narHash": "sha256-+/bYK3DbPxMIvSL4zArkMX0LQvS7rzBKXnDXLfKyRVc=", + "lastModified": 1739866667, + "narHash": "sha256-EO1ygNKZlsAC9avfcwHkKGMsmipUk1Uc0TbrEZpkn64=", "owner": "nixos", "repo": "nixpkgs", - "rev": "2ff53fe64443980e139eaa286017f53f88336dd0", + "rev": "73cf49b8ad837ade2de76f87eb53fc85ed5d4680", "type": "github" }, "original": { diff --git a/src/Language/Rust/Inline.hs b/src/Language/Rust/Inline.hs index 327ae9e..e327075 100644 --- a/src/Language/Rust/Inline.hs +++ b/src/Language/Rust/Inline.hs @@ -77,7 +77,7 @@ module Language.Rust.Inline ( newArray, withByteString, unsafeLocalState, - mkStorable, + mkMarshalable, mkReprC, Marshalable.PeekType, @@ -95,7 +95,7 @@ import Language.Rust.Inline.Marshal import Language.Rust.Inline.Parser import Language.Rust.Inline.Pretty import Language.Rust.Inline.TH.ReprC (mkReprC) -import Language.Rust.Inline.TH.Storable (mkStorable) +import Language.Rust.Inline.TH.Marshalable (mkMarshalable) import Language.Haskell.TH.Lib import Language.Haskell.TH.Quote (QuasiQuoter (..)) diff --git a/src/Language/Rust/Inline/Context/Marshalable.hs b/src/Language/Rust/Inline/Context/Marshalable.hs index 3484c23..df9531b 100644 --- a/src/Language/Rust/Inline/Context/Marshalable.hs +++ b/src/Language/Rust/Inline/Context/Marshalable.hs @@ -21,7 +21,6 @@ import qualified Foreign import Data.ByteString (ByteString) import Data.ByteString.Internal (ByteString(PS)) import qualified Data.ByteString.Unsafe as ByteString -import Language.Rust.Inline.Context.Prelude () class Storable (WithPtrType a) => HasWith a where type WithPtrType a diff --git a/src/Language/Rust/Inline/Context/Prelude.hs b/src/Language/Rust/Inline/Context/Prelude.hs index e651298..12006df 100644 --- a/src/Language/Rust/Inline/Context/Prelude.hs +++ b/src/Language/Rust/Inline/Context/Prelude.hs @@ -8,7 +8,6 @@ Stability : experimental Portability : GHC -} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE ScopedTypeVariables #-} {-# OPTIONS_GHC -w #-} @@ -38,8 +37,8 @@ import Data.Maybe ( fromMaybe ) -- * Tuples up and including to arity 16 -- -- Note that arity 0 is in 'Foreign.Storable' and arity 1 makes no sense in Haskell. -mkStorable [t| forall a. Storable a => Storable (Maybe a) |] -mkStorable [t| forall l r. (Storable l, Storable r) => Storable (Either l r) |] +mkMarshalable [t| forall a. Storable a => Storable (Maybe a) |] +mkMarshalable [t| forall l r. (Storable l, Storable r) => Storable (Either l r) |] fmap join (traverse mkTupleStorable [2..16]) -- | Make a generic path type (e.g. something like @Vec@). diff --git a/src/Language/Rust/Inline/TH.hs b/src/Language/Rust/Inline/TH.hs index 36e0deb..b91c9a5 100644 --- a/src/Language/Rust/Inline/TH.hs +++ b/src/Language/Rust/Inline/TH.hs @@ -1,9 +1,9 @@ -module Language.Rust.Inline.TH ( adtCtx, rustTyCtx, mkStorable, mkTupleStorable ) where +module Language.Rust.Inline.TH ( adtCtx, rustTyCtx, mkMarshalable, mkTupleMarshalable ) where import Language.Rust.Inline.TH.Utilities ( getTyConOpt, getTyCon ) import Language.Rust.Inline.TH.ReprC -import Language.Rust.Inline.TH.Storable ( mkStorable, mkTupleStorable ) +import Language.Rust.Inline.TH.Marshalable ( mkMarshalable, mkTupleMarshalable ) import Language.Rust.Inline.Context import Language.Rust.Inline.Internal import Language.Rust.Inline.Pretty diff --git a/src/Language/Rust/Inline/TH/Marshalable.hs b/src/Language/Rust/Inline/TH/Marshalable.hs new file mode 100644 index 0000000..29b6028 --- /dev/null +++ b/src/Language/Rust/Inline/TH/Marshalable.hs @@ -0,0 +1,350 @@ +{-| +Module : Language.Rust.Inline.TH.Marshalable +Description : Generate Marshalable instances +Copyright : (c) Alec Theriault, 2018 +License : BSD-style +Maintainer : ners +Stability : experimental +Portability : GHC +-} + +{-# LANGUAGE TemplateHaskellQuotes #-} +{-# LANGUAGE FlexibleInstances #-} +{-# OPTIONS_GHC -Wwarn #-} -- TODO: GHC bug around "unused pattern binds" in splices + -- TODO: GHC feature around setting extensions from within TH +module Language.Rust.Inline.TH.Marshalable ( + mkMarshalable, + mkTupleMarshalable, +) where + +import Language.Rust.Inline.TH.Utilities + +import Language.Haskell.TH +import Language.Haskell.TH.Syntax hiding (lift) +import Control.Monad.Trans.State ( StateT(..), get, put ) +import Control.Monad.Trans.Class ( lift ) +import Data.Traversable ( for ) +import Foreign.Ptr ( plusPtr, castPtr, Ptr ) +import Data.Word ( Word8, Word16, Word32, Word64 ) +import Language.Rust.Inline.Context.Marshalable + +-- | Generate 'Marshalable' instance for a non-recursive simple algebraic data +-- type. The instance follows the usual C layout for determining alignment and +-- size. +-- +-- Sum types are implemented as tagged unions. +-- +-- >>> mkMarshalable [t| forall a. Marshalable a => Marshalable (Maybe a) |] +-- +-- Remember to have 'ScopedTypeVariables', 'ExplicitForall', and 'EmptyCase' +-- enabled when calling this! +mkMarshalable :: TypeQ -- ^ a type representing the desired instance head + -> Q [Dec] -- ^ the instance declaration +mkMarshalable tyq = do + pseudoInstHead <- tyq + + -- Extract the context + marshalable <- [t| Marshalable |] + (ctx, ty') <- + case pseudoInstHead of + ForallT _ ctx (AppT s ty) | s == marshalable -> pure (ctx, ty) + AppT s ty | s == marshalable -> pure ([], ty) + _ -> fail "mkMarshalable: malformed 'Marshalable' instance head" + + -- Get the type constructors name + (_,cons') <- getConstructors ty' + + -- Produce the instance + methods <- processADT [ (nameCon n, tyArgs) | (n,tyArgs) <- cons' ] + dec <- instanceD (pure ctx) (pure (AppT marshalable ty')) (map pure methods) + pure [dec] + +mkTupleMarshalable :: Int -- ^ arity of tuple + -> Q [Dec] -- ^ the instance declaration +mkTupleMarshalable n = do + storable <- [t| Marshalable |] + tyVars <- sequence (take n [ newName (c : show i) + | i <- [(1 :: Int)..] + , c <- ['a'..'z'] + ]) + let ctx = [ AppT storable (VarT tyVar) | tyVar <- tyVars ] + let instHead = AppT storable (foldl AppT (TupleT n) (map VarT tyVars)) + + methods <- processADT [ (tupCon, map VarT tyVars) ] + let dec = InstanceD Nothing ctx instHead methods + pure [dec] + +-- * Constructor utilities +data Constructor = Constructor + { conPat :: [Pat] -> Pat + , conExp :: [Exp] -> Exp + } + +nameCon :: Name -> Constructor +nameCon n = Constructor (ConP n []) (foldl AppE (ConE n)) + +tupCon :: Constructor +tupCon = Constructor TupP (TupE . fmap Just) + + +-- * Alignment + +-- | This is the information you need to carry along as you visit the fields of +-- a struct/union. +data Alignment = Alignment + { decs :: [Dec] + -- ^ declarations for variables relied on by offset and align + + , offsetSoFar :: Code Q Int + -- ^ total bytes occupied so far by fields + + , alignSoFar :: Code Q Int + -- ^ size (in bytes) of the largest member in the struct + } + +-- | Combining alignment means concatenating the dependent declarations, and +-- take the maximum for offset and alignment. +instance Semigroup Alignment where + a1 <> a2 = Alignment + { decs = decs a1 <> decs a2 + , offsetSoFar = [|| $$(offsetSoFar a1) `max` $$(offsetSoFar a2) ||] + , alignSoFar = [|| $$(alignSoFar a1) `max` $$(alignSoFar a2) ||] + } + +-- | The 'mconcat' method calls 'maximum' +instance Monoid Alignment where + mempty = Alignment [] [|| 0 ||] [|| 1 ||] + mappend = (<>) + mconcat as = Alignment + { decs = concatMap decs as + , offsetSoFar = [|| maximum $$(liftCode $ listTE <$> mapM (examineCode . offsetSoFar) as) ||] + , alignSoFar = [|| maximum $$(liftCode $ listTE <$> mapM (examineCode . alignSoFar) as) ||] + } + +-- | This is the state we will bundle along while visiting fields. +type StructState = StateT Alignment Q + +-- | Make a typed list. This function is like 'listE', but for 'TExp'. +listTE :: [TExp a] -> TExp [a] +listTE = TExp . ListE . map unType + + +-- * Peek and poke helper functions + +-- | Produces a 'do' block for peeking a constructor. The generated code has the +-- following shape: +-- +-- @ +-- do f1 <- ... ptr +-- f2 <- ... ptr +-- ... +-- fn <- ... ptr +-- return (Con f1 f2 ... fn) +-- @ +-- +peekCon :: Constructor -- ^ name of the constructor + -> [Exp -> Q Exp] -- ^ how to peek every field + -> Name -- ^ the base pointer + -> Q Exp -- ^ a 'do' expression for peeking the constructor +peekCon con peekFields ptr = do + (ns, binds) <- unzip <$> do + for peekFields $ \fldCont -> do + n <- newName "n" + pure (varE n, bindS (varP n) (fldCont (VarE ptr))) + let ret = [e| return $(conExp con <$> sequence ns) |] + doE (binds ++ [noBindS ret]) + +-- | Produces a 'do' block for poking a constructor, along with a pattern for +-- extracting out the right fields. Given a pattern like @Con f1 f2 ... fn@, the +-- generated block has the following shape: +-- +-- @ +-- do ... ptr f1 +-- ... ptr f2 +-- ... +-- ... ptr fn +-- @ +pokeCon :: Constructor -- ^ name of the constructor + -> [Exp -> Q Exp] -- ^ how to poke every field + -> Name -- ^ the base poniter + -> Q (Pat, Exp) -- ^ a pattern to match, an expression for poking +pokeCon con pokeFields ptr = do + (ns, stmts) <- unzip <$> do + for pokeFields $ \fldCont -> do + n <- newName "n" + pure (varP n, noBindS [e| $(fldCont (VarE ptr)) $(varE n) |]) + pat <- conPat con <$> sequence ns + expr <- if null stmts then [e| pure () |] else doE stmts + return (pat, expr) + + +-- * Traversing fields (putting everything together) + +-- TODO: look at `alignPtr :: Ptr a -> Int -> Ptr a` + +-- | Process a field of a given type. +processField :: Type -> StructState (Exp -> Q Exp, Exp -> Q Exp) +processField ty = do + let alignTy, sizeTy :: Code Q Int + alignTy = Code $ TExp <$> [e| alignment (undefined :: $(pure ty)) |] + sizeTy = Code $ TExp <$> [e| sizeOf (undefined :: $(pure ty)) |] + + -- get state at the end of the last field + Alignment prevDecs prevOff prevAlign <- get + + -- beginning offset + beginOffV <- lift $ newName "beginOff" + let beginOffE, beginOff :: Code Q Int + beginOffE = [|| $$prevOff + mod (negate $$prevOff) $$alignTy ||] + beginOff = Code $ TExp <$> varE beginOffV + assignBeginOff <- lift [d| $(varP beginOffV) = $(unType <$> examineCode beginOffE) |] + + -- offset after this field + newOffV <- lift $ newName "afterOff" + let newOffE :: Code Q Int + newOffE = [|| $$beginOff + $$sizeTy ||] + newOff <- lift (TExp <$> varE newOffV) + assignNewOff <- lift [d| $(varP newOffV) = $(unType <$> examineCode newOffE) |] + + -- alignment after this field + newAlignV <- lift $ newName "algn" + let newAlignE :: Code Q Int + newAlignE = [|| $$alignTy `max` $$prevAlign ||] + newAlign <- lift (TExp <$> varE newAlignV) + assignNewAlign <- lift [d| $(varP newAlignV) = $(unType <$> examineCode newAlignE) |] + + -- update state + put (Alignment { decs = concat [ assignBeginOff + , assignNewOff + , assignNewAlign + , prevDecs + ] + , offsetSoFar = liftCode (pure newOff) + , alignSoFar = liftCode (pure newAlign) + }) + + -- TODO: consider degenerate sizeof(..) = 0 cases + pure ( \addrE -> [e| peek (castPtr $(pure addrE) `plusPtr` $(unType <$> examineCode beginOff)) |] + , \addrE -> [e| poke (castPtr $(pure addrE) `plusPtr` $(unType <$> examineCode beginOff)) |] + ) + + +-- | Process an algebraic data type. +-- +-- TODO: think about the zero constructor case... +processADT :: [(Constructor, [Type])] -- ^ constructors and the types of their fields + -> Q (Dec, Dec) -- ^ with and peek implementations + +-- The one constructor case is special - we don't need to specify a tag +processADT [(con, fields)] = do + + initAlign <- mempty + (peekPokes, Alignment ds off algn) + <- runStateT (traverse processField fields) initAlign + let ds' = map pure ds + + -- sizeOf + sizeOf_ <- do + Just sizeOfN <- lookupValueName "sizeOf" + funD sizeOfN [clause [wildP] + (normalB [e| let c = $(unType <$> examineCode off) + in c + mod (negate c) $(unType <$> examineCode algn) |]) + ds'] + + -- alignment + alignment_ <- do + Just alignmentN <- lookupValueName "alignment" + funD alignmentN [clause [wildP] (normalB (unType <$> examineCode algn)) ds'] + + let (peekFields, pokeFields) = unzip peekPokes + + -- peek + peek_ <- do + ptr <- newName "ptr" + Just peekN <- lookupValueName "peek" + funD peekN [clause [varP ptr] (normalB (peekCon con peekFields ptr)) ds'] + + -- poke + poke_ <- do + ptr <- newName "ptr" + (cPat,body) <- pokeCon con pokeFields ptr + Just pokeN <- lookupValueName "poke" + funD pokeN [clause [varP ptr, pure cPat] (normalB (pure body)) ds'] + + pure [sizeOf_, alignment_, peek_, poke_] + +processADT cons = do + + let discNum = length cons + discTy <- snd . head . dropWhile (\(m,_) -> discNum > m + 1) $ + [ (fromIntegral (maxBound :: Word8), [t| Word8 |]) + , (fromIntegral (maxBound :: Word16), [t| Word16 |]) + , (fromIntegral (maxBound :: Word32), [t| Word32 |]) + , (fromIntegral (maxBound :: Word64), [t| Word64 |]) + ] + + initAlign <- mempty + (conPeekPokess, algns) <- unzip <$> do + for cons $ \(con, fields) -> do + (peekPokes, algn) <- runStateT (traverse processField fields) initAlign + let (peekFields, pokeFields) = unzip peekPokes + pure ((con, peekFields, pokeFields), algn) + Alignment ds off algn <- mconcat (map pure algns) + let discSizeOf = [e| sizeOf (undefined :: $(pure discTy)) |] + algn' = [e| $discSizeOf `max` $(unType <$> examineCode algn) |] + let ds' = map pure ds + + -- sizeOf + sizeOf_ <- do + Just sizeOfN <- lookupValueName "sizeOf" + funD sizeOfN [clause [wildP] + (normalB [e| let c = $(unType <$> examineCode off) + in $algn' + c + mod (negate c) $algn' |]) + ds'] + + -- alignment + alignment_ <- do + Just alignmentN <- lookupValueName "alignment" + funD alignmentN [clause [wildP] (normalB algn') ds'] + + -- peek + peek_ <- do + ptr <- newName "ptr" + ptrOff <- newName "ptrOff" + d' <- [d| $(varP ptrOff) = $(varE ptr) `plusPtr` $algn' |] + disc <- newName "disc" + let mtchs = [ match (litP n') (normalB (peekCon con peekFields ptrOff)) [] + | (n, (con, peekFields, _)) <- zip [0..] conPeekPokess + , let n' = IntegerL n + ] + Just peekN <- lookupValueName "peek" + funD peekN + [clause [varP ptr] + (normalB (doE [ bindS (varP disc) [e| peek (castPtr $(varE ptr) :: Ptr $(pure discTy)) |] + , noBindS (caseE (varE disc) mtchs) + ])) + (map pure d' ++ ds')] + + -- poke + poke_ <- do + ptr <- newName "ptr" + ptrOff <- newName "ptrOff" + d' <- [d| $(varP ptrOff) = $(varE ptr) `plusPtr` $algn' |] + disc <- newName "disc" + let mtchs = [ do { (pat,body) <- patBody + ; match (pure pat) + (normalB (doE (map noBindS [ [e| poke (castPtr $(varE ptr) :: Ptr $(pure discTy)) $(litE n') |] + , pure body + ]))) + [] + } + | (n, (con, _, pokeFields)) <- zip [0..] conPeekPokess + , let patBody = pokeCon con pokeFields ptrOff + , let n' = IntegerL n + ] + Just pokeN <- lookupValueName "poke" + funD pokeN + [clause [varP ptr, varP disc] (normalB (caseE (varE disc) mtchs)) (map pure d' ++ ds')] + + pure [sizeOf_, alignment_, peek_, poke_] + diff --git a/tests/AlgebraicDataTypes.hs b/tests/AlgebraicDataTypes.hs index c92faee..8f07aab 100644 --- a/tests/AlgebraicDataTypes.hs +++ b/tests/AlgebraicDataTypes.hs @@ -18,11 +18,11 @@ import Data.Int ( Int8, Int16, Int32, Int64 ) -- | A struct-like ADT where the fields have different sizes data StructLike = StructLike Int16 Int64 deriving (Show, Eq) -mkStorable [t| Storable StructLike |] +mkMarshalable [t| Storable StructLike |] -- | A struct-like newtype ADT where the field is compound newtype StructLike2 = StructLike2 (Int16, Int64) deriving (Show, Eq) -mkStorable [t| Storable StructLike2 |] +mkMarshalable [t| Storable StructLike2 |] -- | An ADT where: @@ -35,14 +35,14 @@ data Foo | Baz Char Int | Qux (Complex Float) Char deriving (Show, Eq) -mkStorable [t| Storable Foo |] +mkMarshalable [t| Storable Foo |] -- | An ADT where fields are nested ADTs data Croc = Lob (Maybe Foo) Int | Boo Int8 Int8 deriving (Show, Eq) -mkStorable [t| Storable Croc |] +mkMarshalable [t| Storable Croc |] -- | A polymorphic ADT. (From the @these@ package). data These a b @@ -50,7 +50,7 @@ data These a b | That b | Both a b deriving (Show, Eq) -mkStorable [t| forall a b. (Storable a, Storable b) => Storable (These a b) |] +mkMarshalable [t| forall a b. (Storable a, Storable b) => Storable (These a b) |] -- | An ADT that needs more that a 'Word8' to store the tag data Big a @@ -85,7 +85,7 @@ data Big a | C280 | C281 | C282 | C283 | C284 | C285 | C286 | C287 | C288 | C289 | C290 | C291 | C292 | C293 | C294 | C295 | C296 | C297 | C298 | C299 a deriving (Show, Eq) -mkStorable [t| forall a. Storable a => Storable (Big a) |] +mkMarshalable [t| forall a. Storable a => Storable (Big a) |] -- | An ADT with a mixture of polymorphism and not. data Foo2 a @@ -94,7 +94,7 @@ data Foo2 a | Qux2 a a | Quux2 Int a deriving (Show, Eq) -mkStorable [t| forall a. Storable a => Storable (Foo2 a) |] +mkMarshalable [t| forall a. Storable a => Storable (Foo2 a) |] -- Set the context From ca3d45695fc3bf977c59e56518112c3b3cb78900 Mon Sep 17 00:00:00 2001 From: Viktor Kleen Date: Sat, 22 Feb 2025 13:09:51 +0000 Subject: [PATCH 06/13] Break circular dependencies around Storable for tuples --- inline-rust.cabal | 1 + src/Language/Rust/Inline/Context/Marshalable.hs | 1 + src/Language/Rust/Inline/Context/Prelude.hs | 2 +- src/Language/Rust/Inline/Storable/Tuple.hs | 11 +++++++++++ 4 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 src/Language/Rust/Inline/Storable/Tuple.hs diff --git a/inline-rust.cabal b/inline-rust.cabal index aa493b1..9c17fce 100644 --- a/inline-rust.cabal +++ b/inline-rust.cabal @@ -35,6 +35,7 @@ library Language.Rust.Inline.Marshal Language.Rust.Inline.Parser Language.Rust.Inline.Pretty + Language.Rust.Inline.Storable.Tuple Language.Rust.Inline.TH.Marshalable Language.Rust.Inline.TH.ReprC Language.Rust.Inline.TH.Storable diff --git a/src/Language/Rust/Inline/Context/Marshalable.hs b/src/Language/Rust/Inline/Context/Marshalable.hs index df9531b..4b642bf 100644 --- a/src/Language/Rust/Inline/Context/Marshalable.hs +++ b/src/Language/Rust/Inline/Context/Marshalable.hs @@ -21,6 +21,7 @@ import qualified Foreign import Data.ByteString (ByteString) import Data.ByteString.Internal (ByteString(PS)) import qualified Data.ByteString.Unsafe as ByteString +import Language.Rust.Inline.Storable.Tuple () class Storable (WithPtrType a) => HasWith a where type WithPtrType a diff --git a/src/Language/Rust/Inline/Context/Prelude.hs b/src/Language/Rust/Inline/Context/Prelude.hs index 12006df..dc68783 100644 --- a/src/Language/Rust/Inline/Context/Prelude.hs +++ b/src/Language/Rust/Inline/Context/Prelude.hs @@ -39,7 +39,7 @@ import Data.Maybe ( fromMaybe ) -- Note that arity 0 is in 'Foreign.Storable' and arity 1 makes no sense in Haskell. mkMarshalable [t| forall a. Storable a => Storable (Maybe a) |] mkMarshalable [t| forall l r. (Storable l, Storable r) => Storable (Either l r) |] -fmap join (traverse mkTupleStorable [2..16]) +fmap join (traverse mkTupleMarshalable [2..16]) -- | Make a generic path type (e.g. something like @Vec@). mkGenPathTy :: Ident -> [Ty ()] -> Ty () diff --git a/src/Language/Rust/Inline/Storable/Tuple.hs b/src/Language/Rust/Inline/Storable/Tuple.hs new file mode 100644 index 0000000..1aefebf --- /dev/null +++ b/src/Language/Rust/Inline/Storable/Tuple.hs @@ -0,0 +1,11 @@ +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# OPTIONS_GHC -w #-} + +module Language.Rust.Inline.Storable.Tuple where + +import Control.Monad (join) +import Foreign.Storable +import Language.Rust.Inline.TH.Storable + +fmap join (traverse mkTupleStorable [2..16]) From 41cfefd05d5d4dc3eadb2712ad8a8580519ff410 Mon Sep 17 00:00:00 2001 From: ners Date: Sat, 22 Feb 2025 23:38:19 +0100 Subject: [PATCH 07/13] =?UTF-8?q?isch=20guet=20=F0=9F=98=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- inline-rust.cabal | 2 +- src/Language/Rust/Inline.hs | 9 +- .../Rust/Inline/Context/ByteString.hs | 2 + .../Rust/Inline/Context/Marshalable.hs | 177 +++++++---- src/Language/Rust/Inline/Context/Prelude.hs | 8 +- src/Language/Rust/Inline/Marshal.hs | 19 +- src/Language/Rust/Inline/TH.hs | 16 +- src/Language/Rust/Inline/TH/Marshalable.hs | 280 +++++++++--------- src/Language/Rust/Inline/TH/Storable.hs | 28 +- tests/AlgebraicDataTypes.hs | 258 ++++++++-------- tests/ByteStrings.hs | 33 ++- tests/ForeignPtr.hs | 22 +- tests/PreludeTypes.hs | 110 ++++--- tests/SimpleTypes.hs | 8 +- 14 files changed, 537 insertions(+), 435 deletions(-) diff --git a/inline-rust.cabal b/inline-rust.cabal index 9c17fce..4040776 100644 --- a/inline-rust.cabal +++ b/inline-rust.cabal @@ -22,7 +22,7 @@ source-repository head library hs-source-dirs: src - ghc-options: -Wall + ghc-options: -Wall -ddump-splices -ddump-to-file default-language: Haskell2010 exposed-modules: Language.Rust.Inline diff --git a/src/Language/Rust/Inline.hs b/src/Language/Rust/Inline.hs index e327075..c46f38e 100644 --- a/src/Language/Rust/Inline.hs +++ b/src/Language/Rust/Inline.hs @@ -80,8 +80,6 @@ module Language.Rust.Inline ( mkMarshalable, mkReprC, - Marshalable.PeekType, - -- * Top-level Rust items ) where @@ -362,7 +360,8 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do case arg of Nothing -> fail ("Could not find Haskell variable ‘" ++ argStr ++ "’") Just argName - | marshalStep marshalForm -> do + | passByValue marshalForm -> goArgs (varE argName : acc) args + | otherwise -> do x <- newName "x" [e| Marshalable.with @@ -371,7 +370,6 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do $(goArgs (varE x : acc) args) ) |] - | otherwise -> goArgs (varE argName : acc) args let haskCall' = goArgs [] (rustArgNames `zip` marshalForms) haskCall = @@ -387,6 +385,9 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do mergeArgs t Nothing = (t, t) mergeArgs t (Just tInter) = (fmap (const mempty) tInter, t) + -- EitherC + -- EitherC -> Result + -- Generate the Rust function. let retByVal = returnByValue returnFfi (retArg, retTy, ret) diff --git a/src/Language/Rust/Inline/Context/ByteString.hs b/src/Language/Rust/Inline/Context/ByteString.hs index 1e0e31a..3bfd539 100644 --- a/src/Language/Rust/Inline/Context/ByteString.hs +++ b/src/Language/Rust/Inline/Context/ByteString.hs @@ -38,6 +38,7 @@ bytestrings = rule rty _ | rty == void [ty| &[u8] |] = pure ([t|ByteString|], pure . pure $ void [ty| RustByteString |]) | rty == void [ty| Vec |] = pure ([t|ByteString|], pure . pure $ void [ty| RustOwnedByteString |]) + | rty == void [ty| RustOwnedByteString |] = pure ([t|ByteString|], pure . pure $ void [ty| RustOwnedByteString |]) -- | rty == void [ty| Option> |] = pure ([t|Maybe ByteString|], pure . pure $ void [ty| RustOwnedByteString |]) rule _ _ = mempty @@ -74,6 +75,7 @@ bytestrings = , " }" , "}" , "" + , "impl MarshalInto for RustOwnedByteString { fn marshal(self) -> RustOwnedByteString { self } }" -- , "impl MarshalInto for Option> {" -- , " fn marshal(self) -> RustOwnedByteString {" -- , " extern fn panic(ptr: *mut u8, len: usize) {" diff --git a/src/Language/Rust/Inline/Context/Marshalable.hs b/src/Language/Rust/Inline/Context/Marshalable.hs index 4b642bf..f8b5621 100644 --- a/src/Language/Rust/Inline/Context/Marshalable.hs +++ b/src/Language/Rust/Inline/Context/Marshalable.hs @@ -1,10 +1,10 @@ +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE InstanceSigs #-} {-# LANGUAGE TypeApplications #-} -{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE DefaultSignatures #-} module Language.Rust.Inline.Context.Marshalable where @@ -13,67 +13,140 @@ import Foreign Ptr, FunPtr, ForeignPtr, - Storable, plusPtr, newForeignPtr, - withForeignPtr) + withForeignPtr, + castPtr, + sizeOf, + alignment, Storable, poke ) import qualified Foreign import Data.ByteString (ByteString) import Data.ByteString.Internal (ByteString(PS)) import qualified Data.ByteString.Unsafe as ByteString +import Foreign.C.Types +import Data.Int +import Data.Word import Language.Rust.Inline.Storable.Tuple () -class Storable (WithPtrType a) => HasWith a where - type WithPtrType a - with :: a -> (Ptr (WithPtrType a) -> IO b) -> IO b - with x k = Foreign.alloca $ \loc -> withLoc x loc (k loc) +-- | A generalisation of `Storable`'s `with` that respects finalizers and lets us avoid copies for types that can not be `Storable`. +class Marshalable a where + -- | The size of the `Storable` representation of `a` + sizeOfWith :: a -> Int + default sizeOfWith :: Storable a => a -> Int + sizeOfWith = sizeOf + + alignmentWith :: a -> Int + default alignmentWith :: Storable a => a -> Int + alignmentWith = alignment - withLoc :: a -> Ptr (WithPtrType a) -> IO b -> IO b + -- | Holds a reference to its argument while making it available to foreign code. + -- By default allocates space to store `WithPtrType a` and makes `a` available as `WithPtrType a` there. + with + :: a + -- ^ The data to marshal + -> (Foreign.Ptr a -> IO b) + -- ^ The continuation that takes a pointer to the marshaled data + -> IO b + with x k = Foreign.allocaBytesAligned (sizeOfWith (undefined :: a)) (alignmentWith (undefined :: a)) $ \loc -> withLoc x loc (k loc) -instance {-# OVERLAPPING #-} HasWith ByteString where - type WithPtrType ByteString = (Ptr Word8, Word) - withLoc (PS ptr off len) loc k = withForeignPtr ptr $ \ptr' -> - Foreign.poke loc (ptr' `plusPtr` off, fromIntegral len) >> k + -- | Hold a reference to `a` and `poke`s the `WithPtrType a` into a preallocated location. + -- Call this if you have a specific memory layout requirement, e.g. to marshal `a` as part of a larger data structure. + withLoc + :: a + -- ^ The data to marshal + -> Ptr a + -- ^ The location to marshal into (where to put the pointer to the data) + -> IO b + -- ^ The action to run while holding the reference to the data + -> IO b + default withLoc :: Storable a => a -> Ptr a -> IO b -> IO b + withLoc p loc k = poke loc p >> k -instance {-# OVERLAPPING #-} Storable a => HasWith (ForeignPtr a) where - type WithPtrType (ForeignPtr a) = Ptr a - withLoc fp loc k = withForeignPtr fp $ \ptr -> - Foreign.poke loc ptr >> k + sizeOfPeek :: a -> Int + default sizeOfPeek :: Storable a => a -> Int + sizeOfPeek = sizeOf -class HasPeek a where - type PeekType a - peek :: Ptr (PeekType a) -> IO a + alignmentPeek :: a -> Int + default alignmentPeek :: Storable a => a -> Int + alignmentPeek = alignment -foreign import ccall safe "dynamic" bytestringFree :: FunPtr (Ptr Word8 -> Word -> IO ()) -> Ptr Word8 -> Word -> IO () + -- | `peek` the `Storable` representation and convert it to `a` + peek + :: Ptr b + -- ^ The pointer to peek at + -> IO a + default peek :: Storable a => Ptr b -> IO a + peek = Foreign.peek . castPtr -instance {-# OVERLAPPING #-} HasPeek ByteString where - type PeekType ByteString = (Ptr Word8, Word, FunPtr (Ptr Word8 -> Word -> IO ())) - peek ret = do - (ptr, len, finalizer) <- Foreign.peek ret +instance Marshalable ByteString where + sizeOfWith _ = Foreign.sizeOf (undefined :: (Foreign.Ptr Foreign.Word8, Word)) + alignmentWith _ = Foreign.alignment (undefined :: (Foreign.Ptr Foreign.Word8, Word)) + withLoc (PS ptr off len) loc k = Foreign.withForeignPtr ptr $ \ptr' -> do + Foreign.poke @(Foreign.Ptr Foreign.Word8, Word) (castPtr loc) (ptr' `Foreign.plusPtr` off, fromIntegral len) + k + sizeOfPeek _ = Foreign.sizeOf (undefined :: (Foreign.Ptr Foreign.Word8, Word, Foreign.FunPtr (Foreign.Ptr Foreign.Word8 -> Word -> IO ()))) + alignmentPeek _ = Foreign.alignment (undefined :: (Foreign.Ptr Foreign.Word8, Word, Foreign.FunPtr (Foreign.Ptr Foreign.Word8 -> Word -> IO ()))) + peek p = do + (ptr, len, finalizer) <- Foreign.peek (castPtr p) ByteString.unsafePackCStringFinalizer ptr (fromIntegral len) (bytestringFree finalizer ptr len) -instance {-# OVERLAPPING #-} HasPeek (ForeignPtr a) where - type PeekType (ForeignPtr a) = (Ptr a, FunPtr (Ptr a -> IO ())) - peek ret = do - (ptr, finalizer) <- Foreign.peek ret - newForeignPtr finalizer ptr - -class (HasWith a, HasPeek a) => Marshalable a where - -instance (Storable (PeekType a), HasPeek a) => HasPeek (Maybe a) where - type PeekType (Maybe a) = (Word8, PeekType a) - peek :: Ptr (Word8, PeekType a) -> IO (Maybe a) - peek ret = do - d <- Foreign.peek $ Foreign.castPtr @_ @Word8 ret - case d of - 0 -> pure Nothing - _ -> Just <$> peek @a (ret `plusPtr` Foreign.alignment @(PeekType a) undefined) - -instance HasWith a => HasWith (Maybe a) where - type WithPtrType (Maybe a) = (Word8, WithPtrType a) - withLoc Nothing loc k = - Foreign.poke (Foreign.castPtr @_ @Word8 loc) 0 >> k - withLoc (Just a) loc k = - let align = Foreign.alignment @(WithPtrType a) undefined - in do Foreign.poke (Foreign.castPtr @_ @Word8 loc) 1 - withLoc a (Foreign.castPtr loc `plusPtr` align) k +instance Marshalable (Foreign.ForeignPtr a) where + sizeOfWith = const $ sizeOf (undefined :: (Foreign.Ptr a)) + alignmentWith = const $ alignment (undefined :: (Foreign.Ptr a)) + withLoc fp loc k = Foreign.withForeignPtr fp $ \ptr -> do + Foreign.poke (castPtr loc) ptr + k + sizeOfPeek _ = Foreign.sizeOf (undefined :: (Foreign.Ptr a, Foreign.FunPtr (Foreign.Ptr a -> IO ()))) + alignmentPeek _ = Foreign.alignment (undefined :: (Foreign.Ptr a, Foreign.FunPtr (Foreign.Ptr a -> IO ()))) + peek p = do + (ptr, finalizer) <- Foreign.peek (castPtr p) + Foreign.newForeignPtr finalizer ptr + +foreign import ccall safe "dynamic" bytestringFree :: Foreign.FunPtr (Foreign.Ptr Foreign.Word8 -> Word -> IO ()) -> Foreign.Ptr Foreign.Word8 -> Word -> IO () + +instance Marshalable CChar +instance Marshalable CSChar +instance Marshalable CUChar +instance Marshalable CShort +instance Marshalable CUShort +instance Marshalable CInt +instance Marshalable CUInt +instance Marshalable CLong +instance Marshalable CULong +instance Marshalable CPtrdiff +instance Marshalable CSize +instance Marshalable CWchar +instance Marshalable CLLong +instance Marshalable CULLong +instance Marshalable CBool +instance Marshalable CIntPtr +instance Marshalable CUIntPtr +instance Marshalable CIntMax +instance Marshalable CUIntMax +instance Marshalable CClock +instance Marshalable CTime +instance Marshalable CUSeconds +instance Marshalable CSUSeconds +instance Marshalable CFloat +instance Marshalable CDouble + +-- TODO: Should we marshal these? We have them in the libc context ... +-- instance Marshalable CFile +-- instance Marshalable CFpos + +instance Marshalable Int8 +instance Marshalable Int16 +instance Marshalable Int32 +instance Marshalable Int64 + +instance Marshalable Word8 +instance Marshalable Word16 +instance Marshalable Word32 +instance Marshalable Word64 + +instance Marshalable Char +instance Marshalable Float +instance Marshalable Double +instance Marshalable Int +instance Marshalable Word +instance Marshalable () diff --git a/src/Language/Rust/Inline/Context/Prelude.hs b/src/Language/Rust/Inline/Context/Prelude.hs index dc68783..288cb68 100644 --- a/src/Language/Rust/Inline/Context/Prelude.hs +++ b/src/Language/Rust/Inline/Context/Prelude.hs @@ -15,7 +15,10 @@ Portability : GHC module Language.Rust.Inline.Context.Prelude where import Language.Rust.Inline.Context +import Language.Rust.Inline.Context.Marshalable import Language.Rust.Inline.TH +import Language.Rust.Inline.TH.Storable (mkStorable, mkTupleStorable) +import Language.Rust.Inline.TH.Marshalable (mkMarshalable, mkTupleMarshalable) import Language.Rust.Data.Ident ( Ident(..), mkIdent ) @@ -37,8 +40,9 @@ import Data.Maybe ( fromMaybe ) -- * Tuples up and including to arity 16 -- -- Note that arity 0 is in 'Foreign.Storable' and arity 1 makes no sense in Haskell. -mkMarshalable [t| forall a. Storable a => Storable (Maybe a) |] -mkMarshalable [t| forall l r. (Storable l, Storable r) => Storable (Either l r) |] +mkMarshalable [t| forall a. Marshalable a => Marshalable (Maybe a) |] +mkMarshalable [t| forall l r. (Marshalable l, Marshalable r) => Marshalable (Either l r) |] + fmap join (traverse mkTupleMarshalable [2..16]) -- | Make a generic path type (e.g. something like @Vec@). diff --git a/src/Language/Rust/Inline/Marshal.hs b/src/Language/Rust/Inline/Marshal.hs index cb4851e..8061c0e 100644 --- a/src/Language/Rust/Inline/Marshal.hs +++ b/src/Language/Rust/Inline/Marshal.hs @@ -15,7 +15,6 @@ Portability : GHC module Language.Rust.Inline.Marshal where import Language.Rust.Inline.Context -import Language.Rust.Inline.Context.Marshalable (PeekType, WithPtrType) import Language.Haskell.TH import Language.Haskell.TH.Syntax ( addTopDecls ) @@ -36,7 +35,6 @@ import GHC.Exts data MarshalForm = MarshalForm { passByValue :: Bool - , marshalStep :: Bool , returnByValue :: Bool , returnType :: Type -> Q Type , argumentType :: Type -> Q Type @@ -57,12 +55,12 @@ ghcMarshallable ty = do simpleB <- sequence qSimpleBoxed tyconsU <- sequence qTyconsUnboxed tyconsB <- sequence qTyconsBoxed + unitType <- [t| () |] bytestring <- [t| ByteString |] fptrCons <- [t| ForeignPtr |] let unboxedDirect = MarshalForm { passByValue = True - , marshalStep = False , returnByValue = True , returnType = pure , argumentType = pure @@ -70,34 +68,32 @@ ghcMarshallable ty = do , addIOUnit = False } boxedDirect = unboxedDirect{ returnType = \t -> [t|IO $(pure t)|], runsInIO = True } + unitDirect = boxedDirect { passByValue = False, argumentType = \t -> [t|Ptr $(pure t)|] } boxedIndirect = MarshalForm { passByValue = False - , marshalStep = True , returnByValue = False - , returnType = \t -> [t|Ptr (PeekType $(pure t))|] - , argumentType = \t -> [t|Ptr (WithPtrType $(pure t))|] + , returnType = const [t|Ptr ()|] + , argumentType = \t -> [t|Ptr $(pure t)|] , runsInIO = True , addIOUnit = True } foreignPtr = MarshalForm { passByValue = False - , marshalStep = True , returnByValue = False , returnType = \case AppT _ r -> [t|Ptr (Ptr $(pure r), FunPtr (Ptr $(pure r) -> IO ()))|] t -> fail $ "Cannot marshal " <> (show . pprParendType) t <> " as a ForeignPtr" , argumentType = \case - AppT _ r -> [t|Ptr (Ptr $(pure r))|] + AppT _ r -> [t|Ptr (ForeignPtr $(pure r))|] t -> fail $ "Cannot marshal " <> (show . pprParendType) t <> " as a ForeignPtr" , runsInIO = True , addIOUnit = True } byteString = MarshalForm { passByValue = False - , marshalStep = True , returnByValue = False , returnType = const [t|Ptr (Ptr Word8, Word, FunPtr (Ptr Word8 -> Word -> IO ()))|] - , argumentType = const [t|Ptr (Ptr Word8, Word)|] + , argumentType = const [t|Ptr ByteString|] , runsInIO = True , addIOUnit = True } @@ -105,6 +101,7 @@ ghcMarshallable ty = do case ty of _ | ty `elem` simpleU -> pure unboxedDirect | ty `elem` simpleB -> pure boxedDirect + | ty == unitType -> pure unitDirect | ty == bytestring -> pure byteString AppT con _ | con `elem` tyconsU -> pure unboxedDirect | con `elem` tyconsB -> pure boxedDirect @@ -131,7 +128,7 @@ ghcMarshallable ty = do , [t| Double |] , [t| Float |] - , [t| Bool |], [t| () |] -- TODO: let through `IO ()` but not `()` + , [t| Bool |] , [t| Int8 |], [t| Int16 |], [t| Int32 |], [t| Int64 |] , [t| Word8 |], [t| Word16 |], [t| Word32 |], [t| Word64 |] diff --git a/src/Language/Rust/Inline/TH.hs b/src/Language/Rust/Inline/TH.hs index b91c9a5..da3f866 100644 --- a/src/Language/Rust/Inline/TH.hs +++ b/src/Language/Rust/Inline/TH.hs @@ -1,9 +1,19 @@ -module Language.Rust.Inline.TH ( adtCtx, rustTyCtx, mkMarshalable, mkTupleMarshalable ) where +module Language.Rust.Inline.TH + ( adtCtx + , rustTyCtx + , mkStorable + , mkTupleStorable + , Marshalable(..) + , mkMarshalable + , mkTupleMarshalable + ) where import Language.Rust.Inline.TH.Utilities ( getTyConOpt, getTyCon ) import Language.Rust.Inline.TH.ReprC -import Language.Rust.Inline.TH.Marshalable ( mkMarshalable, mkTupleMarshalable ) +import Language.Rust.Inline.TH.Storable ( mkStorable, mkTupleStorable ) +import Language.Rust.Inline.Context.Marshalable (Marshalable(..)) +import Language.Rust.Inline.TH.Marshalable (mkMarshalable, mkTupleMarshalable ) import Language.Rust.Inline.Context import Language.Rust.Inline.Internal import Language.Rust.Inline.Pretty @@ -83,7 +93,7 @@ rustTyCtx tyq = do (_, ty) <- case ty' of ForallT tyvars [] t -> pure (tyvars, t) - ForallT _ _ _ -> fail "rustTyCtx: type cannot have context" + ForallT {} -> fail "rustTyCtx: type cannot have context" t -> pure ([], t) -- Get the type and its name diff --git a/src/Language/Rust/Inline/TH/Marshalable.hs b/src/Language/Rust/Inline/TH/Marshalable.hs index 29b6028..d65c1f7 100644 --- a/src/Language/Rust/Inline/TH/Marshalable.hs +++ b/src/Language/Rust/Inline/TH/Marshalable.hs @@ -8,9 +8,10 @@ Stability : experimental Portability : GHC -} -{-# LANGUAGE TemplateHaskellQuotes #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE FlexibleInstances #-} {-# OPTIONS_GHC -Wwarn #-} -- TODO: GHC bug around "unused pattern binds" in splices +{-# LANGUAGE TypeApplications #-} -- TODO: GHC feature around setting extensions from within TH module Language.Rust.Inline.TH.Marshalable ( mkMarshalable, @@ -24,9 +25,10 @@ import Language.Haskell.TH.Syntax hiding (lift) import Control.Monad.Trans.State ( StateT(..), get, put ) import Control.Monad.Trans.Class ( lift ) import Data.Traversable ( for ) -import Foreign.Ptr ( plusPtr, castPtr, Ptr ) +import Foreign.Ptr ( alignPtr, plusPtr, castPtr, Ptr ) import Data.Word ( Word8, Word16, Word32, Word64 ) import Language.Rust.Inline.Context.Marshalable +import qualified Foreign -- | Generate 'Marshalable' instance for a non-recursive simple algebraic data -- type. The instance follows the usual C layout for determining alignment and @@ -55,24 +57,22 @@ mkMarshalable tyq = do (_,cons') <- getConstructors ty' -- Produce the instance - methods <- processADT [ (nameCon n, tyArgs) | (n,tyArgs) <- cons' ] - dec <- instanceD (pure ctx) (pure (AppT marshalable ty')) (map pure methods) - pure [dec] + decs' <- processADT [ (nameCon n, tyArgs) | (n,tyArgs) <- cons' ] + pure . pure $ InstanceD Nothing ctx (AppT marshalable ty') decs' mkTupleMarshalable :: Int -- ^ arity of tuple -> Q [Dec] -- ^ the instance declaration mkTupleMarshalable n = do - storable <- [t| Marshalable |] + marshalable <- [t| Marshalable |] tyVars <- sequence (take n [ newName (c : show i) | i <- [(1 :: Int)..] , c <- ['a'..'z'] ]) - let ctx = [ AppT storable (VarT tyVar) | tyVar <- tyVars ] - let instHead = AppT storable (foldl AppT (TupleT n) (map VarT tyVars)) + let ctx c = [ AppT c (VarT tyVar) | tyVar <- tyVars ] + let instHead c = AppT c (foldl AppT (TupleT n) (map VarT tyVars)) - methods <- processADT [ (tupCon, map VarT tyVars) ] - let dec = InstanceD Nothing ctx instHead methods - pure [dec] + decs' <- processADT [ (tupCon, map VarT tyVars) ] + pure . pure $ InstanceD Nothing (ctx marshalable) (instHead marshalable) decs' -- * Constructor utilities data Constructor = Constructor @@ -128,8 +128,22 @@ type StructState = StateT Alignment Q listTE :: [TExp a] -> TExp [a] listTE = TExp . ListE . map unType +-- * With and Peek helper functions --- * Peek and poke helper functions +-- | TODO: vkleen will write docs +withCon :: Constructor -- ^ name of the constructor + -> [Exp -> Q Exp] -- ^ how to offset to every field + -> Name -- ^ the base pointer + -> Name -- ^ the name of the continuation parameter + -> Q (Pat, Exp) -- ^ an expression for poking the constructor +withCon con fieldOffsets ptr k = do + (ns, fields) <- unzip <$> do + for fieldOffsets $ \offset -> do + n <- newName "n" + pure (VarP n, [e| withLoc $(varE n) $(offset (VarE ptr)) |]) + let pat = conPat con ns + f <- foldr (\b e -> [e| $b $e |]) (varE k) fields + pure (pat, f) -- | Produces a 'do' block for peeking a constructor. The generated code has the -- following shape: @@ -143,59 +157,43 @@ listTE = TExp . ListE . map unType -- @ -- peekCon :: Constructor -- ^ name of the constructor - -> [Exp -> Q Exp] -- ^ how to peek every field + -> [Exp -> Q Exp] -- ^ how to offset to every field -> Name -- ^ the base pointer -> Q Exp -- ^ a 'do' expression for peeking the constructor -peekCon con peekFields ptr = do +peekCon con fieldOffsets ptr = do (ns, binds) <- unzip <$> do - for peekFields $ \fldCont -> do + for fieldOffsets $ \offset -> do n <- newName "n" - pure (varE n, bindS (varP n) (fldCont (VarE ptr))) + pure (varE n, bindS (varP n) [e| peek $(offset (VarE ptr)) |]) let ret = [e| return $(conExp con <$> sequence ns) |] doE (binds ++ [noBindS ret]) --- | Produces a 'do' block for poking a constructor, along with a pattern for --- extracting out the right fields. Given a pattern like @Con f1 f2 ... fn@, the --- generated block has the following shape: --- --- @ --- do ... ptr f1 --- ... ptr f2 --- ... --- ... ptr fn --- @ -pokeCon :: Constructor -- ^ name of the constructor - -> [Exp -> Q Exp] -- ^ how to poke every field - -> Name -- ^ the base poniter - -> Q (Pat, Exp) -- ^ a pattern to match, an expression for poking -pokeCon con pokeFields ptr = do - (ns, stmts) <- unzip <$> do - for pokeFields $ \fldCont -> do - n <- newName "n" - pure (varP n, noBindS [e| $(fldCont (VarE ptr)) $(varE n) |]) - pat <- conPat con <$> sequence ns - expr <- if null stmts then [e| pure () |] else doE stmts - return (pat, expr) +alignQInt :: Q Exp -> Q Exp -> Q Exp +alignQInt size alignment = [e| (($size + $alignment - 1) `div` $alignment) * $alignment |] +alignCodeInt :: Code Q Int -> Code Q Int -> Code Q Int +alignCodeInt size alignment = [|| (($$size + $$alignment - 1) `div` $$alignment) * $$alignment ||] -- * Traversing fields (putting everything together) --- TODO: look at `alignPtr :: Ptr a -> Int -> Ptr a` - -- | Process a field of a given type. -processField :: Type -> StructState (Exp -> Q Exp, Exp -> Q Exp) -processField ty = do +processField :: Name -> Name -> Type -> StructState (Exp -> Q Exp) +processField alignment sizeOf ty = do let alignTy, sizeTy :: Code Q Int - alignTy = Code $ TExp <$> [e| alignment (undefined :: $(pure ty)) |] - sizeTy = Code $ TExp <$> [e| sizeOf (undefined :: $(pure ty)) |] + alignTy = Code $ TExp <$> [e| $(varE alignment) (undefined :: $(pure ty)) |] + sizeTy = Code $ TExp <$> [e| $(varE sizeOf) (undefined :: $(pure ty)) |] -- get state at the end of the last field - Alignment prevDecs prevOff prevAlign <- get + (Alignment prevDecs prevOff prevAlign) <- get + + -- where to peek: align (prevOff) currentAlign + -- new total alignment: max prevAlign currentAlign + -- new offset: where to peek + currentSize -- beginning offset beginOffV <- lift $ newName "beginOff" let beginOffE, beginOff :: Code Q Int - beginOffE = [|| $$prevOff + mod (negate $$prevOff) $$alignTy ||] + beginOffE = alignCodeInt prevOff alignTy beginOff = Code $ TExp <$> varE beginOffV assignBeginOff <- lift [d| $(varP beginOffV) = $(unType <$> examineCode beginOffE) |] @@ -209,7 +207,7 @@ processField ty = do -- alignment after this field newAlignV <- lift $ newName "algn" let newAlignE :: Code Q Int - newAlignE = [|| $$alignTy `max` $$prevAlign ||] + newAlignE = [|| max $$alignTy $$prevAlign ||] newAlign <- lift (TExp <$> varE newAlignV) assignNewAlign <- lift [d| $(varP newAlignV) = $(unType <$> examineCode newAlignE) |] @@ -224,57 +222,58 @@ processField ty = do }) -- TODO: consider degenerate sizeof(..) = 0 cases - pure ( \addrE -> [e| peek (castPtr $(pure addrE) `plusPtr` $(unType <$> examineCode beginOff)) |] - , \addrE -> [e| poke (castPtr $(pure addrE) `plusPtr` $(unType <$> examineCode beginOff)) |] - ) + pure $ \addrE -> [e| (castPtr $(pure addrE) `plusPtr` $(unType <$> examineCode beginOff)) |] -- | Process an algebraic data type. -- -- TODO: think about the zero constructor case... processADT :: [(Constructor, [Type])] -- ^ constructors and the types of their fields - -> Q (Dec, Dec) -- ^ with and peek implementations + -> Q [Dec] -- ^ marshalable implementations -- The one constructor case is special - we don't need to specify a tag processADT [(con, fields)] = do - initAlign <- mempty - (peekPokes, Alignment ds off algn) - <- runStateT (traverse processField fields) initAlign - let ds' = map pure ds - - -- sizeOf - sizeOf_ <- do - Just sizeOfN <- lookupValueName "sizeOf" - funD sizeOfN [clause [wildP] - (normalB [e| let c = $(unType <$> examineCode off) - in c + mod (negate c) $(unType <$> examineCode algn) |]) - ds'] - - -- alignment - alignment_ <- do - Just alignmentN <- lookupValueName "alignment" - funD alignmentN [clause [wildP] (normalB (unType <$> examineCode algn)) ds'] - - let (peekFields, pokeFields) = unzip peekPokes - - -- peek - peek_ <- do + (offsetsWith, Alignment dsWith sizeWith algnWith) <- runStateT (traverse (processField 'alignmentWith 'sizeOfWith) fields) initAlign + (offsetsPeek, Alignment dsPeek sizePeek algnPeek) <- runStateT (traverse (processField 'alignmentPeek 'sizeOfPeek) fields) initAlign + + sizeOfWith' <- funD + (mkName "sizeOfWith") + [clause [wildP] + (NormalB . unType <$> examineCode (alignCodeInt sizeWith algnWith)) + (pure <$> dsWith)] + + alignmentWith' <- funD + (mkName "alignmentWith") + [clause [wildP] + (NormalB . unType <$> examineCode algnWith) + (pure <$> dsWith)] + + withLoc' <- do ptr <- newName "ptr" - Just peekN <- lookupValueName "peek" - funD peekN [clause [varP ptr] (normalB (peekCon con peekFields ptr)) ds'] - - -- poke - poke_ <- do + k <- newName "k" + (pat, body) <- withCon con offsetsWith ptr k + funD (mkName "withLoc") [clause [pure pat, varP ptr, varP k] (normalB $ pure body) (pure <$> dsWith)] + + sizeOfPeek' <- funD + (mkName "sizeOfPeek") + [clause [wildP] + (NormalB . unType <$> examineCode (alignCodeInt sizePeek algnPeek)) + (pure <$> dsPeek)] + + alignmentPeek' <- funD + (mkName "alignmentPeek") + [clause [wildP] + (NormalB . unType <$> examineCode algnPeek) + (pure <$> dsPeek)] + + peek' <- do ptr <- newName "ptr" - (cPat,body) <- pokeCon con pokeFields ptr - Just pokeN <- lookupValueName "poke" - funD pokeN [clause [varP ptr, pure cPat] (normalB (pure body)) ds'] + funD (mkName "peek") [clause [varP ptr] (normalB (peekCon con offsetsPeek ptr)) (pure <$> dsPeek)] - pure [sizeOf_, alignment_, peek_, poke_] + pure [sizeOfWith', alignmentWith', withLoc', sizeOfPeek', alignmentPeek', peek'] processADT cons = do - let discNum = length cons discTy <- snd . head . dropWhile (\(m,_) -> discNum > m + 1) $ [ (fromIntegral (maxBound :: Word8), [t| Word8 |]) @@ -282,69 +281,76 @@ processADT cons = do , (fromIntegral (maxBound :: Word32), [t| Word32 |]) , (fromIntegral (maxBound :: Word64), [t| Word64 |]) ] - + initAlign <- mempty - (conPeekPokess, algns) <- unzip <$> do + (conWithsPeeks, algnsWith, algnsPeek) <- unzip3 <$> do for cons $ \(con, fields) -> do - (peekPokes, algn) <- runStateT (traverse processField fields) initAlign - let (peekFields, pokeFields) = unzip peekPokes - pure ((con, peekFields, pokeFields), algn) - Alignment ds off algn <- mconcat (map pure algns) - let discSizeOf = [e| sizeOf (undefined :: $(pure discTy)) |] - algn' = [e| $discSizeOf `max` $(unType <$> examineCode algn) |] - let ds' = map pure ds - - -- sizeOf - sizeOf_ <- do - Just sizeOfN <- lookupValueName "sizeOf" - funD sizeOfN [clause [wildP] - (normalB [e| let c = $(unType <$> examineCode off) - in $algn' + c + mod (negate c) $algn' |]) - ds'] - - -- alignment - alignment_ <- do - Just alignmentN <- lookupValueName "alignment" - funD alignmentN [clause [wildP] (normalB algn') ds'] - - -- peek - peek_ <- do + (offsetsWith, algnWith) <- runStateT (traverse (processField 'alignmentWith 'sizeOfWith) fields) initAlign + (offsetsPeek, algnPeek) <- runStateT (traverse (processField 'alignmentPeek 'sizeOfPeek) fields) initAlign + pure ((con, offsetsWith, offsetsPeek), algnWith, algnPeek) + let (Alignment dsWith offWith algnWith) = mconcat algnsWith + let (Alignment dsPeek offPeek algnPeek) = mconcat algnsPeek + let discSizeOf = [e| Foreign.sizeOf (undefined :: $(pure discTy)) |] + discAlign = [e| Foreign.alignment (undefined :: $(pure discTy)) |] + + sizeOfWith' <- funD + (mkName "sizeOfWith") + [clause [wildP] + (normalB [e| $(alignQInt discSizeOf (unType <$> examineCode algnWith)) + $(unType <$> examineCode offWith) |]) + (pure <$> dsWith)] + + alignmentWith' <- funD + (mkName "alignmentWith") + [clause [wildP] + (normalB [e| max $(discAlign) $(unType <$> examineCode algnWith) |]) + (pure <$> dsWith)] + + withLoc' <- do ptr <- newName "ptr" ptrOff <- newName "ptrOff" - d' <- [d| $(varP ptrOff) = $(varE ptr) `plusPtr` $algn' |] - disc <- newName "disc" - let mtchs = [ match (litP n') (normalB (peekCon con peekFields ptrOff)) [] - | (n, (con, peekFields, _)) <- zip [0..] conPeekPokess + k <- newName "k" + d' <- [d| $(varP ptrOff) = ($(varE ptr) `plusPtr` $(discSizeOf)) `alignPtr` $(unType <$> examineCode algnWith) |] + x <- newName "x" + + let mtchs = [ do (pat, body) <- patBody + match (pure pat) + (normalB . doE $ noBindS <$> [ [e| Foreign.poke (Foreign.castPtr $(varE ptr) :: Ptr $(pure discTy)) $(litE n') |] + , pure body + ]) + [] + | (n, (con, offsetsWith, _)) <- zip [0..] conWithsPeeks + , let patBody = withCon con offsetsWith ptrOff k , let n' = IntegerL n ] - Just peekN <- lookupValueName "peek" - funD peekN - [clause [varP ptr] - (normalB (doE [ bindS (varP disc) [e| peek (castPtr $(varE ptr) :: Ptr $(pure discTy)) |] - , noBindS (caseE (varE disc) mtchs) - ])) - (map pure d' ++ ds')] - -- poke - poke_ <- do + funD (mkName "withLoc") [clause [varP x, varP ptr, varP k] (normalB $ caseE (varE x) mtchs) (pure <$> d' ++ dsWith)] + + sizeOfPeek' <- funD + (mkName "sizeOfPeek") + [clause [wildP] + (normalB [e| $(alignQInt discSizeOf (unType <$> examineCode algnPeek)) + $(unType <$> examineCode offPeek) |]) + (pure <$> dsPeek)] + + alignmentPeek' <- funD + (mkName "alignmentPeek") + [clause [wildP] + (NormalB . unType <$> examineCode algnPeek) + (pure <$> dsPeek)] + + peek' <- do ptr <- newName "ptr" ptrOff <- newName "ptrOff" - d' <- [d| $(varP ptrOff) = $(varE ptr) `plusPtr` $algn' |] + d' <- [d| $(varP ptrOff) = ($(varE ptr) `plusPtr` $(discSizeOf)) `alignPtr` $(unType <$> examineCode algnPeek) |] disc <- newName "disc" - let mtchs = [ do { (pat,body) <- patBody - ; match (pure pat) - (normalB (doE (map noBindS [ [e| poke (castPtr $(varE ptr) :: Ptr $(pure discTy)) $(litE n') |] - , pure body - ]))) - [] - } - | (n, (con, _, pokeFields)) <- zip [0..] conPeekPokess - , let patBody = pokeCon con pokeFields ptrOff + let mtchs = [ match (litP n') (normalB (peekCon con offsetsPeek ptrOff)) [] + | (n, (con, _, offsetsPeek)) <- zip [0..] conWithsPeeks , let n' = IntegerL n ] - Just pokeN <- lookupValueName "poke" - funD pokeN - [clause [varP ptr, varP disc] (normalB (caseE (varE disc) mtchs)) (map pure d' ++ ds')] - - pure [sizeOf_, alignment_, peek_, poke_] - + funD (mkName "peek") + [clause [varP ptr] + (normalB (doE [ bindS (varP disc) [e| Foreign.peek (castPtr $(varE ptr) :: Ptr $(pure discTy)) |] + , noBindS (caseE (varE disc) mtchs) + ])) + (pure <$> d' ++ dsPeek)] + + pure [sizeOfWith', alignmentWith', withLoc', sizeOfPeek', alignmentPeek', peek'] diff --git a/src/Language/Rust/Inline/TH/Storable.hs b/src/Language/Rust/Inline/TH/Storable.hs index 989577a..54a822b 100644 --- a/src/Language/Rust/Inline/TH/Storable.hs +++ b/src/Language/Rust/Inline/TH/Storable.hs @@ -245,32 +245,26 @@ processADT [(con, fields)] = do let ds' = map pure ds -- sizeOf - sizeOf_ <- do - Just sizeOfN <- lookupValueName "sizeOf" - funD sizeOfN [clause [wildP] + sizeOf_ <- funD (mkName "sizeOf") [clause [wildP] (normalB [e| let c = $(unType <$> examineCode off) in c + mod (negate c) $(unType <$> examineCode algn) |]) ds'] -- alignment - alignment_ <- do - Just alignmentN <- lookupValueName "alignment" - funD alignmentN [clause [wildP] (normalB (unType <$> examineCode algn)) ds'] + alignment_ <- funD (mkName "alignment") [clause [wildP] (normalB (unType <$> examineCode algn)) ds'] let (peekFields, pokeFields) = unzip peekPokes -- peek peek_ <- do ptr <- newName "ptr" - Just peekN <- lookupValueName "peek" - funD peekN [clause [varP ptr] (normalB (peekCon con peekFields ptr)) ds'] + funD (mkName "peek") [clause [varP ptr] (normalB (peekCon con peekFields ptr)) ds'] -- poke poke_ <- do ptr <- newName "ptr" (cPat,body) <- pokeCon con pokeFields ptr - Just pokeN <- lookupValueName "poke" - funD pokeN [clause [varP ptr, pure cPat] (normalB (pure body)) ds'] + funD (mkName "poke") [clause [varP ptr, pure cPat] (normalB (pure body)) ds'] pure [sizeOf_, alignment_, peek_, poke_] @@ -296,17 +290,13 @@ processADT cons = do let ds' = map pure ds -- sizeOf - sizeOf_ <- do - Just sizeOfN <- lookupValueName "sizeOf" - funD sizeOfN [clause [wildP] + sizeOf_ <- funD (mkName "sizeOf") [clause [wildP] (normalB [e| let c = $(unType <$> examineCode off) in $algn' + c + mod (negate c) $algn' |]) ds'] -- alignment - alignment_ <- do - Just alignmentN <- lookupValueName "alignment" - funD alignmentN [clause [wildP] (normalB algn') ds'] + alignment_ <- funD (mkName "alignment") [clause [wildP] (normalB algn') ds'] -- peek peek_ <- do @@ -318,8 +308,7 @@ processADT cons = do | (n, (con, peekFields, _)) <- zip [0..] conPeekPokess , let n' = IntegerL n ] - Just peekN <- lookupValueName "peek" - funD peekN + funD (mkName "peek") [clause [varP ptr] (normalB (doE [ bindS (varP disc) [e| peek (castPtr $(varE ptr) :: Ptr $(pure discTy)) |] , noBindS (caseE (varE disc) mtchs) @@ -343,8 +332,7 @@ processADT cons = do , let patBody = pokeCon con pokeFields ptrOff , let n' = IntegerL n ] - Just pokeN <- lookupValueName "poke" - funD pokeN + funD (mkName "poke") [clause [varP ptr, varP disc] (normalB (caseE (varE disc) mtchs)) (map pure d' ++ ds')] pure [sizeOf_, alignment_, peek_, poke_] diff --git a/tests/AlgebraicDataTypes.hs b/tests/AlgebraicDataTypes.hs index 8f07aab..e67c4dc 100644 --- a/tests/AlgebraicDataTypes.hs +++ b/tests/AlgebraicDataTypes.hs @@ -18,11 +18,11 @@ import Data.Int ( Int8, Int16, Int32, Int64 ) -- | A struct-like ADT where the fields have different sizes data StructLike = StructLike Int16 Int64 deriving (Show, Eq) -mkMarshalable [t| Storable StructLike |] +mkMarshalable [t| Marshalable StructLike |] -- | A struct-like newtype ADT where the field is compound newtype StructLike2 = StructLike2 (Int16, Int64) deriving (Show, Eq) -mkMarshalable [t| Storable StructLike2 |] +mkMarshalable [t| Marshalable StructLike2 |] -- | An ADT where: @@ -35,14 +35,15 @@ data Foo | Baz Char Int | Qux (Complex Float) Char deriving (Show, Eq) -mkMarshalable [t| Storable Foo |] +mkMarshalable [t| forall a. Marshalable a => Marshalable (Complex a) |] +mkMarshalable [t| Marshalable Foo |] -- | An ADT where fields are nested ADTs data Croc = Lob (Maybe Foo) Int | Boo Int8 Int8 deriving (Show, Eq) -mkMarshalable [t| Storable Croc |] +mkMarshalable [t| Marshalable Croc |] -- | A polymorphic ADT. (From the @these@ package). data These a b @@ -50,7 +51,7 @@ data These a b | That b | Both a b deriving (Show, Eq) -mkMarshalable [t| forall a b. (Storable a, Storable b) => Storable (These a b) |] +mkMarshalable [t| forall a b. (Marshalable a, Marshalable b) => Marshalable (These a b) |] -- | An ADT that needs more that a 'Word8' to store the tag data Big a @@ -85,7 +86,7 @@ data Big a | C280 | C281 | C282 | C283 | C284 | C285 | C286 | C287 | C288 | C289 | C290 | C291 | C292 | C293 | C294 | C295 | C296 | C297 | C298 | C299 a deriving (Show, Eq) -mkMarshalable [t| forall a. Storable a => Storable (Big a) |] +mkMarshalable [t| forall a. Marshalable a => Marshalable (Big a) |] -- | An ADT with a mixture of polymorphism and not. data Foo2 a @@ -94,7 +95,7 @@ data Foo2 a | Qux2 a a | Quux2 Int a deriving (Show, Eq) -mkMarshalable [t| forall a. Storable a => Storable (Foo2 a) |] +mkMarshalable [t| forall a. Marshalable a => Marshalable (Foo2 a) |] -- Set the context @@ -220,126 +221,125 @@ impl These { algebraicDataTypes :: Spec algebraicDataTypes = describe "Algebraic data types" $ do - pure () - -- it "Can marshal a `Complex Float` argument/return" $ do - -- let z1, z2 :: Complex Float - -- z1 = 1.3 :+ 4.5 - -- z2 = 6.7 :+ 8.9 - -- [rust| Cpx { $(z1: Cpx) + $(z2: Cpx) } |] `shouldBe` z1 + z2 - -- - -- it "Can marshal a custom single-constructor ADT argument/return" $ do - -- let s1 = StructLike 78 (negate 267) - -- s2 = StructLike 92 45223 - -- s3 = StructLike2 (34, -92391) - -- s4 = StructLike2 (576, 1234) - -- - -- for_ [s1,s2] $ \si -> - -- [rust| StructLike2 { $(si: StructLike).in2() } |] `shouldBe` in2 si - -- for_ [s3,s4] $ \si -> - -- [rust| StructLike { $(si: StructLike2).out2() } |] `shouldBe` out2 si - -- - -- it "Can marshal a custom monomorphic ADT argument/return" $ do - -- let f1, f2, f3, f4 :: Foo - -- f1 = Baz 'a' 0 - -- f2 = Baz 'b' 2 - -- f3 = Qux (7.1 :+ 3.4) 'f' - -- f4 = Bar - -- - -- for_ [f1,f2,f3,f4] $ \fi -> - -- [rust| Foo { $(fi: Foo).quux() } |] `shouldBe` quux fi - -- - -- it "Can marshal nested monomorphic ADT arguments/returns" $ do - -- let c1, c2, c3, c4, c5, c6, c7 :: Croc - -- c1 = Lob (Just (Baz 'a' 0)) 2 - -- c2 = Lob (Just (Baz 'b' 2)) 6 - -- c3 = Lob (Just (Qux (7.1 :+ 3.4) 'f')) 8 - -- c4 = Lob (Just Bar) 9 - -- c5 = Lob Nothing 3 - -- c6 = Boo 3 (-2) - -- c7 = Boo (-4) 2 - -- - -- for_ [c1,c2,c3,c4,c5,c6,c7] $ \ci -> - -- [rust| Croc { $(ci: Croc).croc() } |] `shouldBe` croc ci - -- - -- it "Can marshal polymorphic ADT arguments/returns" $ do - -- let t1, t2, t3 :: These Int8 Int64 - -- t1 = This maxBound - -- t2 = That 432442 - -- t3 = Both (maxBound - 3) 879 - -- - -- for_ [t1,t2,t3] $ \ti -> - -- let v1 = [rust| These { - -- $(ti: These).bimap(|x| x as i16 * 2, |y| y + 2) - -- } |] - -- v2 = bimap (\x -> fromIntegral x * 2) (+ 2) ti - -- in v1 `shouldBe` v2 - -- - -- it "Can marshal nested polymorphic ADT arguments/returns" $ do - -- let t1, t2, t3 :: These (Maybe Int) (These Int8 Int8) - -- t1 = This (Just 6) - -- t2 = This Nothing - -- t3 = That (This 8) - -- t4 = That (That 9) - -- t5 = That (Both 1 2) - -- t6 = Both (Just 3) (That 8) - -- t7 = Both Nothing (Both 3 5) - -- t8 = Both (Just 213) (Both 78 98) - -- - -- for_ [t1,t2,t3,t4,t5,t6,t7,t8] $ \ti -> - -- let v1 = [rust| These,These> { - -- $(ti: These,These>).bimap( - -- |oi| oi.map(|i| i + 5), - -- |t| t.bimap(|i| i + 2, |j| j * 3), - -- ) - -- } |] - -- v2 = bimap (fmap (+5)) (bimap (+2) (*3)) ti - -- in v1 `shouldBe` v2 - -- - -- it "Can marshal a big ADT whose tag needs more than a `Word8`" $ do - -- let b1, b2, b3, b4 :: Big Int64 - -- b1 = C000 - -- b2 = C160 - -- b3 = C298 - -- b4 = C299 89 - -- - -- for_ [b1,b2,b3,b4] $ \bi -> - -- let v1 = [rust| Big { - -- match $(bi: Big) { - -- Big::C160 => Big::C161, - -- Big::C299(i) => Big::C299(i+1), - -- b => b, - -- } - -- } |] - -- v2 = case bi of - -- C160 -> C161 - -- C299 i -> C299 (i + 1) - -- b -> b - -- in v1 `shouldBe` v2 - -- - -- it "Can marshal a custom `Foo2 Int` and `Foo2 (Foo2 Int)` return" $ do - -- let f1, f2, f3, f4 :: Foo2 Int - -- f1 = Bar2 - -- f2 = Baz2 3 - -- f3 = Qux2 (-1) 2 - -- f4 = Quux2 (-8) 3 - -- - -- let fooed f = case f of - -- Qux2 x y -> Qux2 (Qux2 x y) (Qux2 y x) - -- Quux2 i x -> Quux2 (i + 1) (Qux2 x x) - -- Bar2 -> Bar2 - -- Baz2 w -> Baz2 w - -- - -- let fooed' f = [rust| Foo2> { - -- match $(f: Foo2) { - -- Foo2::Qux2(x,y) => Foo2::Qux2(Foo2::Qux2(x,y), Foo2::Qux2(y,x)), - -- Foo2::Quux2(i,x) => Foo2::Quux2(i+1, Foo2::Qux2(x,x)), - -- Foo2::Bar2 => Foo2::Bar2, - -- Foo2::Baz2(w) => Foo2::Baz2(w), - -- } - -- } |] - -- - -- fooed f1 `shouldBe` fooed' f1 - -- fooed f2 `shouldBe` fooed' f2 - -- fooed f3 `shouldBe` fooed' f3 - -- fooed f4 `shouldBe` fooed' f4 + it "Can marshal a `Complex Float` argument/return" $ do + let z1, z2 :: Complex Float + z1 = 1.3 :+ 4.5 + z2 = 6.7 :+ 8.9 + [rust| Cpx { $(z1: Cpx) + $(z2: Cpx) } |] `shouldBe` z1 + z2 + + it "Can marshal a custom single-constructor ADT argument/return" $ do + let s1 = StructLike 78 (negate 267) + s2 = StructLike 92 45223 + s3 = StructLike2 (34, -92391) + s4 = StructLike2 (576, 1234) + + for_ [s1,s2] $ \si -> + [rust| StructLike2 { $(si: StructLike).in2() } |] `shouldBe` in2 si + for_ [s3,s4] $ \si -> + [rust| StructLike { $(si: StructLike2).out2() } |] `shouldBe` out2 si + + it "Can marshal a custom monomorphic ADT argument/return" $ do + let f1, f2, f3, f4 :: Foo + f1 = Baz 'a' 0 + f2 = Baz 'b' 2 + f3 = Qux (7.1 :+ 3.4) 'f' + f4 = Bar + + for_ [f1,f2,f3,f4] $ \fi -> + [rust| Foo { $(fi: Foo).quux() } |] `shouldBe` quux fi + + it "Can marshal nested monomorphic ADT arguments/returns" $ do + let c1, c2, c3, c4, c5, c6, c7 :: Croc + c1 = Lob (Just (Baz 'a' 0)) 2 + c2 = Lob (Just (Baz 'b' 2)) 6 + c3 = Lob (Just (Qux (7.1 :+ 3.4) 'f')) 8 + c4 = Lob (Just Bar) 9 + c5 = Lob Nothing 3 + c6 = Boo 3 (-2) + c7 = Boo (-4) 2 + + for_ [c1,c2,c3,c4,c5,c6,c7] $ \ci -> + [rust| Croc { $(ci: Croc).croc() } |] `shouldBe` croc ci + + it "Can marshal polymorphic ADT arguments/returns" $ do + let t1, t2, t3 :: These Int8 Int64 + t1 = This maxBound + t2 = That 432442 + t3 = Both (maxBound - 3) 879 + + for_ [t1,t2,t3] $ \ti -> + let v1 = [rust| These { + $(ti: These).bimap(|x| x as i16 * 2, |y| y + 2) + } |] + v2 = bimap (\x -> fromIntegral x * 2) (+ 2) ti + in v1 `shouldBe` v2 + + it "Can marshal nested polymorphic ADT arguments/returns" $ do + let t1, t2, t3 :: These (Maybe Int) (These Int8 Int8) + t1 = This (Just 6) + t2 = This Nothing + t3 = That (This 8) + t4 = That (That 9) + t5 = That (Both 1 2) + t6 = Both (Just 3) (That 8) + t7 = Both Nothing (Both 3 5) + t8 = Both (Just 213) (Both 78 98) + + for_ [t1,t2,t3,t4,t5,t6,t7,t8] $ \ti -> + let v1 = [rust| These,These> { + $(ti: These,These>).bimap( + |oi| oi.map(|i| i + 5), + |t| t.bimap(|i| i + 2, |j| j * 3), + ) + } |] + v2 = bimap (fmap (+5)) (bimap (+2) (*3)) ti + in v1 `shouldBe` v2 + + it "Can marshal a big ADT whose tag needs more than a `Word8`" $ do + let b1, b2, b3, b4 :: Big Int64 + b1 = C000 + b2 = C160 + b3 = C298 + b4 = C299 89 + + for_ [b1,b2,b3,b4] $ \bi -> + let v1 = [rust| Big { + match $(bi: Big) { + Big::C160 => Big::C161, + Big::C299(i) => Big::C299(i+1), + b => b, + } + } |] + v2 = case bi of + C160 -> C161 + C299 i -> C299 (i + 1) + b -> b + in v1 `shouldBe` v2 + + it "Can marshal a custom `Foo2 Int` and `Foo2 (Foo2 Int)` return" $ do + let f1, f2, f3, f4 :: Foo2 Int + f1 = Bar2 + f2 = Baz2 3 + f3 = Qux2 (-1) 2 + f4 = Quux2 (-8) 3 + + let fooed f = case f of + Qux2 x y -> Qux2 (Qux2 x y) (Qux2 y x) + Quux2 i x -> Quux2 (i + 1) (Qux2 x x) + Bar2 -> Bar2 + Baz2 w -> Baz2 w + + let fooed' f = [rust| Foo2> { + match $(f: Foo2) { + Foo2::Qux2(x,y) => Foo2::Qux2(Foo2::Qux2(x,y), Foo2::Qux2(y,x)), + Foo2::Quux2(i,x) => Foo2::Quux2(i+1, Foo2::Qux2(x,x)), + Foo2::Bar2 => Foo2::Bar2, + Foo2::Baz2(w) => Foo2::Baz2(w), + } + } |] + + fooed f1 `shouldBe` fooed' f1 + fooed f2 `shouldBe` fooed' f2 + fooed f3 `shouldBe` fooed' f3 + fooed f4 `shouldBe` fooed' f4 diff --git a/tests/ByteStrings.hs b/tests/ByteStrings.hs index 044ab6e..62dc940 100644 --- a/tests/ByteStrings.hs +++ b/tests/ByteStrings.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE TypeApplications #-} + module ByteStrings where import Language.Rust.Inline @@ -9,6 +11,7 @@ import qualified Data.ByteString.Unsafe as ByteString import Data.Maybe (fromJust) import Data.String import Data.Either (fromRight) +import Language.Rust.Inline.TH extendContext basic extendContext prelude @@ -35,6 +38,26 @@ bytestringSpec = describe "ByteStrings" $ do `shouldBe` rustBs ByteString.unsafeFinalize rustBs + it "can marshal optional ByteString arguments and return values" $ do + let rustSum inputs = + [rust| Option { + let inputs = $( inputs: Option<&[u8]> ); + inputs.map(|inputs| inputs.iter().sum()) + } |] + let inputs = ByteString.pack [0, 1, 2, 3] + rustSum Nothing `shouldBe` Nothing + rustSum (Just inputs) `shouldBe` Just (sum $ ByteString.unpack inputs) + + it "can marshal result ByteString arguments" $ do + let rustSum inputs = + [rust| Result { + let inputs = $( inputs: Result<&[u8], ()> ); + inputs.map(|inputs| inputs.iter().sum()) + } |] + let inputs = ByteString.pack [0, 1, 2, 3] + rustSum (Left ()) `shouldBe` Left () + rustSum (Right inputs) `shouldBe` Right (sum $ ByteString.unpack inputs) + it "can marshal optional ByteString return values" $ do let noRustBs = [rust| Option> { None } |] noRustBs `shouldBe` Nothing @@ -42,9 +65,9 @@ bytestringSpec = describe "ByteStrings" $ do let rustBs = [rust| Option> { Some(vec![0, 1, 2, 3]) } |] rustBs `shouldBe` Just (ByteString.pack [0, 1, 2, 3]) - -- it "can marshal result ByteString return values" $ do - -- let errRustBs = [rust| Result, ()> { Err(()) } |] - -- errRustBs `shouldBe` Left () + it "can marshal result ByteString return values" $ do + let errRustBs = [rust| Result { Err(()) } |] + errRustBs `shouldBe` Left () - -- let okRustBs = [rust| Result, ()> { Ok(vec![0, 1, 2, 3]) } |] - -- okRustBs `shouldBe` Right (ByteString.pack [0, 1, 2, 3]) + let okRustBs = [rust| Result { Ok(vec![0, 1, 2, 3].marshal()) } |] + okRustBs `shouldBe` Right (ByteString.pack [0, 1, 2, 3]) diff --git a/tests/ForeignPtr.hs b/tests/ForeignPtr.hs index 4eff1c5..14edd19 100644 --- a/tests/ForeignPtr.hs +++ b/tests/ForeignPtr.hs @@ -56,18 +56,18 @@ foreignPtrTypes = describe "ForeignPtr types" $ do } |] withForeignPtr (fromJust mp) peek >>= (`shouldBe` 42) - -- it "Can marshal result ForeignPtr returns" $ do - -- let mp = - -- [rust| Result, ()> { - -- Err(()) - -- } |] - -- mp `shouldBe` Left () + it "Can marshal result ForeignPtr returns" $ do + let mp = + [rust| Result, ()> { + Err(()) + } |] + mp `shouldBe` Left () - -- let mp = - -- [rust| Result, ()> { - -- Ok(Box::new(42).into()) - -- } |] - -- withForeignPtr (fromRight mp) peek >>= (`shouldBe` 42) + let mp = + [rust| Result, ()> { + Ok(Box::new(42).into()) + } |] + withForeignPtr (fromRight undefined mp) peek >>= (`shouldBe` 42) it "still has working pointers" $ alloca $ \p -> do diff --git a/tests/PreludeTypes.hs b/tests/PreludeTypes.hs index ff31d70..2ee27c9 100644 --- a/tests/PreludeTypes.hs +++ b/tests/PreludeTypes.hs @@ -15,59 +15,57 @@ setCrateModule preludeTypes :: Spec preludeTypes = describe "Common Prelude types" $ do - pure () - -- it "Can marshal a `Maybe Int32` argument/return" $ do - -- let x1, x2 :: Maybe Int32 - -- x1 = Just 9 - -- x2 = Nothing - -- [rust| Option { $(x1: Option).map(|n| n+2) } |] `shouldBe` fmap (+2) x1 - -- [rust| Option { $(x2: Option).map(|n| n+2) } |] `shouldBe` fmap (+2) x2 - -- - -- it "Can marshal an `Either Int32 Char` argument/return" $ do - -- let x1, x2 :: Either Int32 Char - -- x1 = Left 9 - -- x2 = Right 'e' - -- - -- [rust| Result { - -- $(x1: Result) - -- .map(|c| c.to_uppercase().next().unwrap()) - -- .map_err(|n| n+2) - -- } |] `shouldBe` bimap (+2) toUpper x1 - -- - -- [rust| Result { - -- $(x2: Result) - -- .map(|c| c.to_uppercase().next().unwrap()) - -- .map_err(|n| n+2) - -- } |] `shouldBe` bimap (+2) toUpper x2 - -- - -- it "Can marshal `(Int32, Char)` argument and `(Int32, Char, Word32)` return" $ do - -- let x = (-9, 'c') - -- - -- [rust| (i32, char, u32) { - -- let (n, c) = $(x: (i32, char)); - -- (n * 3, c.to_uppercase().next().unwrap(), n.abs() as u32) - -- } |] `shouldBe` (fst x * 3, toUpper (snd x), fromIntegral (abs (fst x))) - -- - -- it "Can marshal `Maybe (Int32, Either Char Word32)` argument/return" $ do - -- let x1, x2, x3 :: Maybe (Int32, Either Char Word32) - -- x1 = Nothing - -- x2 = Just (3, Left 'c') - -- x3 = Just (4, Right 8) - -- - -- let f x = [rust| Option<(i32, Result)> { - -- $(x: Option<(i32, Result)>).map(|x| { - -- match x { - -- (x1, Ok(x2)) => (x1 * x2 as i32, Err('a')), - -- (x1, e) => (x1 - 1, e), - -- } - -- }) - -- } |] - -- - -- let f' = fmap (\x -> case x of - -- (x1, Right x2) -> (x1 * fromIntegral x2, Left 'a') - -- (x1, e) -> (x1 - 1, e)) - -- - -- f x1 `shouldBe` f' x1 - -- f x2 `shouldBe` f' x2 - -- f x3 `shouldBe` f' x3 - + it "Can marshal a `Maybe Int32` argument/return" $ do + let x1, x2 :: Maybe Int32 + x1 = Just 9 + x2 = Nothing + [rust| Option { $(x1: Option).map(|n| n+2) } |] `shouldBe` fmap (+2) x1 + [rust| Option { $(x2: Option).map(|n| n+2) } |] `shouldBe` fmap (+2) x2 + + it "Can marshal an `Either Int32 Char` argument/return" $ do + let x1, x2 :: Either Int32 Char + x1 = Left 9 + x2 = Right 'e' + + [rust| Result { + $(x1: Result) + .map(|c| c.to_uppercase().next().unwrap()) + .map_err(|n| n+2) + } |] `shouldBe` bimap (+2) toUpper x1 + + [rust| Result { + $(x2: Result) + .map(|c| c.to_uppercase().next().unwrap()) + .map_err(|n| n+2) + } |] `shouldBe` bimap (+2) toUpper x2 + + it "Can marshal `(Int32, Char)` argument and `(Int32, Char, Word32)` return" $ do + let x = (-9, 'c') + + [rust| (i32, char, u32) { + let (n, c) = $(x: (i32, char)); + (n * 3, c.to_uppercase().next().unwrap(), n.abs() as u32) + } |] `shouldBe` (fst x * 3, toUpper (snd x), fromIntegral (abs (fst x))) + + it "Can marshal `Maybe (Int32, Either Char Word32)` argument/return" $ do + let x1, x2, x3 :: Maybe (Int32, Either Char Word32) + x1 = Nothing + x2 = Just (3, Left 'c') + x3 = Just (4, Right 8) + + let f x = [rust| Option<(i32, Result)> { + $(x: Option<(i32, Result)>).map(|x| { + match x { + (x1, Ok(x2)) => (x1 * x2 as i32, Err('a')), + (x1, e) => (x1 - 1, e), + } + }) + } |] + + let f' = fmap (\x -> case x of + (x1, Right x2) -> (x1 * fromIntegral x2, Left 'a') + (x1, e) -> (x1 - 1, e)) + + f x1 `shouldBe` f' x1 + f x2 `shouldBe` f' x2 + f x3 `shouldBe` f' x3 diff --git a/tests/SimpleTypes.hs b/tests/SimpleTypes.hs index 29d77cb..85e3e0c 100644 --- a/tests/SimpleTypes.hs +++ b/tests/SimpleTypes.hs @@ -66,7 +66,7 @@ simpleTypes = describe "Simple types" $ do it "Can marshal a `Bool` argument/return" $ do let x = 0 :: Word8 [rust| bool { !$(x: bool) } |] `shouldBe` (1 :: Word8) - --- it "Can marshal a `()` argument/return" $ do --- let x = () --- [rust| () { $(x: ()) } |] `shouldBe` () + + it "Can marshal a `()` argument/return" $ do + let x = () + [rust| () { $(x: ()) } |] `shouldBe` () From ff2f9829898327a6982908a073c8866966c96ccb Mon Sep 17 00:00:00 2001 From: Viktor Kleen Date: Sun, 23 Feb 2025 15:07:35 +0000 Subject: [PATCH 08/13] Enable `Vec` in ADT return types in rust --- src/Language/Rust/Inline/Context.hs | 8 ++++---- .../Rust/Inline/Context/ByteString.hs | 20 ++++++------------- src/Language/Rust/Inline/Context/Prelude.hs | 2 +- tests/ByteStrings.hs | 4 ++-- 4 files changed, 13 insertions(+), 21 deletions(-) diff --git a/src/Language/Rust/Inline/Context.hs b/src/Language/Rust/Inline/Context.hs index c4513e2..85a5c50 100644 --- a/src/Language/Rust/Inline/Context.hs +++ b/src/Language/Rust/Inline/Context.hs @@ -61,17 +61,17 @@ type. newtype Context = Context ( [RType -> Context -> First (Q HType, Maybe (Q RType))] - , -- Given a Rust type in a quasiquote, we need to look up the + -- Given a Rust type in a quasiquote, we need to look up the -- corresponding Haskell type (for the FFI import) as well as the -- C-compatible Rust type (if the initial Rust type isn't already -- @#[repr(C)]@. - [HType -> Context -> First (Q RType)] - , -- Given a field in a Haskell ADT, we need to figure out which + , [HType -> Context -> First (Q RType)] + -- Given a field in a Haskell ADT, we need to figure out which -- (not-necessarily @#[repr(C)]@) Rust type normally maps into this -- Haskell type. - [String] + , [String] -- Source for the trait impls of @MarshalTo@ ) deriving (Semigroup, Monoid, Typeable) diff --git a/src/Language/Rust/Inline/Context/ByteString.hs b/src/Language/Rust/Inline/Context/ByteString.hs index 3bfd539..25db20b 100644 --- a/src/Language/Rust/Inline/Context/ByteString.hs +++ b/src/Language/Rust/Inline/Context/ByteString.hs @@ -32,16 +32,18 @@ import Foreign.Ptr (Ptr) import Debug.Trace (traceM) bytestrings :: Q Context -bytestrings = - pure $ Context ([rule], [], [rustByteString, impl]) +bytestrings = do + bs <- [t|ByteString|] + pure $ Context ([rule], [rev bs], [rustByteString, impl]) where rule rty _ | rty == void [ty| &[u8] |] = pure ([t|ByteString|], pure . pure $ void [ty| RustByteString |]) | rty == void [ty| Vec |] = pure ([t|ByteString|], pure . pure $ void [ty| RustOwnedByteString |]) - | rty == void [ty| RustOwnedByteString |] = pure ([t|ByteString|], pure . pure $ void [ty| RustOwnedByteString |]) - -- | rty == void [ty| Option> |] = pure ([t|Maybe ByteString|], pure . pure $ void [ty| RustOwnedByteString |]) rule _ _ = mempty + rev bs hty _ | bs == hty = pure . pure $ void [ty|Vec|] + rev _ _ _ = mempty + rustByteString = unlines [ "#[repr(C)]" @@ -74,14 +76,4 @@ bytestrings = , " RustOwnedByteString(bytes.as_mut_ptr(), len, free)" , " }" , "}" - , "" - , "impl MarshalInto for RustOwnedByteString { fn marshal(self) -> RustOwnedByteString { self } }" - -- , "impl MarshalInto for Option> {" - -- , " fn marshal(self) -> RustOwnedByteString {" - -- , " extern fn panic(ptr: *mut u8, len: usize) {" - -- , " panic!(\"Attempted to free a null ByteString\");" - -- , " }" - -- , " self.map(|bs| bs.marshal()).unwrap_or(RustOwnedByteString(std::ptr::null_mut(), 0, panic))" - -- , " }" - -- , "}" ] diff --git a/src/Language/Rust/Inline/Context/Prelude.hs b/src/Language/Rust/Inline/Context/Prelude.hs index 288cb68..b72b57f 100644 --- a/src/Language/Rust/Inline/Context/Prelude.hs +++ b/src/Language/Rust/Inline/Context/Prelude.hs @@ -205,7 +205,7 @@ eitherItems = map unlines , "struct RightC(T);" ] -- impl MarshalInto> for Result - , [ "impl + Copy, R: Copy, R1: MarshalInto + Copy> MarshalInto> for Result {" + , [ "impl, R: Copy, R1: MarshalInto> MarshalInto> for Result {" , " fn marshal(self) -> EitherC {" , " match self {" , " Err(l) => EitherC { tag: 0, payload: TaggedEitherC { left: LeftC(l.marshal()) } }," diff --git a/tests/ByteStrings.hs b/tests/ByteStrings.hs index 62dc940..a450c6a 100644 --- a/tests/ByteStrings.hs +++ b/tests/ByteStrings.hs @@ -66,8 +66,8 @@ bytestringSpec = describe "ByteStrings" $ do rustBs `shouldBe` Just (ByteString.pack [0, 1, 2, 3]) it "can marshal result ByteString return values" $ do - let errRustBs = [rust| Result { Err(()) } |] + let errRustBs = [rust| Result, ()> { Err(()) } |] errRustBs `shouldBe` Left () - let okRustBs = [rust| Result { Ok(vec![0, 1, 2, 3].marshal()) } |] + let okRustBs = [rust| Result, ()> { Ok(vec![0, 1, 2, 3]) } |] okRustBs `shouldBe` Right (ByteString.pack [0, 1, 2, 3]) From 570f32992f2ae80f0fffd7bedef3ff6dd48af503 Mon Sep 17 00:00:00 2001 From: ners Date: Sun, 23 Feb 2025 16:43:15 +0100 Subject: [PATCH 09/13] still works without copy --- src/Language/Rust/Inline/Context.hs | 5 +++-- tests/ForeignPtr.hs | 14 +++++++++++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/Language/Rust/Inline/Context.hs b/src/Language/Rust/Inline/Context.hs index 85a5c50..50aa588 100644 --- a/src/Language/Rust/Inline/Context.hs +++ b/src/Language/Rust/Inline/Context.hs @@ -329,9 +329,10 @@ foreignPointers = do foreignPtr = unlines - [ "#[derive(Copy, Clone)]" - , "#[repr(C)]" + [ "#[repr(C)]" , "pub struct ForeignPtr(pub *mut T, pub extern \"C\" fn (*mut T));" + , "impl Copy for ForeignPtr {}" + , "impl Clone for ForeignPtr { fn clone(&self) -> Self { ForeignPtr(self.0, self.1) } }" ] constPtr = diff --git a/tests/ForeignPtr.hs b/tests/ForeignPtr.hs index 14edd19..b2c8e98 100644 --- a/tests/ForeignPtr.hs +++ b/tests/ForeignPtr.hs @@ -1,8 +1,9 @@ module ForeignPtr where import Language.Rust.Inline +import Language.Rust.Quote -import Data.Maybe (fromJust) +import Data.Maybe (fromJust, isJust) import Data.Word (Word64) import Foreign (Storable (..)) import Foreign.ForeignPtr @@ -16,6 +17,12 @@ extendContext prelude extendContext basic setCrateModule + +extendContext (singleton [ty| NotCopy |] [t| () |]) +[rust| +pub struct NotCopy(()); +|] + foreignPtrTypes :: Spec foreignPtrTypes = describe "ForeignPtr types" $ do it "Can marshal ForeignPtr arguments as references" $ do @@ -81,3 +88,8 @@ foreignPtrTypes = describe "ForeignPtr types" $ do unsafe { *$(p: *u64) } } |] val `shouldBe` 42 + + it "still works without Copy" $ do + let mp = [rust| Option> { Some(Box::new(NotCopy(())).into()) } |] + mp `shouldSatisfy` isJust + From 703846421d90d181124bfc45e1cc482345d6fc63 Mon Sep 17 00:00:00 2001 From: ners Date: Sun, 23 Feb 2025 18:18:44 +0100 Subject: [PATCH 10/13] do not the stack --- flake.nix | 4 +++- inline-rust.cabal | 9 ++++++--- src/Language/Rust/Inline.hs | 8 ++++---- src/Language/Rust/Inline/TH/Marshalable.hs | 20 ++++++++++++++------ tests/ByteStrings.hs | 2 -- tests/Concurrency.hs | 22 ++++++++++++++++++++++ tests/Main.hs | 2 ++ 7 files changed, 51 insertions(+), 16 deletions(-) create mode 100644 tests/Concurrency.hs diff --git a/flake.nix b/flake.nix index 8c273ca..05b3b10 100644 --- a/flake.nix +++ b/flake.nix @@ -83,10 +83,12 @@ packages = ps: [ hp.${pname} ]; nativeBuildInputs = with pkgs'; with haskellPackages; [ pkgs'.haskellPackages.cabal-install + cargo fourmolu + gdb haskell-language-server - cargo rustc + valgrind ]; }; }); diff --git a/inline-rust.cabal b/inline-rust.cabal index 4040776..021db2e 100644 --- a/inline-rust.cabal +++ b/inline-rust.cabal @@ -22,7 +22,7 @@ source-repository head library hs-source-dirs: src - ghc-options: -Wall -ddump-splices -ddump-to-file + ghc-options: -Wall -ddump-splices -ddump-to-file -g default-language: Haskell2010 exposed-modules: Language.Rust.Inline @@ -65,7 +65,7 @@ library test-suite spec hs-source-dirs: tests - ghc-options: -threaded -ddump-splices -ddump-to-file + ghc-options: -threaded -ddump-splices -ddump-to-file -g if os(windows) extra-libraries: @@ -93,9 +93,12 @@ test-suite spec , Submodule.Submodule , ByteStrings , ForeignPtr - build-depends: base + , Concurrency + build-depends: async + , base , inline-rust , language-rust , hspec , ghc-prim , bytestring + , unliftio diff --git a/src/Language/Rust/Inline.hs b/src/Language/Rust/Inline.hs index c46f38e..4c858b9 100644 --- a/src/Language/Rust/Inline.hs +++ b/src/Language/Rust/Inline.hs @@ -104,6 +104,7 @@ import Foreign.Marshal.Array (newArray, withArrayLen) import Foreign.Marshal.Unsafe (unsafeLocalState) import Foreign.Marshal.Utils (new, with) import Foreign.Ptr (freeHaskellFunPtr) +import qualified Foreign import Control.Monad (void) import Data.List (intercalate) @@ -344,7 +345,9 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do | otherwise = do ret <- newName "ret" [e| - alloca + Foreign.allocaBytesAligned + (Marshalable.sizeOfPeek (undefined :: $(pure haskRet))) + (Marshalable.alignmentPeek (undefined :: $(pure haskRet))) ( \($(varP ret)) -> do $(appsE (varE qqName : reverse (varE ret : acc))) @@ -385,9 +388,6 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do mergeArgs t Nothing = (t, t) mergeArgs t (Just tInter) = (fmap (const mempty) tInter, t) - -- EitherC - -- EitherC -> Result - -- Generate the Rust function. let retByVal = returnByValue returnFfi (retArg, retTy, ret) diff --git a/src/Language/Rust/Inline/TH/Marshalable.hs b/src/Language/Rust/Inline/TH/Marshalable.hs index d65c1f7..00bff0c 100644 --- a/src/Language/Rust/Inline/TH/Marshalable.hs +++ b/src/Language/Rust/Inline/TH/Marshalable.hs @@ -29,6 +29,7 @@ import Foreign.Ptr ( alignPtr, plusPtr, castPtr, Ptr ) import Data.Word ( Word8, Word16, Word32, Word64 ) import Language.Rust.Inline.Context.Marshalable import qualified Foreign +import Data.List (intercalate) -- | Generate 'Marshalable' instance for a non-recursive simple algebraic data -- type. The instance follows the usual C layout for determining alignment and @@ -54,10 +55,10 @@ mkMarshalable tyq = do _ -> fail "mkMarshalable: malformed 'Marshalable' instance head" -- Get the type constructors name - (_,cons') <- getConstructors ty' + (name,cons') <- getConstructors ty' -- Produce the instance - decs' <- processADT [ (nameCon n, tyArgs) | (n,tyArgs) <- cons' ] + decs' <- processADT name [ (nameCon n, tyArgs) | (n,tyArgs) <- cons' ] pure . pure $ InstanceD Nothing ctx (AppT marshalable ty') decs' mkTupleMarshalable :: Int -- ^ arity of tuple @@ -68,10 +69,11 @@ mkTupleMarshalable n = do | i <- [(1 :: Int)..] , c <- ['a'..'z'] ]) + let name = mkName $ "(" <> intercalate "," (show <$> tyVars) <> ")" let ctx c = [ AppT c (VarT tyVar) | tyVar <- tyVars ] let instHead c = AppT c (foldl AppT (TupleT n) (map VarT tyVars)) - decs' <- processADT [ (tupCon, map VarT tyVars) ] + decs' <- processADT name [ (tupCon, map VarT tyVars) ] pure . pure $ InstanceD Nothing (ctx marshalable) (instHead marshalable) decs' -- * Constructor utilities @@ -228,11 +230,12 @@ processField alignment sizeOf ty = do -- | Process an algebraic data type. -- -- TODO: think about the zero constructor case... -processADT :: [(Constructor, [Type])] -- ^ constructors and the types of their fields +processADT :: Name -- ^ name of the type + -> [(Constructor, [Type])] -- ^ constructors and the types of their fields -> Q [Dec] -- ^ marshalable implementations -- The one constructor case is special - we don't need to specify a tag -processADT [(con, fields)] = do +processADT _ [(con, fields)] = do initAlign <- mempty (offsetsWith, Alignment dsWith sizeWith algnWith) <- runStateT (traverse (processField 'alignmentWith 'sizeOfWith) fields) initAlign (offsetsPeek, Alignment dsPeek sizePeek algnPeek) <- runStateT (traverse (processField 'alignmentPeek 'sizeOfPeek) fields) initAlign @@ -273,7 +276,7 @@ processADT [(con, fields)] = do pure [sizeOfWith', alignmentWith', withLoc', sizeOfPeek', alignmentPeek', peek'] -processADT cons = do +processADT name cons = do let discNum = length cons discTy <- snd . head . dropWhile (\(m,_) -> discNum > m + 1) $ [ (fromIntegral (maxBound :: Word8), [t| Word8 |]) @@ -346,6 +349,11 @@ processADT cons = do | (n, (con, _, offsetsPeek)) <- zip [0..] conWithsPeeks , let n' = IntegerL n ] + <> + [ match wildP (normalB [e| + fail $ "Unknown discriminator for " <> $(liftString $ show name) <> ": " <> show $(varE disc) + |]) [] + ] funD (mkName "peek") [clause [varP ptr] (normalB (doE [ bindS (varP disc) [e| Foreign.peek (castPtr $(varE ptr) :: Ptr $(pure discTy)) |] diff --git a/tests/ByteStrings.hs b/tests/ByteStrings.hs index a450c6a..a390ba6 100644 --- a/tests/ByteStrings.hs +++ b/tests/ByteStrings.hs @@ -1,5 +1,3 @@ -{-# LANGUAGE TypeApplications #-} - module ByteStrings where import Language.Rust.Inline diff --git a/tests/Concurrency.hs b/tests/Concurrency.hs new file mode 100644 index 0000000..6daab9b --- /dev/null +++ b/tests/Concurrency.hs @@ -0,0 +1,22 @@ +module Concurrency where + +import Control.Concurrent.Async +import Control.Monad (replicateM_) +import Foreign (withForeignPtr, peek) +import Language.Rust.Inline +import Language.Rust.Inline.TH (sizeOfWith) +import Test.Hspec +import Data.Word + +extendContext foreignPointers +extendContext pointers +extendContext prelude +extendContext basic +setCrateModule + +concurrencySpec :: Spec +concurrencySpec = describe "Concurrency" $ do + it "does not crash" $ do + let p = [rust| ForeignPtr { Box::new(0).into() } |] + replicateConcurrently_ 10000 $ replicateM_ 100 [rustIO| Option<()> { *$(p: &mut u64) += 1; Some(()) } |] + withForeignPtr p peek `shouldNotReturn` 0 diff --git a/tests/Main.hs b/tests/Main.hs index 7fc208a..5c459be 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -12,6 +12,7 @@ import Language.Rust.Inline import AlgebraicDataTypes import ByteStrings +import Concurrency (concurrencySpec) import Data.Word import Foreign.Marshal.Array import Foreign.Ptr @@ -34,6 +35,7 @@ main = hspec $ describe "Rust quasiquoter" $ do algebraicDataTypes bytestringSpec + concurrencySpec foreignPtrTypes funcPointerTypes ghcUnboxedTypes From ed87ebd67505456e8483453e87bc64c11c8799f9 Mon Sep 17 00:00:00 2001 From: ners Date: Sun, 23 Feb 2025 20:51:41 +0100 Subject: [PATCH 11/13] we have vectors now --- inline-rust.cabal | 3 +- src/Language/Rust/Inline.hs | 2 + src/Language/Rust/Inline/Context.hs | 8 +- .../Rust/Inline/Context/Marshalable.hs | 4 +- src/Language/Rust/Inline/Context/Vector.hs | 87 +++++++++++++++++++ tests/Concurrency.hs | 22 ----- tests/Main.hs | 4 +- tests/Vectors.hs | 30 +++++++ 8 files changed, 131 insertions(+), 29 deletions(-) create mode 100644 src/Language/Rust/Inline/Context/Vector.hs delete mode 100644 tests/Concurrency.hs create mode 100644 tests/Vectors.hs diff --git a/inline-rust.cabal b/inline-rust.cabal index 021db2e..086816f 100644 --- a/inline-rust.cabal +++ b/inline-rust.cabal @@ -31,6 +31,7 @@ library Language.Rust.Inline.Context.ByteString Language.Rust.Inline.Context.Marshalable Language.Rust.Inline.Context.Prelude + Language.Rust.Inline.Context.Vector Language.Rust.Inline.Internal Language.Rust.Inline.Marshal Language.Rust.Inline.Parser @@ -92,8 +93,8 @@ test-suite spec , Submodule , Submodule.Submodule , ByteStrings + , Vectors , ForeignPtr - , Concurrency build-depends: async , base , inline-rust diff --git a/src/Language/Rust/Inline.hs b/src/Language/Rust/Inline.hs index 4c858b9..04d7411 100644 --- a/src/Language/Rust/Inline.hs +++ b/src/Language/Rust/Inline.hs @@ -61,6 +61,7 @@ module Language.Rust.Inline ( pointers, prelude, bytestrings, + vectors, foreignPointers, -- ** Marshalling @@ -87,6 +88,7 @@ module Language.Rust.Inline ( import Language.Rust.Inline.Context import Language.Rust.Inline.Context.ByteString (bytestrings) +import Language.Rust.Inline.Context.Vector (vectors) import Language.Rust.Inline.Context.Prelude (prelude) import Language.Rust.Inline.Internal import Language.Rust.Inline.Marshal diff --git a/src/Language/Rust/Inline/Context.hs b/src/Language/Rust/Inline/Context.hs index 50aa588..01fcca0 100644 --- a/src/Language/Rust/Inline/Context.hs +++ b/src/Language/Rust/Inline/Context.hs @@ -46,6 +46,7 @@ import GHC.Exts ( Int#, Word#, ) +import Data.Void (Void) -- Easier on the eyes type RType = Ty () @@ -319,10 +320,13 @@ foreignPointers = do foreignPtrT <- [t|ForeignPtr|] pure $ Context ([rule], [rev foreignPtrT], [foreignPtr, constPtr, mutPtr]) where + htype _ (Just _) = pure ([t| ForeignPtr Void|], Nothing) -- if the pointee needs marshalling, forbid peeking from Haskell + htype t Nothing = pure ([t|ForeignPtr $t|], Nothing) + rule (Rptr _ _ t _) context - | First (Just (t', Nothing)) <- lookupRTypeInContext t context = pure ([t|ForeignPtr $t'|], Nothing) + | First (Just (t', inter)) <- lookupRTypeInContext t context = htype t' inter rule (PathTy Nothing (Path False [PathSegment "ForeignPtr" (Just (AngleBracketed [] [t] [] _)) _] _) _) context - | First (Just (t', Nothing)) <- lookupRTypeInContext t context = pure ([t|ForeignPtr $t'|], Nothing) + | First (Just (t', inter)) <- lookupRTypeInContext t context = htype t' inter rule _ _ = mempty rev _ _ _ = mempty diff --git a/src/Language/Rust/Inline/Context/Marshalable.hs b/src/Language/Rust/Inline/Context/Marshalable.hs index f8b5621..7eb3f0f 100644 --- a/src/Language/Rust/Inline/Context/Marshalable.hs +++ b/src/Language/Rust/Inline/Context/Marshalable.hs @@ -88,7 +88,7 @@ instance Marshalable ByteString where alignmentPeek _ = Foreign.alignment (undefined :: (Foreign.Ptr Foreign.Word8, Word, Foreign.FunPtr (Foreign.Ptr Foreign.Word8 -> Word -> IO ()))) peek p = do (ptr, len, finalizer) <- Foreign.peek (castPtr p) - ByteString.unsafePackCStringFinalizer ptr (fromIntegral len) (bytestringFree finalizer ptr len) + ByteString.unsafePackCStringFinalizer ptr (fromIntegral len) (freeByteString finalizer ptr len) instance Marshalable (Foreign.ForeignPtr a) where sizeOfWith = const $ sizeOf (undefined :: (Foreign.Ptr a)) @@ -102,7 +102,7 @@ instance Marshalable (Foreign.ForeignPtr a) where (ptr, finalizer) <- Foreign.peek (castPtr p) Foreign.newForeignPtr finalizer ptr -foreign import ccall safe "dynamic" bytestringFree :: Foreign.FunPtr (Foreign.Ptr Foreign.Word8 -> Word -> IO ()) -> Foreign.Ptr Foreign.Word8 -> Word -> IO () +foreign import ccall safe "dynamic" freeByteString :: Foreign.FunPtr (Foreign.Ptr Foreign.Word8 -> Word -> IO ()) -> Foreign.Ptr Foreign.Word8 -> Word -> IO () instance Marshalable CChar instance Marshalable CSChar diff --git a/src/Language/Rust/Inline/Context/Vector.hs b/src/Language/Rust/Inline/Context/Vector.hs new file mode 100644 index 0000000..ad72bb7 --- /dev/null +++ b/src/Language/Rust/Inline/Context/Vector.hs @@ -0,0 +1,87 @@ +{-# LANGUAGE OverloadedStrings#-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TemplateHaskell #-} +{-# OPTIONS_GHC -Wno-orphans #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Language.Rust.Inline.Context.Vector where +import Language.Rust.Syntax (Ty(PathTy), Path (Path), PathSegment (PathSegment), PathParameters (AngleBracketed)) +import Language.Rust.Inline.Context.Prelude (mkGenPathTy) +import Language.Haskell.TH +import Language.Rust.Quote (ty) +import Data.Maybe (fromMaybe) +import Language.Rust.Inline.Context.Marshalable +import qualified Foreign +import Data.Foldable (foldrM) +import qualified Language.Rust.Inline.Context.Marshalable as Marshalable +import Control.Monad (void) +import Language.Rust.Inline.Context (Context(..), lookupRTypeInContext) + +vectors :: Q Context +vectors = pure $ Context ([rule], [], [rustList, impl]) + where + rule (PathTy Nothing (Path False [PathSegment "Vec" (Just (AngleBracketed [] [t'] [] _)) _] _) _) ctx + | t' /= void [ty|u8|] = do + (t'', rInterOpt) <- lookupRTypeInContext t' ctx + let inter = mkGenPathTy "Vector" . pure <$> fromMaybe (pure t') rInterOpt + pure ([t| [$t''] |], Just inter) + rule _ _ = mempty + + rustList = unlines + [ "#[repr(C)]" + , "pub struct Vector(*mut T, usize, extern \"C\" fn (*mut T, usize));" + , "impl Copy for Vector { }" + , "impl Clone for Vector { fn clone(&self) -> Self { Vector(self.0, self.1, self.2) } }" + ] + + impl = unlines + [ "impl + Copy> MarshalInto> for Vector {" + , " fn marshal(self) -> Vec {" + , " let Vector(ptr, len, _) = self;" + , " let inputs = unsafe { std::slice::from_raw_parts(ptr, len) };" + , " let mut vec = Vec::with_capacity(len);" + , " for i in 0..len {" + , " vec.push(inputs[i].marshal());" + , " }" + , " vec" + , " }" + , "}" + , "" + , "impl> MarshalInto> for Vec {" + , " fn marshal(self) -> Vector {" + , " let mut vec = Vec::with_capacity(self.len());" + , " for x in self.into_iter() {" + , " vec.push(x.marshal());" + , " }" + , "" + , " let slice = Box::leak(vec.into_boxed_slice());" + , " let len = slice.len();" + , "" + , " extern fn free(ptr: *mut U, len: usize) {" + , " let data = unsafe { Box::from_raw(std::ptr::slice_from_raw_parts_mut(ptr, len)) };" + , " drop(data);" + , " }" + , " Vector(slice.as_mut_ptr(), len, free)" + , " }" + , "}" + ] + +instance Marshalable a => Marshalable [a] where + sizeOfWith _ = Foreign.sizeOf (undefined :: (Foreign.Ptr (), Word, Foreign.FunPtr (Foreign.Ptr () -> Word -> IO ()))) + alignmentWith _ = Foreign.alignment (undefined :: (Foreign.Ptr (), Word, Foreign.FunPtr (Foreign.Ptr () -> Word -> IO ()))) + withLoc as loc k = Foreign.allocaBytesAligned (length as * sizeOfWith (undefined :: a)) (alignmentWith (undefined :: a)) $ \space -> do + let pokeSlice = Foreign.poke (Foreign.castPtr loc) (Foreign.castPtr space, length as, Foreign.nullFunPtr) + let pokeAt i a = withLoc a $ space `Foreign.plusPtr` (i * sizeOfWith (undefined :: a)) + foldr ($) (pokeSlice >> k) $ zipWith pokeAt [0..] as + + sizeOfPeek _ = Foreign.sizeOf (undefined :: (Foreign.Ptr (), Word, Foreign.FunPtr (Foreign.Ptr () -> Word -> IO ()))) + alignmentPeek _ = Foreign.alignment (undefined :: (Foreign.Ptr (), Word, Foreign.FunPtr (Foreign.Ptr () -> Word -> IO ()))) + peek p = do + (ptr, len, finalizer) <- Foreign.peek (Foreign.castPtr p) + let peekAtOffset :: Word -> IO a + peekAtOffset offset = Marshalable.peek $ ptr `Foreign.plusPtr` (fromIntegral offset * sizeOfPeek (undefined :: a)) + list <- foldrM (\a b -> liftA2 (:) (peekAtOffset a) (pure b)) [] [0 .. len - 1] + freeVector finalizer ptr len + pure list + +foreign import ccall safe "dynamic" freeVector :: Foreign.FunPtr (Foreign.Ptr () -> Word -> IO ()) -> Foreign.Ptr () -> Word -> IO () diff --git a/tests/Concurrency.hs b/tests/Concurrency.hs deleted file mode 100644 index 6daab9b..0000000 --- a/tests/Concurrency.hs +++ /dev/null @@ -1,22 +0,0 @@ -module Concurrency where - -import Control.Concurrent.Async -import Control.Monad (replicateM_) -import Foreign (withForeignPtr, peek) -import Language.Rust.Inline -import Language.Rust.Inline.TH (sizeOfWith) -import Test.Hspec -import Data.Word - -extendContext foreignPointers -extendContext pointers -extendContext prelude -extendContext basic -setCrateModule - -concurrencySpec :: Spec -concurrencySpec = describe "Concurrency" $ do - it "does not crash" $ do - let p = [rust| ForeignPtr { Box::new(0).into() } |] - replicateConcurrently_ 10000 $ replicateM_ 100 [rustIO| Option<()> { *$(p: &mut u64) += 1; Some(()) } |] - withForeignPtr p peek `shouldNotReturn` 0 diff --git a/tests/Main.hs b/tests/Main.hs index 5c459be..beb5ef6 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -12,7 +12,6 @@ import Language.Rust.Inline import AlgebraicDataTypes import ByteStrings -import Concurrency (concurrencySpec) import Data.Word import Foreign.Marshal.Array import Foreign.Ptr @@ -26,6 +25,7 @@ import SimpleTypes import Submodule import Submodule.Submodule import Test.Hspec +import Vectors extendContext basic setCrateRoot [] @@ -35,7 +35,7 @@ main = hspec $ describe "Rust quasiquoter" $ do algebraicDataTypes bytestringSpec - concurrencySpec + vectorsSpec foreignPtrTypes funcPointerTypes ghcUnboxedTypes diff --git a/tests/Vectors.hs b/tests/Vectors.hs new file mode 100644 index 0000000..099a1d1 --- /dev/null +++ b/tests/Vectors.hs @@ -0,0 +1,30 @@ +module Vectors where + +import Language.Rust.Inline +import Language.Rust.Inline.TH +import Test.Hspec +import Control.Monad (forM) +import Foreign (withForeignPtr) +import Data.Word +import qualified Foreign + +extendContext basic +extendContext prelude +extendContext foreignPointers +extendContext vectors +setCrateModule + +vectorsSpec :: Spec +vectorsSpec = describe "Vectors" $ do + it "can marshal list return values" $ do + let mints = [rust| Vec> { vec![Some(17), None] } |] + mints `shouldBe` [Just 17, Nothing] + + let fps = [rust| Vec> { vec![Box::new(17).into(), Box::new(42).into() ] } |] + values <- forM fps $ flip withForeignPtr Foreign.peek + values `shouldBe` [17, 42] + + it "can marshal list arguments" $ do + let ints = [17, 42] :: [Word64] + let rsum = [rust| u64 { $(ints: Vec).iter().sum() } |] + rsum `shouldBe` sum ints From ae75275a61f8630c8f8aace0a4b60a9a13981e49 Mon Sep 17 00:00:00 2001 From: ners Date: Sun, 23 Feb 2025 21:21:39 +0100 Subject: [PATCH 12/13] oh no --- src/Language/Rust/Inline/Context/Vector.hs | 2 +- tests/Vectors.hs | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/Language/Rust/Inline/Context/Vector.hs b/src/Language/Rust/Inline/Context/Vector.hs index ad72bb7..08d2536 100644 --- a/src/Language/Rust/Inline/Context/Vector.hs +++ b/src/Language/Rust/Inline/Context/Vector.hs @@ -80,7 +80,7 @@ instance Marshalable a => Marshalable [a] where (ptr, len, finalizer) <- Foreign.peek (Foreign.castPtr p) let peekAtOffset :: Word -> IO a peekAtOffset offset = Marshalable.peek $ ptr `Foreign.plusPtr` (fromIntegral offset * sizeOfPeek (undefined :: a)) - list <- foldrM (\a b -> liftA2 (:) (peekAtOffset a) (pure b)) [] [0 .. len - 1] + list <- foldrM (\a b -> liftA2 (:) (peekAtOffset a) (pure b)) [] $ take (fromIntegral len) [0..] freeVector finalizer ptr len pure list diff --git a/tests/Vectors.hs b/tests/Vectors.hs index 099a1d1..94dbdb2 100644 --- a/tests/Vectors.hs +++ b/tests/Vectors.hs @@ -17,6 +17,8 @@ setCrateModule vectorsSpec :: Spec vectorsSpec = describe "Vectors" $ do it "can marshal list return values" $ do + [rust| Vec { vec![] } |] `shouldBe` [] + let mints = [rust| Vec> { vec![Some(17), None] } |] mints `shouldBe` [Just 17, Nothing] @@ -28,3 +30,7 @@ vectorsSpec = describe "Vectors" $ do let ints = [17, 42] :: [Word64] let rsum = [rust| u64 { $(ints: Vec).iter().sum() } |] rsum `shouldBe` sum ints + + it "can marshal pairs of lists" $ do + let yes = [rust| (Vec, Vec) { (vec![], vec![]) } |] + yes `shouldBe` ([], []) From 337236f3a57c6310888890d75a8b0c7816df324f Mon Sep 17 00:00:00 2001 From: ners Date: Sun, 26 Oct 2025 09:26:59 +0100 Subject: [PATCH 13/13] update flake --- .github/workflows/ci.yml | 2 +- flake.lock | 6 +- flake.nix | 76 +++++++++++++--------- src/Language/Rust/Inline/Context/Vector.hs | 2 +- 4 files changed, 51 insertions(+), 35 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fe9da39..00d0e36 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: cachix/install-nix-action@v30 + - uses: ners/simply-nix@main - uses: ryanccn/attic-action@v0 with: endpoint: https://cache.ners.ch diff --git a/flake.lock b/flake.lock index a4f2a68..7c1eea8 100644 --- a/flake.lock +++ b/flake.lock @@ -18,11 +18,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1739866667, - "narHash": "sha256-EO1ygNKZlsAC9avfcwHkKGMsmipUk1Uc0TbrEZpkn64=", + "lastModified": 1761114652, + "narHash": "sha256-f/QCJM/YhrV/lavyCVz8iU3rlZun6d+dAiC3H+CDle4=", "owner": "nixos", "repo": "nixpkgs", - "rev": "73cf49b8ad837ade2de76f87eb53fc85ed5d4680", + "rev": "01f116e4df6a15f4ccdffb1bcd41096869fb385c", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 05b3b10..f24a53f 100644 --- a/flake.nix +++ b/flake.nix @@ -23,38 +23,47 @@ ); sourceFilter = root: with lib.fileset; toSource { inherit root; - fileset = fileFilter (file: any file.hasExt [ "cabal" "hs" "hsc" "md" ]) root; + fileset = fileFilter + (file: any file.hasExt [ "cabal" "hs" "md" ]) + root; }; ghcsFor = pkgs: with lib; foldlAttrs - (acc: name: hp: + (acc: name: hp': let - version = getVersion hp.ghc; + hp = tryEval hp'; + version = getVersion hp.value.ghc; majorMinor = versions.majorMinor version; ghcName = "ghc${replaceStrings ["."] [""] majorMinor}"; in - if hp ? ghc && ! acc ? ${ghcName} && versionAtLeast version "9.2" && versionOlder version "9.11" - then acc // { ${ghcName} = hp; } + if hp.value ? ghc && ! acc ? ${ghcName} && versionAtLeast version "9.4" && versionOlder version "9.13" + then acc // { ${ghcName} = hp.value; } else acc ) { } pkgs.haskell.packages; hpsFor = pkgs: { default = pkgs.haskellPackages; } // ghcsFor pkgs; pname = "inline-rust"; - src = sourceFilter ./.; - overlay = lib.composeManyExtensions [ - (final: prev: { - haskell = prev.haskell // { - packageOverrides = lib.composeManyExtensions [ - prev.haskell.packageOverrides - (hfinal: hprev: with prev.haskell.lib.compose; { - language-rust = hfinal.callCabal2nix "language-rust" inputs.language-rust { }; - ${pname} = addBuildDepend prev.cargo (hfinal.callCabal2nix pname src { }); - }) - ]; - }; - }) - ]; + pnames = [ pname ]; + haskell-overlay = final: prev: hfinal: hprev: with prev.haskell.lib.compose; { + language-rust = hfinal.callCabal2nix "language-rust" inputs.language-rust { }; + ${pname} = addBuildDepend prev.cargo (hfinal.callCabal2nix pname (sourceFilter ./.) { }); + }; + overlay = final: prev: { + haskell = prev.haskell // { + packageOverrides = lib.composeManyExtensions [ + prev.haskell.packageOverrides + (haskell-overlay final prev) + ]; + }; + }; in + { + overlays = { + default = overlay; + haskell = haskell-overlay; + }; + } + // foreach inputs.nixpkgs.legacyPackages (system: pkgs': let @@ -62,11 +71,20 @@ hps = hpsFor pkgs; libs = pkgs.buildEnv { name = "${pname}-libs"; - paths = map (hp: hp.${pname}) (attrValues hps); + paths = + lib.mapCartesianProduct + ({ hp, pname }: hp.${pname}) + { hp = attrValues hps; pname = pnames; }; pathsToLink = [ "/lib" ]; }; - docs = pkgs.haskell.lib.documentationTarball hps.default.${pname}; - sdist = pkgs.haskell.lib.sdistTarball hps.default.${pname}; + docs = pkgs.buildEnv { + name = "${pname}-docs"; + paths = map (pname: pkgs.haskell.lib.documentationTarball hps.default.${pname}) pnames; + }; + sdist = pkgs.buildEnv { + name = "${pname}-sdist"; + paths = map (pname: pkgs.haskell.lib.sdistTarball hps.default.${pname}) pnames; + }; docsAndSdist = pkgs.linkFarm "${pname}-docsAndSdist" { inherit docs sdist; }; in { @@ -80,20 +98,18 @@ devShells.${system} = foreach hps (ghcName: hp: { ${ghcName} = hp.shellFor { - packages = ps: [ hp.${pname} ]; + packages = ps: map (pname: ps.${pname}) pnames; nativeBuildInputs = with pkgs'; with haskellPackages; [ - pkgs'.haskellPackages.cabal-install + cabal-gild + cabal-install cargo fourmolu - gdb - haskell-language-server rustc - valgrind + ] ++ lib.optionals (lib.versionAtLeast (lib.getVersion hp.ghc) "9.4") [ + hp.haskell-language-server ]; }; }); } - ) // { - overlays.default = overlay; - }; + ); } diff --git a/src/Language/Rust/Inline/Context/Vector.hs b/src/Language/Rust/Inline/Context/Vector.hs index 08d2536..1f8057c 100644 --- a/src/Language/Rust/Inline/Context/Vector.hs +++ b/src/Language/Rust/Inline/Context/Vector.hs @@ -80,7 +80,7 @@ instance Marshalable a => Marshalable [a] where (ptr, len, finalizer) <- Foreign.peek (Foreign.castPtr p) let peekAtOffset :: Word -> IO a peekAtOffset offset = Marshalable.peek $ ptr `Foreign.plusPtr` (fromIntegral offset * sizeOfPeek (undefined :: a)) - list <- foldrM (\a b -> liftA2 (:) (peekAtOffset a) (pure b)) [] $ take (fromIntegral len) [0..] + list <- foldrM (\a b -> (:) <$> peekAtOffset a <*> pure b) [] $ take (fromIntegral len) [0..] freeVector finalizer ptr len pure list