Waiting for Spring AI Streaming Responses in JUnit Tests

Recently I taught a training class on Spring AI, the newly released 1.0.0 module from the Spring framework, that lets you incorporate AI tools into Java systems. I like to do simple demos for the students, but rather than write them inside a Java main method, I prefer test cases.

Tales from the jar side banner

For example, here’s a JUnit test that creates a ChatClient and asks it a question:

@Test
void simpleQuery(@Autowired OpenAiChatModel model) {
    var chatClient = ChatClient.create(model);
    String question = "Why is the sky blue?";
    String response = chatClient.prompt()
            .user(question)
            .call()
            .content();
    System.out.println(response);
}

The test uses an autowired chat model from OpenAI, as configured in the application.properties file:

spring.ai.openai.api-key=${OPENAI_API_KEY}
spring.ai.openai.chat.options.model=gpt-4.1

(You probably noticed I’m not really testing anything in that test. How to test AI responses, with their zillions of possible variations, will be the subject of another post.)

You get an answer like:

The sky appears blue because of a phenomenon called Rayleigh scattering.

Here’s how it works:

  • Sunlight may look white, but it is actually made up of many colors (all the colors of the rainbow).
  • As sunlight passes through Earth’s atmosphere, it interacts with air molecules and tiny particles.
  • The shorter wavelengths of light (blue and violet) are scattered in all directions by the gases and particles in the atmosphere much more than the longer wavelengths (red, orange, yellow).
  • Our eyes are more sensitive to blue light than violet, and some violet light is absorbed by the upper atmosphere, so the sky appears mostly blue to us.

In summary:
The blue color of the sky is due to the scattering of shorter wavelength blue light by air molecules in the atmosphere.

This test works fine because the call() method is synchronous. It blocks until we get a response. We then extract the string content from it and print the result.

Where life gets interesting, however, is if you use the stream() method instead, which returns a series of tokens asynchronously. If you try to do the same thing as above, you quickly run into a problem. (In this example, the chatClient initialization has been moved to a @BeforeEach method.)

@Test
void streamingChatDoOnNext() {
    Flux<String> output = chatClient.prompt()
        .user("Why is the sky blue?")
        .stream()
        .content();
    System.out.println(response);
}

The stream() method here means this query will return a Flux instance, which is a class from Project Reactor, part of the Reactive Spring module. In the dependencies, that’s called spring-webflux. Interestingly, you don’t have to add that module to your build, because it’s a transitive dependency of the Spring AI module.

A software development environment screenshot showing the usages of the 'spring-webflux' dependency version 6.2.6 and related dependencies for a Spring AI project.
Gradle dependencies showing spring-webflux
as a transitive dependency of spring-ai

The problem with this test is that it will exit before the asynchronous responses come back. In other words, the test ends and we get nothing. What we need is a way to tell the test thread to wait until the query thread is finished. That’s what this blog post is all about.

There are several ways to accomplish this, from using built-in Java library classes to working with the Project Reactor library directly.

1. Use a CountDownLatch

One simple (in the sense that it uses only existing Java library classes) way to handle this, is to take advantage of the CountDownLatch class from java.util.concurrent:

@Test
void streamingChatCountDownLatch() throws InterruptedException {
    Flux<String> output = chatClient.prompt()
            .user("Why is the sky blue?")
            .stream()
            .content();

    var latch = new CountDownLatch(1);
    output.subscribe(
            System.out::println,
            e -> {
                System.out.println("Error: " + e.getMessage());
                latch.countDown();
            },
            () -> {
                System.out.println("Completed");
                latch.countDown();
            }
    );
    latch.await();
}

The idea here is that the CountDownLatch waits until a thread calls the countDown method, at which point it decrements its argument. The argument here is just one, because we only need to wait for the stream to finish. The await() method takes care of that. The three-argument subscribe() method from Flux takes two Consumer instances and a Runnable. The first one is invoked when there is a successful response, the second gets called on an error condition, and the last one is invoked when the response is complete. Here’s the so-called “marble diagram” from the Javadocs:

A diagram illustrating the subscribe method in reactive programming, showcasing the flow of events and handling of emissions, errors, and completion signals.
Marble diagram for the three-argument subscribe method on Flux

