Response Augmenter
Response Augmenter allows extending the response from the LLM. A typical use case is to add the sources that the LLM used to compute the response.
Response Augmenter is like a post-processor that is executed after the LLM has computed the response and after that method has been mapped to an object (when structured output is used).
Implementing a Response Augmenter
A response augmenter is a CDI bean implementing the io.quarkiverse.langchain4j.response.AiResponseAugmenter
.
This interface defines a method used for imperative interaction and one for streaming responses.
The implementation has to implement at least one of these methods (but possibly both):
package io.quarkiverse.langchain4j.response;
import io.smallrye.mutiny.Multi;
/**
* A CDI bean willing to manipulate the response of the AI model needs to implement this interface.
* An AI method that wants to use an augmenter should be annotated with {@link ResponseAugmenter}, indicating the
* augmenter implementation class name.
* <p>
* The default implementation keeps the response unchanged.
*
* @param <T> the type of the response
*/
public interface AiResponseAugmenter<T> {
/**
* Augment the response.
*
* @param response the response to augment
* @param params the parameters to use for the augmentation
* @return the augmented response
*/
default T augment(T response, ResponseAugmenterParams params) {
return response;
}
/**
* Augment a streamed response.
*
* @param stream the stream to augment
* @param params the parameters to use for the augmentation
* @return the augmented stream
*/
default Multi<T> augment(Multi<T> stream, ResponseAugmenterParams params) {
return stream;
}
}
The ResponseAugmenterParams
object contains the following information:
-
The user message
-
The chat memory
-
The augmentation result (RAG text segments)
-
The user message template
-
The variables used to compute the user message from the template
The implementation can do whatever it wants with the response, including changing it, adding information, etc. Note that for streamed responses, the method is called on the event loop.
Using a Response Augmenter
Once you have implemented a response augmenter, you can use it in your AI service by annotating the method with @ResponseAugmenter
:
@SessionScoped
@RegisterAiService
public interface CustomerSupportAgent {
@SystemMessage("""
...
""")
@InputGuardrails(PromptInjectionGuard.class)
@ToolBox(BookingRepository.class)
@ResponseAugmenter(SourceAugmenter.class) // <--- here
String chat(String userMessage);
}
In this example, the SourceAugmenter
class is used to augment the response.
Example
Here is an example of a response augmenter that adds the sources used to compute the response:
@ApplicationScoped
public class SourceAugmenter implements AiResponseAugmenter<String> {
@Inject
EmbeddingModel embeddingModel;
record SourceEmbedding(TextSegment textSegment, String file, Embedding embedding) {}
@Override
public String augment(String response, ResponseAugmenterParams params) {
// Only add sources that are similar to the computed response
var embeddingOfTheResponse = embeddingModel.embed(response).content();
List<SourceEmbedding> sources = params.augmentationResult()
.contents().stream().map(c -> {
var embedding = embeddingModel.embed(c.textSegment().text()).content();
// Extract the "source" of the content from the metadata:
return new SourceEmbedding(c.textSegment(),
c.textSegment().metadata().getString("file"), embedding);
}).toList();
// Ignore segments not similar enough
Set<SourceEmbedding> filtered = filter(embeddingOfTheResponse, sources);
// Remove duplicates
Set<String> names = new LinkedHashSet<>();
for (var source : filtered) {
names.add(source.file());
}
// Append the sources to the response
return response + " (Sources: "
+ String.join(", ", names) + ")";
}
private Set<SourceEmbedding> filter(Embedding embeddingOfTheResponse, List<SourceEmbedding> contents) {
Set<SourceEmbedding> filtered = new LinkedHashSet<>();
for (SourceEmbedding content : contents) {
double similarity = CosineSimilarity.between(embeddingOfTheResponse, content.embedding());
if (similarity > 0.85) {
filtered.add(content);
}
}
return filtered;
}
}