diff --git a/servant-client/test/Servant/ClientTestUtils.hs b/servant-client/test/Servant/ClientTestUtils.hs index 2af8bbc9e..b6cefc462 100644 --- a/servant-client/test/Servant/ClientTestUtils.hs +++ b/servant-client/test/Servant/ClientTestUtils.hs @@ -138,6 +138,21 @@ type TestHeaders = '[Header "X-Example1" Int, Header "X-Example2" String] type TestSetCookieHeaders = '[Header "Set-Cookie" String, Header "Set-Cookie" String] +-- | AsHeaders instance for extracting two headers (Required by the MultiVerbSetCookie test) +-- Returns: (body, (cookie1, cookie2)) +instance AsHeaders '[a, b] c (c, (a, b)) where + toHeaders (body, (h1, h2)) = (I h1 :* I h2 :* Nil, body) + fromHeaders (I h1 :* I h2 :* Nil, body) = (body, (h1, h2)) + +-- | MultiVerb endpoint definition for SetCookie test +type MultiVerbSetCookie = + "multiverb-set-cookie" + :> MultiVerb + 'GET + '[JSON] + '[WithHeaders TestSetCookieHeaders (Bool, (String, String)) (Respond 200 "OK" Bool)] + (Bool, (String, String)) + data RecordRoutes mode = RecordRoutes { version :: mode :- "version" :> Get '[JSON] Int , echo :: mode :- "echo" :> Capture "string" String :> Get '[JSON] String @@ -252,6 +267,7 @@ type Api = :<|> "multiple-choices-int" :> MultipleChoicesInt :<|> "captureVerbatim" :> Capture "someString" Verbatim :> Get '[PlainText] Text :<|> "host-test" :> Host "servant.example" :> Get '[JSON] Bool + :<|> MultiVerbSetCookie :<|> PaginatedAPI api :: Proxy Api @@ -298,6 +314,7 @@ recordRoutes :: RecordRoutes (AsClientT ClientM) multiChoicesInt :: Int -> ClientM MultipleChoicesIntResult captureVerbatim :: Verbatim -> ClientM Text getHost :: ClientM Bool +getMultiVerbSetCookie :: ClientM (Bool, (String, String)) getPaginatedPerson :: Maybe (Range 1 100) -> ClientM [Person] getRoot :<|> getGet @@ -329,6 +346,7 @@ getRoot :<|> multiChoicesInt :<|> captureVerbatim :<|> getHost + :<|> getMultiVerbSetCookie :<|> getPaginatedPerson = client api server :: Application @@ -409,6 +427,7 @@ server = ) :<|> pure . decodeUtf8 . unVerbatim :<|> pure True + :<|> pure (True, ("cookie1", "cookie2")) :<|> usersServer ) diff --git a/servant-client/test/Servant/SuccessSpec.hs b/servant-client/test/Servant/SuccessSpec.hs index 435f97a05..f16420334 100644 --- a/servant-client/test/Servant/SuccessSpec.hs +++ b/servant-client/test/Servant/SuccessSpec.hs @@ -147,6 +147,15 @@ successSpec = beforeAll (startWaiApp server) $ afterAll endWaiApp $ do Left e -> assertFailure $ show e Right val -> getHeaders val `shouldBe` [("Set-Cookie", "cookie1"), ("Set-Cookie", "cookie2")] + it "Returns multiple Set-Cookie headers via MultiVerb WithHeaders" $ \(_, baseUrl) -> do + res <- runClient getMultiVerbSetCookie baseUrl + case res of + Left e -> assertFailure $ show e + Right (body, (cookie1, cookie2)) -> do + body `shouldBe` True + cookie1 `shouldBe` "cookie1" + cookie2 `shouldBe` "cookie2" + it "Stores Cookie in CookieJar after a redirect" $ \(_, baseUrl) -> do mgr <- C.newManager C.defaultManagerSettings cj <- atomically . newTVar $ C.createCookieJar [] diff --git a/servant/src/Servant/API/MultiVerb.hs b/servant/src/Servant/API/MultiVerb.hs index 45ea5a816..f885abe52 100644 --- a/servant/src/Servant/API/MultiVerb.hs +++ b/servant/src/Servant/API/MultiVerb.hs @@ -187,14 +187,16 @@ instance constructHeader @h x <> constructHeaders @headers xs - -- NOTE: should we concatenate all the matching headers instead of just taking the first one? + -- This implementation retrieves the *first* header with matching name. + -- It leaves other instances of the same header intact for subsequent extraction, which allows + -- multiple headers with the same name to be extracted (e.g. Set-Cookie). extractHeaders headers = do let name' = headerName @name - (headers0, headers1) = Seq.partition (\(h, _) -> h == name') headers - x <- case headers0 of - Seq.Empty -> empty - ((_, h) :<| _) -> either (const empty) pure (parseHeader h) - xs <- extractHeaders @headers headers1 + idx <- Seq.findIndexL (\(h, _) -> h == name') headers + let (_, val) = Seq.index headers idx + headers' = Seq.deleteAt idx headers + x <- either (const empty) pure (parseHeader val) + xs <- extractHeaders @headers headers' pure (I x :* xs) class ServantHeader h (name :: Symbol) x | h -> name x where