11use crate :: { AppState , Done } ;
22use axum:: body:: { Body , BodyDataStream } ;
3- use axum:: extract:: { Path , State } ;
3+ use axum:: extract:: { Path , Request , State } ;
44use axum:: http:: { header, HeaderMap , HeaderValue , StatusCode } ;
5- use axum:: response:: IntoResponse ;
6- use axum:: routing:: { delete, get, post} ;
5+ use axum:: middleware:: { self , Next } ;
6+ use axum:: response:: { IntoResponse , Response } ;
7+ use axum:: routing:: { get, post} ;
78use axum:: Router ;
89use std:: collections:: HashMap ;
910use std:: sync:: Arc ;
10- use tokio:: sync:: Mutex ;
11+ use tokio:: sync:: { oneshot , Mutex } ;
1112
1213type Namespace = String ;
1314type ChannelName = String ;
@@ -24,50 +25,58 @@ pub(crate) type ChannelClients = Mutex<
2425 > ,
2526> ;
2627
27- pub ( crate ) fn routes ( ) -> Router < Arc < AppState > > {
28+ pub ( crate ) fn routes ( state : Arc < AppState > ) -> Router < Arc < AppState > > {
2829 Router :: new ( )
2930 . route ( "/channels/namespaces" , get ( list_all_namespaces) )
3031 . route ( "/channels/{namespace}" , get ( list_all_namespace_channels) )
31- . route (
32- "/channels/{namespace}" ,
33- delete ( delete_namespace_and_all_channels) ,
34- )
35- . route (
36- "/channels/{namespace}/{channel_name}" ,
37- get ( subscribe_to_channel) ,
38- )
3932 . route (
4033 "/channels/{namespace}/{channel_name}" ,
41- post ( broadcast_to_channel) ,
34+ get ( subscribe_to_channel) . route_layer ( middleware:: from_fn_with_state (
35+ state. clone ( ) ,
36+ clean_up_unused_channels,
37+ ) ) ,
4238 )
4339 . route (
4440 "/channels/{namespace}/{channel_name}" ,
45- delete ( delete_channel) ,
41+ post ( broadcast_to_channel) . route_layer ( middleware:: from_fn_with_state (
42+ state. clone ( ) ,
43+ clean_up_unused_channels,
44+ ) ) ,
4645 )
4746}
4847
49- async fn delete_namespace_and_all_channels (
50- Path ( namespace) : Path < String > ,
48+ async fn clean_up_unused_channels (
49+ Path ( ( namespace, channel_name ) ) : Path < ( String , String ) > ,
5150 State ( state) : State < Arc < AppState > > ,
52- ) -> axum:: response:: Result < ( ) > {
53- let mut channel_clients = state. channel_clients . lock ( ) . await ;
51+ request : Request ,
52+ next : Next ,
53+ ) -> Response {
54+ let ( tx, rx) = oneshot:: channel ( ) ;
5455
55- channel_clients. remove ( & namespace) ;
56+ tokio:: spawn ( async move {
57+ let _ = rx. await ;
5658
57- Ok ( ( ) )
58- }
59+ let mut channel_clients = state. channel_clients . lock ( ) . await ;
5960
60- async fn delete_channel (
61- Path ( ( namespace, channel_name) ) : Path < ( String , String ) > ,
62- State ( state) : State < Arc < AppState > > ,
63- ) -> axum:: response:: Result < ( ) > {
64- let mut channel_clients = state. channel_clients . lock ( ) . await ;
61+ let delete_namespace = if let Some ( namespace_channels) = channel_clients. get_mut ( & namespace)
62+ {
63+ namespace_channels. remove ( & channel_name) ;
6564
66- if let Some ( channels) = channel_clients. get_mut ( & namespace) {
67- channels. remove ( & channel_name) ;
68- }
65+ namespace_channels. is_empty ( )
66+ } else {
67+ false
68+ } ;
6969
70- Ok ( ( ) )
70+ if delete_namespace {
71+ channel_clients. remove ( & namespace) ;
72+ }
73+ } ) ;
74+
75+ let response = next. run ( request) . await ;
76+
77+ let _ = tx. send ( ( ) ) ;
78+
79+ response
7180}
7281
7382async fn list_all_namespaces (
@@ -105,11 +114,9 @@ async fn broadcast_to_channel(
105114) -> axum:: response:: Result < ( ) > {
106115 let mut channel_clients = state. channel_clients . lock ( ) . await ;
107116
108- let namespace_channels = if let Some ( channels) = channel_clients. get_mut ( & namespace) {
109- channels
110- } else {
111- channel_clients. insert ( namespace. clone ( ) , HashMap :: new ( ) ) ;
112- channel_clients. get_mut ( & namespace) . unwrap ( )
117+ let namespace_channels = match channel_clients. entry ( namespace) {
118+ std:: collections:: hash_map:: Entry :: Occupied ( e) => e. into_mut ( ) ,
119+ std:: collections:: hash_map:: Entry :: Vacant ( e) => e. insert ( HashMap :: new ( ) ) ,
113120 } ;
114121
115122 let tx = if let Some ( ( tx, _rx) ) = namespace_channels. get ( & channel_name) {
@@ -124,11 +131,11 @@ async fn broadcast_to_channel(
124131
125132 drop ( channel_clients) ;
126133
127- let body_stream = body. into_data_stream ( ) ;
134+ let request_body_stream = body. into_data_stream ( ) ;
128135
129136 let ( done, done_rx) = Done :: new ( ) ;
130137
131- tx. send_async ( ( body_stream , request_headers, done) )
138+ tx. send_async ( ( request_body_stream , request_headers, done) )
132139 . await
133140 . map_err ( |_e| StatusCode :: INTERNAL_SERVER_ERROR ) ?;
134141
@@ -145,11 +152,9 @@ async fn subscribe_to_channel(
145152) -> axum:: response:: Result < impl IntoResponse > {
146153 let mut channel_clients = state. channel_clients . lock ( ) . await ;
147154
148- let namespace_channels = if let Some ( channels) = channel_clients. get_mut ( & namespace) {
149- channels
150- } else {
151- channel_clients. insert ( namespace. clone ( ) , HashMap :: new ( ) ) ;
152- channel_clients. get_mut ( & namespace) . unwrap ( )
155+ let namespace_channels = match channel_clients. entry ( namespace) {
156+ std:: collections:: hash_map:: Entry :: Occupied ( e) => e. into_mut ( ) ,
157+ std:: collections:: hash_map:: Entry :: Vacant ( e) => e. insert ( HashMap :: new ( ) ) ,
153158 } ;
154159
155160 let rx = if let Some ( ( _tx, rx) ) = namespace_channels. get ( & channel_name) {
@@ -420,166 +425,4 @@ mod tests {
420425
421426 assert_eq ! ( ids, vec![ "it_should_autovivify_on_publish" ] )
422427 }
423-
424- #[ tokio:: test]
425- async fn delete_channel ( ) {
426- let options = Options :: default ( ) ;
427-
428- let port = get_port ( ) ;
429-
430- let listener = tokio:: net:: TcpListener :: bind ( ( "0.0.0.0" , port) )
431- . await
432- . unwrap ( ) ;
433-
434- let ( _done, done_rx) = Done :: new ( ) ;
435-
436- tokio:: spawn ( async move {
437- axum:: serve ( listener, app ( options) )
438- . with_graceful_shutdown ( async move { done_rx. await . unwrap ( ) } )
439- . await
440- . unwrap ( ) ;
441- } ) ;
442-
443- tokio:: spawn ( async move {
444- reqwest:: Client :: new ( )
445- . post ( format ! (
446- "http://localhost:{port}/channels/a_great_ns/it_should_autovivify_on_publish"
447- ) )
448- . body ( "some body" )
449- . send ( )
450- . await
451- . unwrap ( )
452- } ) ;
453-
454- reqwest:: get ( format ! (
455- "http://localhost:{port}/channels/a_great_ns/it_should_autovivify_on_publish"
456- ) )
457- . await
458- . unwrap ( ) ;
459-
460- let namespaces: HashSet < String > =
461- reqwest:: get ( format ! ( "http://localhost:{port}/channels/namespaces" ) )
462- . await
463- . unwrap ( )
464- . json ( )
465- . await
466- . unwrap ( ) ;
467-
468- assert_eq ! ( namespaces, HashSet :: from( [ "a_great_ns" . to_string( ) ] ) ) ;
469-
470- let ids: Vec < String > = reqwest:: get ( format ! ( "http://localhost:{port}/channels/a_great_ns" ) )
471- . await
472- . unwrap ( )
473- . json ( )
474- . await
475- . unwrap ( ) ;
476-
477- assert_eq ! ( ids, vec![ "it_should_autovivify_on_publish" ] ) ;
478-
479- reqwest:: Client :: new ( )
480- . delete ( format ! (
481- "http://localhost:{port}/channels/a_great_ns/it_should_autovivify_on_publish"
482- ) )
483- . send ( )
484- . await
485- . unwrap ( ) ;
486-
487- let ids: Vec < String > = reqwest:: get ( format ! ( "http://localhost:{port}/channels/a_great_ns" ) )
488- . await
489- . unwrap ( )
490- . json ( )
491- . await
492- . unwrap ( ) ;
493-
494- assert_eq ! ( ids, Vec :: <String >:: new( ) ) ;
495-
496- let namespaces: HashSet < String > =
497- reqwest:: get ( format ! ( "http://localhost:{port}/channels/namespaces" ) )
498- . await
499- . unwrap ( )
500- . json ( )
501- . await
502- . unwrap ( ) ;
503-
504- assert_eq ! ( namespaces, HashSet :: from( [ "a_great_ns" . to_string( ) ] ) ) ;
505- }
506-
507- #[ tokio:: test]
508- async fn delete_namespace_and_all_channels ( ) {
509- let options = Options :: default ( ) ;
510-
511- let port = get_port ( ) ;
512-
513- let listener = tokio:: net:: TcpListener :: bind ( ( "0.0.0.0" , port) )
514- . await
515- . unwrap ( ) ;
516-
517- let ( _done, done_rx) = Done :: new ( ) ;
518-
519- tokio:: spawn ( async move {
520- axum:: serve ( listener, app ( options) )
521- . with_graceful_shutdown ( async move { done_rx. await . unwrap ( ) } )
522- . await
523- . unwrap ( ) ;
524- } ) ;
525-
526- tokio:: spawn ( async move {
527- reqwest:: Client :: new ( )
528- . post ( format ! (
529- "http://localhost:{port}/channels/a_great_ns/it_should_autovivify_on_publish"
530- ) )
531- . body ( "some body" )
532- . send ( )
533- . await
534- . unwrap ( )
535- } ) ;
536-
537- reqwest:: get ( format ! (
538- "http://localhost:{port}/channels/a_great_ns/it_should_autovivify_on_publish"
539- ) )
540- . await
541- . unwrap ( ) ;
542-
543- let namespaces: HashSet < String > =
544- reqwest:: get ( format ! ( "http://localhost:{port}/channels/namespaces" ) )
545- . await
546- . unwrap ( )
547- . json ( )
548- . await
549- . unwrap ( ) ;
550-
551- assert_eq ! ( namespaces, HashSet :: from( [ "a_great_ns" . to_string( ) ] ) ) ;
552-
553- let ids: Vec < String > = reqwest:: get ( format ! ( "http://localhost:{port}/channels/a_great_ns" ) )
554- . await
555- . unwrap ( )
556- . json ( )
557- . await
558- . unwrap ( ) ;
559-
560- assert_eq ! ( ids, vec![ "it_should_autovivify_on_publish" ] ) ;
561-
562- reqwest:: Client :: new ( )
563- . delete ( format ! ( "http://localhost:{port}/channels/a_great_ns" ) )
564- . send ( )
565- . await
566- . unwrap ( ) ;
567-
568- let ns_status = reqwest:: get ( format ! ( "http://localhost:{port}/channels/a_great_ns" ) )
569- . await
570- . unwrap ( )
571- . status ( ) ;
572-
573- assert_eq ! ( ns_status, StatusCode :: NOT_FOUND ) ;
574-
575- let namespaces: HashSet < String > =
576- reqwest:: get ( format ! ( "http://localhost:{port}/channels/namespaces" ) )
577- . await
578- . unwrap ( )
579- . json ( )
580- . await
581- . unwrap ( ) ;
582-
583- assert_eq ! ( namespaces, HashSet :: new( ) ) ;
584- }
585428}
0 commit comments