The result is that the code prints each token being returned on a separate line:

The
sky
appears
blue
because
of
a
phenomenon
called
**
Ray
leigh
scattering
**
.

And so on. If an error occurs, it prints that and counts down on the latch. When the response completes normally, the word “Completed” is printed and again the latch counts down, completing the process.

The process is a little involved, but is a very typical usage of a CountDownLatch, in case you’ve never used one before.

2. The doOn… Methods in Flux

The second approach is to streamline this a bit by using the “doOn…” methods from Flux. Here’s the code, followed by a bit of explanation:

@Test
void streamingChatDoOnNext() {
    Flux<String> output = chatClient.prompt()
          .user("Why is the sky blue?")
          .stream()
          .content();

    output.doOnNext(System.out::println)
          .doOnError(e -> System.out.println(
                    "Error: " + e.getMessage()))
          .doOnCancel(() -> System.out.println("Cancelled"))
          .doOnComplete(() -> System.out.println("Completed"))
          .blockLast();
}

The doOn methods are rather like the peek() method on regular Java streams. They are part of the reactive lifecycle, and are called automatically by the Reactor framework as events occur in the stream:

  • doOnNext() is called for each element that flows through the stream
  • doOnError() is called if an error occurs anywhere upstream
  • doOnCancel() is called if the subscription is cancelled
  • doOnComplete() is called when the stream completes normally

Note that these are all side-effects. They don’t affect the signals emitted by the stream in any way.

As with all reactive streams, nothing happens until you have a subscriber and a terminal operation.

Joke: How many reactive coders does it take to change a lightbulb? Only one, but if nobody watches the bulb, the room stays dark.

Here, the blockLast() method is what triggers everything. Like the previous example, we could use a countdown latch here, but blockLast() subscribes to this Flux and blocks indefinitely until the upstream signals its last value or completes. Here’s its marble diagram:

Diagram illustrating the behavior of the blockLast() method in reactive streams, showing how it subscribes to a flux and manages the return value.
Marble diagram for blockLast()

3. StepVerifier from reactor-test

Our third alternative, and arguably the preferred one, is to use the StepVerifier class from Project Reactor’s test module. When you add the webflux starter to a Spring Boot app, the reactor-test dependency is added automatically. With Spring AI, however, that doesn’t happen, so you need to add it yourself:

testImplementation("io.projectreactor:reactor-test")

That will add the reactor.test.StepVerifier class, among others, to the classpath and you can use it in a test:

@Test 
// Note: Requires the reactor-test dependency
// (not included in the starter)
void streamingChatStepVerifier() {
    Flux<String> output = chatClient.prompt()
          .user("Why is the sky blue?")
          .stream()
          .content();

    output.as(StepVerifier::create)
          .expectSubscription()
          .thenConsumeWhile(s -> true, System.out::println)
          .verifyComplete();
}

StepVerifier creates a test subscriber that lets you define expectations about what should happen in your reactive stream, step by step.

The as() method transforms the Flux into a StepVerifier.FirstStep, which is a fluent way to wrap your stream for testing. Then expectSubscription does what it sounds like — verifies the subscription actually occurs. The call to the two-argument version of thenConsumeWhile takes a Predicate which is always true in this case, and a Consumer to print the value. This consumes all the elements without making any specific assertions on their content. Finally, the call to verifyComplete expects the stream to complete normally and triggers the actual verification.

The general recommendation is to use the StepVerifier whenever you want to verify specific stream behavior, whereas blockLast with the doOn methods is good for simple integration tests when you just want to see the output. All three of these approaches work, however, so use whatever technique makes the most sense to you.

I should note there are also libraries like Awaitility that are specifically designed to handle asynchronous scenarios, but that requires a bit more setup. For my purposes, all I wanted to do was demonstrate how the stream() call worked, so that seemed like overkill.

(Again I’ll note that none of these tests actually test anything. I’ll cover that in another blog post.)

I hope this helps. Feel free to comment here, or on my Tales from the jar side newsletter or YouTube channel.

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.

Discover more from Stuff I've learned recently...

Subscribe now to keep reading and get access to the full archive.

Continue reading