Skip to content

Commit 6e48bab

Browse files
committed
Hack in support for Maybe ByteString
1 parent 148d725 commit 6e48bab

File tree

4 files changed

+54
-6
lines changed

4 files changed

+54
-6
lines changed

src/Language/Rust/Inline.hs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ import Foreign.Marshal.Alloc (alloca, free)
103103
import Foreign.Marshal.Array (newArray, withArrayLen)
104104
import Foreign.Marshal.Unsafe (unsafeLocalState)
105105
import Foreign.Marshal.Utils (new, with)
106-
import Foreign.Ptr (FunPtr, Ptr, freeHaskellFunPtr)
106+
import Foreign.Ptr (FunPtr, Ptr, freeHaskellFunPtr, nullPtr)
107107

108108
import Control.Monad (void)
109109
import Data.List (intercalate)
@@ -315,6 +315,7 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
315315
(returnFfi, haskRet') <- do
316316
marshalForm <- ghcMarshallable haskRet
317317
let fptrRet haskRet' = [t|Ptr (Ptr $(pure haskRet'), FunPtr (Ptr $(pure haskRet') -> IO ())) -> IO ()|]
318+
let bsRet = [t|Ptr (Ptr Word8, Word, FunPtr (Ptr Word8 -> Word -> IO ())) -> IO ()|]
318319
ret <- case marshalForm of
319320
BoxedDirect -> [t|IO $(pure haskRet)|]
320321
BoxedIndirect -> [t|Ptr $(pure haskRet) -> IO ()|]
@@ -323,7 +324,8 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
323324
| otherwise ->
324325
let retTy = showTy haskRet
325326
in fail ("Cannot put unlifted type ‘" ++ retTy ++ "’ in IO")
326-
ByteString -> [t|Ptr (Ptr Word8, Word, FunPtr (Ptr Word8 -> Word -> IO ())) -> IO ()|]
327+
ByteString -> bsRet
328+
OptionalByteString -> bsRet
327329
ForeignPtr
328330
| AppT _ haskRet' <- haskRet -> fptrRet haskRet'
329331
| otherwise -> fail ("Cannot marshal " ++ showTy haskRet ++ " using the ForeignPtr calling convention")
@@ -356,6 +358,9 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
356358
ByteString -> do
357359
rbsT <- [t|Ptr (Ptr Word8, Word)|]
358360
pure (ByteString, rbsT)
361+
OptionalByteString -> do
362+
rbsT <- [t|Ptr (Ptr Word8, Word)|]
363+
pure (OptionalByteString, rbsT)
359364
ForeignPtr
360365
| AppT _ haskArg' <- haskArg -> do
361366
ptr <- [t|Ptr $(pure haskArg')|]
@@ -431,6 +436,26 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
431436
else Just <$> newForeignPtr $(varE finalizer) $(varE ptr)
432437
)
433438
|]
439+
| returnFfi == OptionalByteString = do
440+
ret <- newName "ret"
441+
ptr <- newName "ptr"
442+
len <- newName "len"
443+
finalizer <- newName "finalizer"
444+
[e|
445+
alloca
446+
( \($(varP ret)) -> do
447+
$(appsE (varE qqName : reverse (varE ret : acc)))
448+
($(varP ptr), $(varP len), $(varP finalizer)) <- peek $(varE ret)
449+
if $(varE ptr) == nullPtr
450+
then pure Nothing
451+
else
452+
Just
453+
<$> ByteString.unsafePackCStringFinalizer
454+
$(varE ptr)
455+
(fromIntegral $(varE len))
456+
($(varE bsFree) $(varE finalizer) $(varE ptr) $(varE len))
457+
)
458+
|]
434459
| returnByValue returnFfi = appsE (varE qqName : reverse acc)
435460
| otherwise = do
436461
ret <- newName "ret"
@@ -475,6 +500,7 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
475500
Just $(varP fptr) ->
476501
withForeignPtr $(varE fptr) (\($(varP ptr)) -> $(goArgs (varE ptr : acc) args))
477502
|]
503+
| marshalForm == OptionalByteString -> fail "Don't"
478504
| passByValue marshalForm -> goArgs (varE argName : acc) args
479505
| otherwise -> do
480506
x <- newName "x"

src/Language/Rust/Inline/Context/ByteString.hs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,16 @@ import Data.Functor (void)
2929
import Data.Word (Word8)
3030
import Foreign.Ptr (Ptr)
3131

32+
import Debug.Trace (traceM)
33+
3234
bytestrings :: Q Context
3335
bytestrings =
3436
pure $ Context ([rule], [], [rustByteString, impl])
3537
where
3638
rule rty _
3739
| rty == void [ty| &[u8] |] = pure ([t|ByteString|], pure . pure $ void [ty| RustByteString |])
3840
| rty == void [ty| Vec<u8> |] = pure ([t|ByteString|], pure . pure $ void [ty| RustOwnedByteString |])
41+
| rty == void [ty| Option<Vec<u8>> |] = pure ([t|Maybe ByteString|], pure . pure $ void [ty| RustOwnedByteString |])
3942
rule _ _ = mempty
4043

4144
rustByteString =
@@ -70,4 +73,13 @@ bytestrings =
7073
, " RustOwnedByteString(bytes.as_mut_ptr(), len, free)"
7174
, " }"
7275
, "}"
76+
, ""
77+
, "impl MarshalInto<RustOwnedByteString> for Option<Vec<u8>> {"
78+
, " fn marshal(self) -> RustOwnedByteString {"
79+
, " extern fn panic(ptr: *mut u8, len: usize) {"
80+
, " panic!(\"Attempted to free a null ByteString\");"
81+
, " }"
82+
, " self.map(|bs| bs.marshal()).unwrap_or(RustOwnedByteString(std::ptr::null_mut(), 0, panic))"
83+
, " }"
84+
, "}"
7385
]

