Guardrails - Controlling the Chaos
Guardrails are mechanisms that let you validate the output of the LLM to ensure it meets your expectations. Typically, you can:
-
Ensure that the format is correct (e.g., it is a JSON document with the right schema)
-
Detect hallucinations (especially in the context of a RAG)
-
Verify that the user input is not out of scope
Input guardrails are executed before the LLM is called, while output guardrails are executed after the LLM has produced its output. Failing an input guardrail prevents the LLM from being called. Failing an output guardrail allows retrying or reprompting to improve the response.
Input Guardrails
Input guardrails are functions invoked before the LLM is called.
Implementing Input Guardrails
Input guardrails are implemented as CDI beans and must implement the io.quarkiverse.langchain4j.guardrails.InputGuardrail
interface:
package io.quarkiverse.langchain4j.guardrails;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.rag.AugmentationResult;
/**
* An input guardrail is a rule that is applied to the input of the model to ensure that the input (the user message) is
* safe and meets the expectations of the model.
* <p>
* Implementation should be exposed as a CDI bean, and the class name configured in {@link InputGuardrails#value()} annotation.
*/
public interface InputGuardrail extends Guardrail<InputGuardrailParams, InputGuardrailResult> {
/**
* Validates the {@code user message} that will be sent to the LLM.
* <p>
*
* @param userMessage the response from the LLM
*/
default InputGuardrailResult validate(UserMessage userMessage) {
return failure("Validation not implemented");
}
/**
* Validates the input that will be sent to the LLM.
* <p>
* Unlike {@link #validate(UserMessage)}, this method allows to access the memory and the augmentation result (in the
* case of a RAG).
* <p>
* Implementation must not attempt to write to the memory or the augmentation result.
*
* @param params the parameters, including the user message, the memory (maybe null),
* and the augmentation result (maybe null). Cannot be {@code null}
*/
@Override
default InputGuardrailResult validate(InputGuardrailParams params) {
return validate(params.userMessage());
}
/**
* Represents the parameter passed to {@link InputGuardrail#validate(InputGuardrailParams)}.
*
* @param userMessage the user message, cannot be {@code null}
* @param memory the memory, can be {@code null} or empty
* @param augmentationResult the augmentation result, can be {@code null}
*/
record InputGuardrailParams(UserMessage userMessage, ChatMemory memory,
AugmentationResult augmentationResult) implements GuardrailParams {
}
}
The validate
method of the InputGuardrail
interface can have two signatures:
-
InputGuardrailParams validate(UserMessage responseFromLLM)
-
InputGuardrailParams validate(InputGuardrailParams params)
The first one is used when the guardrail only needs the user message. Simple guardrails can use this method. The second one is used for more complex guardrails that need more information, like the memory or the augmentation results. For example, they can check that there are enough documents in the augmentation results, or that the user is not asking the same question multiple times.
Input Guardrails Outcome
Input guardrails can have three outcomes:
-
pass - The input is valid, the next guardrail is executed, and if the last guardrail passes, the LLM is called.
-
fail - The input is invalid, but the next guardrail is executed the same, in order to accumulate all the possible validation problems. The LLM is not called.
-
fatal - The input is invalid, the next guardrail is not executed, and the error is rethrown. The LLM is not called.
Declaring Input Guardrails
Input guardrails are declared on the AI Service interface. You can declare output guardrails in two ways:
-
By annotating the AI Service interface with
@InputGuardrails
and listing the guardrails - these guardrails will be applied to all the methods of the AI Service. -
By annotating the method of the AI Service with
@InputGuardrails
and listing the guardrails - these guardrails will be applied to this method only.
Method guardrails take precedence over class guardrails. |
Here is an example of an AI Service interface with input guardrails:
import dev.langchain4j.service.SystemMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
import jakarta.enterprise.context.SessionScoped;
@RegisterAiService(retrievalAugmentor = Retriever.class)
@SystemMessage("""
You are Mona, a chatbot answering questions about a museum. Be polite, concise, and helpful.
""")
@SessionScoped
public interface ChatBot {
@InputGuardrails(InScopeGuard.class)
String chat(String question);
}
Input Guardrail Chain
You can declare multiple guardrails. In this case, a chain is created, and the guardrails are executed in the order they are declared. Thus, the order of the guardrails is important.
@RegisterAiService
@SystemMessage("""
You are simulating fights between a superhero and a supervillain.
""")
public interface Simulator {
@UserMessage("""
Simulate a fight between:
- a hero: {hero}
- a villain: {villain}
""")
@InputGuardrails({VerifyHeroFormat.class, VerifyVillainFormat.class})
FightResult fight(Hero hero, Villain villain);
}
In this example, the VerifyHeroFormat
is executed first to check that the passed hero is valid.
Then, the VerifyVillainFormat
is executed to check that the villain is valid.
If the VerifyHeroFormat
fails, the VerifyVillainFormat
may or may not be executed depending on whether the failure is fatal or not. For instance the VerifyHeroFormat
could be implemented as it follows.
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
import jakarta.enterprise.context.ApplicationScoped;
@ApplicationScoped
public class VerifyHeroFormat implements InputGuardrail {
@Override
public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) {
String text = um.singleText();
if (text.length() > 1000) {
// a fatal failure, the next InputGuardrail won't be evaluated
return fatal("Input too long, size = " + text.length());
}
if (!text.contains("hero")) {
// a normal failure, still allowing to evaluate also the next InputGuardrail and accumulate multiple failures
return failure("The input should contain the word 'hero'");
}
return success();
}
}
Output Guardrails
Output guardrails are functions invoked once the LLM has produced its output.
Implementing Output Guardrails
Output guardrails are implemented as CDI beans and must implement the io.quarkiverse.langchain4j.guardrails.OutputGuardrail
interface:
package io.quarkiverse.langchain4j.guardrails;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.rag.AugmentationResult;
/**
* An output guardrail is a rule that is applied to the output of the model to ensure that the output is safe and meets the
* expectations.
* <p>
* Implementation should be exposed as a CDI bean, and the class name configured in {@link OutputGuardrails#value()} annotation.
* <p>
* In the case of reprompting, the reprompt message is added to the LLM context and the request is retried.
* <p>
* The maximum number of retries is configurable using {@code quarkus.langchain4j.guardrails.max-retries}, defaulting to 3.
*/
public interface OutputGuardrail extends Guardrail<OutputGuardrailParams, OutputGuardrailResult> {
/**
* Validates the response from the LLM.
*
* @param responseFromLLM the response from the LLM
*/
default OutputGuardrailResult validate(AiMessage responseFromLLM) {
return failure("Validation not implemented");
}
/**
* Validates the response from the LLM.
* <p>
* Unlike {@link #validate(AiMessage)}, this method allows to access the memory and the augmentation result (in the
* case of a RAG).
* <p>
* Implementation must not attempt to write to the memory or the augmentation result.
*
* @param params the parameters, including the response from the LLM, the memory (maybe null),
* and the augmentation result (maybe null). Cannot be {@code null}
*/
@Override
default OutputGuardrailResult validate(OutputGuardrailParams params) {
return validate(params.responseFromLLM());
}
/**
* Represents the parameter passed to {@link OutputGuardrail#validate(OutputGuardrailParams)}.
*
* @param responseFromLLM the response from the LLM
* @param memory the memory, can be {@code null} or empty
* @param augmentationResult the augmentation result, can be {@code null}
*/
record OutputGuardrailParams(AiMessage responseFromLLM, ChatMemory memory,
AugmentationResult augmentationResult) implements GuardrailParams {
}
}
The validate
method of the OutputGuardrail
interface can have two signatures:
-
OutputGuardrailParams validate(AiMessage responseFromLLM)
-
OutputGuardrailParams validate(OutputGuardrailParams params)
The first one is used when the guardrail only needs the output of the LLM. Simple guardrails can use this method. For example, here is an output guardrail that checks that the output is a JSON document:
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.message.AiMessage;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
@ApplicationScoped
public class JsonGuardrail implements OutputGuardrail {
@Inject
ObjectMapper mapper;
@Override
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
try {
mapper.readTree(responseFromLLM.text());
} catch (Exception e) {
return reprompt("Invalid JSON", e, "Make sure you return a valid JSON object");
}
return success();
}
}
The second signature is used when the guardrail needs more information, like the augmentation results or the memory. Note that the guardrail cannot modify the memory or the augmentation results. The Detecting Hallucinations in the RAG Context section gives an example of guardrail using the augmented results.
Output Guardrails Outcome
Output guardrails can have six outcomes:
-
pass - The output is valid, the next guardrail is executed, and if the last guardrail passes, the output is returned to the caller.
-
pass with rewrite - The output isn’t valid in its original form, but has been rewritten in order to make it valid, the next guardrail is executed against the rewritten output, and if the last guardrail passes, the output is returned to the caller.
-
fail - The output is invalid, but the next guardrail is executed the same, in order to accumulate all the possible validation problems.
-
fatal - The output is invalid, the next guardrail is not executed, and the error is rethrown.
-
fatal with retry - The output is invalid, the next guardrail is not executed, and the LLM is called again with the same prompt.
-
fatal with reprompt - The output is invalid, the next guardrail is not executed, and the LLM is called again with a new prompt.
In fact if the validation fails, then the guardrail can specify whether the LLM should be retried or reprompted.
// Retry - The LLM is called again with the same prompt and context
// The guardrails will be called again with the new output
return retry("Invalid JSON");
// Retry with reprompt - The LLM is called again with a new prompt and context
// A new user message is added to the LLM context (memory), and the LLM is called again with this new context.
// The guardrails will be called again with the new output
return reprompt("Invalid JSON", "Make sure you return a valid JSON object");
By default, Quarkus Langchain4J will limit the number of retries to 3.
This is configurable using the quarkus.langchain4j.guardrails.max-retries
configuration property:
quarkus.langchain4j.guardrails.max-retries=5
Setting quarkus.langchain4j.guardrails.max-retries to 0 disables retries.
|
Declaring Output Guardrails
Output guardrails are declared on the AI Service interface. You can declare output guardrails in two ways:
-
By annotating the AI Service interface with
@OutputGuardrails
and listing the guardrails - these guardrails will be applied to all the methods of the AI Service. -
By annotating the method of the AI Service with
@OutputGuardrails
and listing the guardrails - these guardrails will be applied to this method only.
Method guardrails take precedence over class guardrails. |
Here is an example of an AI Service interface with output guardrails:
import dev.langchain4j.service.SystemMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
import jakarta.enterprise.context.SessionScoped;
@RegisterAiService(retrievalAugmentor = Retriever.class)
@SystemMessage("""
You are Mona, a chatbot answering questions about a museum. Be polite, concise, and helpful.
""")
@SessionScoped
public interface ChatBot {
@OutputGuardrails(HallucinationGuard.class)
String chat(String question);
}
Output Guardrail Chain
You can declare multiple guardrails. In this case, a chain is created, and the guardrails are executed in the order they are declared. Thus, the order of the guardrails is important.
Typically, it’s a good idea to have a guardrail that checks the format of the output first, and then a guardrail that checks the content.
@RegisterAiService
@SystemMessage("""
You are simulating fights between a superhero and a supervillain.
""")
public interface Simulator {
@UserMessage("""
Simulate a fight between:
- a hero: {hero}
- a villain: {villain}
""")
@OutputGuardrails({JsonGuardrail.class, ConsistentStoryGuardrail.class})
FightResult fight(Hero hero, Villain villain);
}
In this example, the JsonGuardrail
is executed first to check that the output is a valid JSON document.
Then, the ConsistentStoryGuardrail
is executed to check that the story is consistent.
If the JsonGuardrail
fails, the ConsistentStoryGuardrail
is not executed.
However, if the ConsistentStoryGuardrail
fails with a retry or reprompt, the JsonGuardrail
is executed again with the new response.
Output Guardrails for Streamed Responses
Output guardrails can be applied to methods that return Multi
, but not to those returning TokenStream
.
By default, Quarkus will automatically assemble the full response before executing the guardrail chain.
Keep in mind that this may have a performance impact when handling large responses.
To control when the guardrail chain is invoked during streaming, you can configure an accumulator:
@UserMessage("...")
@OutputGuardrails(MyGuardrail.class)
@OutputGuardrailAccumulator(PassThroughAccumulator.class) // Defines the accumulator
Multi<String> ask();
The @OutputGuardrailAccumulator
annotation allows you to specify a custom accumulator.
The accumulator must implement the io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator
interface and be exposed as a CDI bean.
Here’s a simple example of an accumulator that does not perform any accumulation:
@ApplicationScoped
public class PassThroughAccumulator implements OutputTokenAccumulator {
@Override
public Multi<String> accumulate(Multi<String> tokens) {
return tokens; // Passes the tokens through without accumulating
}
}
You can create accumulators based on various criteria, such as the number of tokens, a specific separator, or time intervals.
When an accumulator is set, the output guardrail chain is invoked for each item emitted by the Multi
returned by the accumulate
method.
In the case of a retry, the accumulator is called again with the new response, restarting the stream from the beginning. The same behavior applies for reprompts.
Detecting Hallucinations in the RAG Context
This section is an example of how to implement a guardrail that detects hallucinations in the context of a RAG. The idea is to check that the output of the LLM is consistent with the augmentation results.
package me.escoffier.langchain4j.nomic;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.rag.content.Content;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
import io.quarkus.logging.Log;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import org.eclipse.microprofile.config.inject.ConfigProperty;
@ApplicationScoped
public class HallucinationGuard implements OutputGuardrail {
@Inject
NomicEmbeddingV1 embedding;
@ConfigProperty(name = "hallucination.threshold", defaultValue = "0.7")
double threshold;
@Override
public OutputGuardrailResult validate(OutputGuardrailParams params) {
Response<Embedding> embeddingOfTheResponse = embedding.embed(params.responseFromLLM().text());
if (params.augmentationResult() == null || params.augmentationResult().contents().isEmpty()) {
Log.info("No content to validate against");
return success();
}
float[] vectorOfTheResponse = embeddingOfTheResponse.content().vector();
for (Content content : params.augmentationResult().contents()) {
Response<Embedding> embeddingOfTheContent = embedding.embed(content.textSegment());
float[] vectorOfTheContent = embeddingOfTheContent.content().vector();
double distance = cosineDistance(vectorOfTheResponse, vectorOfTheContent);
if (distance < threshold) {
Log.info("Passed hallucination guardrail: " + distance);
return success();
}
}
return reprompt("Hallucination detected", "Make sure you use the given documents to produce the response");
}
public static double cosineDistance(float[] vector1, float[] vector2) {
double dotProduct = 0.0;
double normA = 0.0;
double normB = 0.0;
for (int i = 0; i < vector1.length; i++) {
dotProduct += vector1[i] * vector2[i];
normA += Math.pow(vector1[i], 2);
normB += Math.pow(vector2[i], 2);
}
double cosineSimilarity = dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
return 1.0 - cosineSimilarity;
}
}
Rewriting the LLM output
It may happen that the output generated by the LLM is not completely satisfying, but it can be programmatically adjusted instead of attempting a retry or a remprompt, both implying a costly, time consuming and less reliable new interaction with the LLM. For instance it is quite common that an LLM produces the json of the data object that it is required to extract from the user prompt, but appends to it some unwanted explanation of why it generated that result, making the json unparsable, something like
{"name":"Alex", age:18} Alex is 18 since he became an adult a few days ago.
In this situation it is better to try to programmatically trim the json part of the response and check if we can deserialize a valid Person object out of it, before trying to reprompt the LLM again. If the programmatic extraction of the json string from the partially hallucinated LLM output succeeds, it is possible to propagate the rewritten output through the successWith
method.
This scenario is so common that it is already provided an abstract class implementing the OutputGuardrail
interface and performing this programmatic json sanitization out-of-the-box.
import jakarta.inject.Inject;
import org.jboss.logging.Logger;
import com.fasterxml.jackson.core.type.TypeReference;
import dev.langchain4j.data.message.AiMessage;
public abstract class AbstractJsonExtractorOutputGuardrail implements OutputGuardrail {
@Inject
Logger logger;
@Inject
JsonGuardrailsUtils jsonGuardrailsUtils;
protected AbstractJsonExtractorOutputGuardrail() {
if (getOutputClass() == null && getOutputType() == null) {
throw new IllegalArgumentException("Either getOutputClass() or getOutputType() must be implemented");
}
}
@Override
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
String llmResponse = responseFromLLM.text();
logger.debugf("LLM output: %s", llmResponse);
Object result = deserialize(llmResponse);
if (result != null) {
return successWith(llmResponse, result);
}
String json = jsonGuardrailsUtils.trimNonJson(llmResponse);
if (json != null) {
result = deserialize(json);
if (result != null) {
return successWith(json, result);
}
}
return reprompt("Invalid JSON",
"Make sure you return a valid JSON object following "
+ "the specified format");
}
protected Object deserialize(String llmResponse) {
return getOutputClass() != null ? jsonGuardrailsUtils.deserialize(llmResponse, getOutputClass())
: jsonGuardrailsUtils.deserialize(llmResponse, getOutputType());
}
protected Class<?> getOutputClass() {
return null;
}
protected TypeReference<?> getOutputType() {
return null;
}
}
This implementation, first tries to deserialize the LLM response into the expected class to be returned by the data extraction. If this doesn’t succeed it tries to trim away the non-json part of the response and perform the deserialization again. Note that in both case together with the json response, either the original LLM one or the one programmatically trimmed, the successWith
method also returns the resulting deserialized object, so that it could be used directly as the final response of the data extraction, instead of uselessly having to execute a second deserialization. In case that both these attempts of deserialization fail then the OutputGuardrail
perform a reprompt, hoping that the LLM will finally produce a valid json string.
In this way if for example there is an AI service trying to extract the data of a customer from the user prompts like the following
@RegisterAiService
public interface CustomerExtractor {
@UserMessage("Extract information about a customer from this text '{text}'. The response must contain only the JSON with customer's data and without any other sentence.")
@OutputGuardrails(CustomerExtractionOutputGuardrail.class)
Customer extractData(String text);
}
it is possible to use with it an OutputGuardrail
that sanitizes the json LLM response by simply extending the former abstract class and declaring which is the expected output class of the data extraction.
@ApplicationScoped
public class CustomerExtractionOutputGuardrail extends AbstractJsonExtractorOutputGuardrail {
@Override
protected Class<?> getOutputClass() {
return Customer.class;
}
}
Note that if the data extraction requires a generified Java type, like a List<Customer>
, it is conversely necessary to extend the getOutputType
and return a Jackson’s TypeReference
as it follows:
@ApplicationScoped
public class CustomersExtractionOutputGuardrail extends AbstractJsonExtractorOutputGuardrail {
@Override
protected TypeReference<?> getOutputType() {
return new TypeReference<List<Customer>>() {};
}
}