From d88f72737db903e38d4298041943f0857a575b14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=A4=80=ED=99=98?= Date: Tue, 20 Jan 2026 11:42:59 +0900 Subject: [PATCH] Add header propagation predicate support to message return value handlers Signed-off-by: Junhwan Kim --- .../SendToMethodReturnValueHandler.java | 43 ++++++++++++-- .../SubscriptionMethodReturnValueHandler.java | 41 ++++++++++++- .../SendToMethodReturnValueHandlerTests.java | 53 ++++++++++++++++- ...criptionMethodReturnValueHandlerTests.java | 59 +++++++++++++++++++ 4 files changed, 189 insertions(+), 7 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java index 9d13aa045184..f391a82b34df 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java @@ -45,6 +45,7 @@ import org.springframework.util.PropertyPlaceholderHelper; import org.springframework.util.PropertyPlaceholderHelper.PlaceholderResolver; import org.springframework.util.StringUtils; +import java.util.function.Predicate; /** * A {@link HandlerMethodReturnValueHandler} for sending to destinations specified in a @@ -73,6 +74,8 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH private @Nullable MessageHeaderInitializer headerInitializer; + private @Nullable Predicate headerFilter; + public SendToMethodReturnValueHandler(SimpMessageSendingOperations messagingTemplate, boolean annotationRequired) { Assert.notNull(messagingTemplate, "'messagingTemplate' must not be null"); @@ -133,6 +136,27 @@ public void setHeaderInitializer(@Nullable MessageHeaderInitializer headerInitia return this.headerInitializer; } + /** + * Add a filter to determine which headers from the input message should be propagated to the output message. + * Multiple filters are combined with logical OR. + *

If not set, no input headers are propagated (default behavior).

+ */ + public void addHeaderFilter(Predicate filter) { + Assert.notNull(filter, "Filter predicate must not be null"); + if (this.headerFilter == null) { + this.headerFilter = filter; + } else { + this.headerFilter = this.headerFilter.or(filter); + } + } + + /** + * Return the configured header filter. + */ + public @Nullable Predicate getHeaderFilter() { + return this.headerFilter; + } + @Override public boolean supportsReturnType(MethodParameter returnType) { @@ -171,11 +195,11 @@ public void handleReturnValue(@Nullable Object returnValue, MethodParameter retu destination = destinationHelper.expandTemplateVars(destination); if (broadcast) { this.messagingTemplate.convertAndSendToUser( - user, destination, returnValue, createHeaders(null, returnType)); + user, destination, returnValue, createHeaders(null, returnType, message)); } else { this.messagingTemplate.convertAndSendToUser( - user, destination, returnValue, createHeaders(sessionId, returnType)); + user, destination, returnValue, createHeaders(sessionId, returnType, message)); } } } @@ -185,7 +209,7 @@ public void handleReturnValue(@Nullable Object returnValue, MethodParameter retu String[] destinations = getTargetDestinations(sendTo, message, this.defaultDestinationPrefix); for (String destination : destinations) { destination = destinationHelper.expandTemplateVars(destination); - this.messagingTemplate.convertAndSend(destination, returnValue, createHeaders(sessionId, returnType)); + this.messagingTemplate.convertAndSend(destination, returnValue, createHeaders(sessionId, returnType, message)); } } } @@ -234,11 +258,22 @@ protected String[] getTargetDestinations(@Nullable Annotation annotation, Messag new String[] {defaultPrefix + destination} : new String[] {defaultPrefix + '/' + destination}); } - private MessageHeaders createHeaders(@Nullable String sessionId, MethodParameter returnType) { + private MessageHeaders createHeaders(@Nullable String sessionId, MethodParameter returnType, @Nullable Message inputMessage) { SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); if (getHeaderInitializer() != null) { getHeaderInitializer().initHeaders(headerAccessor); } + + if (inputMessage != null && headerFilter != null) { + Map inputHeaders = inputMessage.getHeaders(); + for (Map.Entry entry : inputHeaders.entrySet()) { + String name = entry.getKey(); + if (headerFilter.test(name)) { + headerAccessor.setHeader(name, entry.getValue()); + } + } + } + if (sessionId != null) { headerAccessor.setSessionId(sessionId); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java index 5d9e99c09faf..d487e0c1af30 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java @@ -16,6 +16,9 @@ package org.springframework.messaging.simp.annotation.support; +import java.util.Map; +import java.util.function.Predicate; + import org.apache.commons.logging.Log; import org.jspecify.annotations.Nullable; @@ -65,6 +68,8 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn private @Nullable MessageHeaderInitializer headerInitializer; + private @Nullable Predicate headerFilter; + /** * Construct a new SubscriptionMethodReturnValueHandler. @@ -93,6 +98,27 @@ public void setHeaderInitializer(@Nullable MessageHeaderInitializer headerInitia return this.headerInitializer; } + /** + * Add a filter to determine which headers from the input message should be propagated to the output message. + * Multiple filters are combined with logical OR. + *

If not set, no input headers are propagated (default behavior).

+ */ + public void addHeaderFilter(Predicate filter) { + Assert.notNull(filter, "Filter predicate must not be null"); + if (this.headerFilter == null) { + this.headerFilter = filter; + } else { + this.headerFilter = this.headerFilter.or(filter); + } + } + + /** + * Return the configured header filter. + */ + public @Nullable Predicate getHeaderFilter() { + return this.headerFilter; + } + @Override public boolean supportsReturnType(MethodParameter returnType) { @@ -126,15 +152,26 @@ public void handleReturnValue(@Nullable Object returnValue, MethodParameter retu if (logger.isDebugEnabled()) { logger.debug("Reply to @SubscribeMapping: " + returnValue); } - MessageHeaders headersToSend = createHeaders(sessionId, subscriptionId, returnType); + MessageHeaders headersToSend = createHeaders(sessionId, subscriptionId, returnType, message); this.messagingTemplate.convertAndSend(destination, returnValue, headersToSend); } - private MessageHeaders createHeaders(@Nullable String sessionId, String subscriptionId, MethodParameter returnType) { + private MessageHeaders createHeaders(@Nullable String sessionId, String subscriptionId, MethodParameter returnType, @Nullable Message inputMessage) { SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); if (getHeaderInitializer() != null) { getHeaderInitializer().initHeaders(accessor); } + + if (inputMessage != null && headerFilter != null) { + Map inputHeaders = inputMessage.getHeaders(); + for (Map.Entry entry : inputHeaders.entrySet()) { + String name = entry.getKey(); + if (headerFilter.test(name)) { + accessor.setHeader(name, entry.getValue()); + } + } + } + if (sessionId != null) { accessor.setSessionId(sessionId); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java index a0bc78919c8c..1da2f9b02519 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java @@ -293,9 +293,60 @@ public void sendToUserWithSendToOverride() throws Exception { assertResponse(parameter, sessionId, 1, "/dest4"); } + @Test + void sendToWithHeaderFilterSinglePredicate() throws Exception { + given(this.messageChannel.send(any(Message.class))).willReturn(true); + + String sessionId = "sess1"; + String customHeaderName = "x-custom-header"; + String customHeaderValue = "custom-value"; + Message inputMessage = createMessage(sessionId, "sub1", null, null, null); + inputMessage = MessageBuilder.fromMessage(inputMessage) + .setHeader(customHeaderName, customHeaderValue) + .build(); + + SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(new SimpMessagingTemplate(this.messageChannel), true); + handler.addHeaderFilter(name -> name.equals(customHeaderName)); + + handler.handleReturnValue(PAYLOAD, this.sendToReturnType, inputMessage); + + verify(this.messageChannel, times(2)).send(this.messageCaptor.capture()); + for (Message sent : this.messageCaptor.getAllValues()) { + MessageHeaders headers = sent.getHeaders(); + assertThat(headers.get(customHeaderName)).isEqualTo(customHeaderValue); + } + } + + @Test + void sendToWithHeaderFilterMultiplePredicates() throws Exception { + given(this.messageChannel.send(any(Message.class))).willReturn(true); + + String sessionId = "sess1"; + String headerA = "x-header-a"; + String headerB = "x-header-b"; + Message inputMessage = createMessage(sessionId, "sub1", null, null, null); + inputMessage = MessageBuilder.fromMessage(inputMessage) + .setHeader(headerA, "A-value") + .setHeader(headerB, "B-value") + .build(); + + SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(new SimpMessagingTemplate(this.messageChannel), true); + handler.addHeaderFilter(name -> name.equals(headerA)); + handler.addHeaderFilter(name -> name.equals(headerB)); + + handler.handleReturnValue(PAYLOAD, this.sendToReturnType, inputMessage); + + verify(this.messageChannel, times(2)).send(this.messageCaptor.capture()); + for (Message sent : this.messageCaptor.getAllValues()) { + MessageHeaders headers = sent.getHeaders(); + assertThat(headers.get(headerA)).isEqualTo("A-value"); + assertThat(headers.get(headerB)).isEqualTo("B-value"); + } + } + private void assertResponse(MethodParameter methodParameter, String sessionId, - int index, String destination) { + int index, String destination) { SimpMessageHeaderAccessor accessor = getCapturedAccessor(index); assertThat(accessor.getSessionId()).isEqualTo(sessionId); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java index b6368c2a382a..a44d93542400 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java @@ -186,6 +186,65 @@ void testJsonView() throws Exception { assertThat(new String((byte[]) message.getPayload(), StandardCharsets.UTF_8)).isEqualTo("{\"withView1\":\"with\"}"); } + @Test + void testHeaderFilterSinglePredicate() throws Exception { + String sessionId = "sess1"; + String subscriptionId = "subs1"; + String destination = "/dest"; + String customHeaderName = "x-custom-header"; + String customHeaderValue = "custom-value"; + Message inputMessage = MessageBuilder.withPayload(PAYLOAD) + .setHeader(SimpMessageHeaderAccessor.SESSION_ID_HEADER, sessionId) + .setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID_HEADER, subscriptionId) + .setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, destination) + .setHeader(customHeaderName, customHeaderValue) + .build(); + + MessageSendingOperations messagingTemplate = mock(); + SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(messagingTemplate); + + handler.addHeaderFilter(name -> name.equals(customHeaderName)); + + handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage); + + ArgumentCaptor captor = ArgumentCaptor.forClass(MessageHeaders.class); + verify(messagingTemplate).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture()); + + MessageHeaders sentHeaders = captor.getValue(); + assertThat(sentHeaders.get(customHeaderName)).isEqualTo(customHeaderValue); + } + + @Test + void testHeaderFilterMultiplePredicates() throws Exception { + String sessionId = "sess1"; + String subscriptionId = "subs1"; + String destination = "/dest"; + String headerA = "x-header-a"; + String headerB = "x-header-b"; + Message inputMessage = MessageBuilder.withPayload(PAYLOAD) + .setHeader(SimpMessageHeaderAccessor.SESSION_ID_HEADER, sessionId) + .setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID_HEADER, subscriptionId) + .setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, destination) + .setHeader(headerA, "A-value") + .setHeader(headerB, "B-value") + .build(); + + MessageSendingOperations messagingTemplate = mock(); + SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(messagingTemplate); + + handler.addHeaderFilter(name -> name.equals(headerA)); + handler.addHeaderFilter(name -> name.equals(headerB)); + + handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage); + + ArgumentCaptor captor = ArgumentCaptor.forClass(MessageHeaders.class); + verify(messagingTemplate).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture()); + + MessageHeaders sentHeaders = captor.getValue(); + assertThat(sentHeaders.get(headerA)).isEqualTo("A-value"); + assertThat(sentHeaders.get(headerB)).isEqualTo("B-value"); + } + private Message createInputMessage(String sessId, String subsId, String dest, Principal principal) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create();