src/Language/Rust/Inline/Marshal.hs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ data MarshalForm
3939
| ByteString
4040
| ForeignPtr
4141
| OptionalForeignPtr
42-
deriving (Eq)
42+
| OptionalByteString
43+
deriving (Eq, Show)
4344

4445
passByValue :: MarshalForm -> Bool
4546
passByValue = (`elem` [UnboxedDirect, BoxedDirect, ForeignPtr])
@@ -71,9 +72,10 @@ ghcMarshallable ty = do
7172
AppT con _ | con `elem` tyconsU -> pure UnboxedDirect
7273
| con `elem` tyconsB -> pure BoxedDirect
7374
| con == fptrCons -> pure ForeignPtr
74-
AppT mb (AppT fptr _)
75-
| mb == maybeCons && fptr == fptrCons -> pure OptionalForeignPtr
76-
_ -> pure BoxedIndirect
75+
AppT mb (AppT c _)
76+
| mb == maybeCons && c == fptrCons -> pure OptionalForeignPtr
77+
AppT mb c | mb == maybeCons && c == bytestring -> pure OptionalByteString
78+
_ -> pure BoxedIndirect
7779
where
7880
qSimpleUnboxed = [ [t| Char# |]
7981
, [t| Int# |]

tests/ByteStrings.hs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import Test.Hspec
99
import Data.ByteString (ByteString)
1010
import qualified Data.ByteString as ByteString
1111
import qualified Data.ByteString.Unsafe as ByteString
12+
import Data.Maybe (fromJust)
1213
import Data.String
1314

1415
extendContext basic
@@ -34,3 +35,10 @@ bytestringSpec = describe "ByteStrings" $ do
3435
ByteString.pack [0, 1, 2, 3]
3536
`shouldBe` rustBs
3637
ByteString.unsafeFinalize rustBs
38+
39+
it "can marshal optional ByteString return values" $ do
40+
let noRustBs = [rust| Option<Vec<u8>> { None } |]
41+
noRustBs `shouldBe` Nothing
42+
43+
let rustBs = [rust| Option<Vec<u8>> { Some(vec![0, 1, 2, 3]) } |]
44+
fromJust rustBs `shouldBe` ByteString.pack [0, 1, 2, 3]

0 commit comments

Comments
 (0)