Guardrails - Controlling the Chaos
Guardrails are mechanisms that let you validate the output of the LLM to ensure it meets your expectations. Typically, you can:
-
Ensure that the format is correct (e.g., it is a JSON document with the right schema)
-
Detect hallucinations (especially in the context of a RAG)
-
Verify that the user input is not out of scope
Input guardrails are executed before the LLM is called, while output guardrails are executed after the LLM has produced its output. Failing an input guardrail prevents the LLM from being called. Failing an output guardrail allows retrying or reprompting to improve the response.
Input Guardrails
Input guardrails are functions invoked before the LLM is called.
Implementing Input Guardrails
Input guardrails are implemented as CDI beans and must implement the io.quarkiverse.langchain4j.guardrails.InputGuardrail
interface:
package io.quarkiverse.langchain4j.guardrails;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.rag.AugmentationResult;
/**
* An input guardrail is a rule that is applied to the input of the model to ensure that the input (the user message) is
* safe and meets the expectations of the model.
* <p>
* Implementation should be exposed as a CDI bean, and the class name configured in {@link InputGuardrails#value()} annotation.
*/
public interface InputGuardrail extends Guardrail<InputGuardrail.InputGuardrailParams, InputGuardrailResult> {
/**
* Validates the {@code user message} that will be sent to the LLM.
* <p>
*
* @param userMessage the response from the LLM
*/
default InputGuardrailResult validate(UserMessage userMessage) {
return failure("Validation not implemented");
}
/**
* Validates the input that will be sent to the LLM.
* <p>
* Unlike {@link #validate(UserMessage)}, this method allows to access the memory and the augmentation result (in the
* case of a RAG).
* <p>
* Implementation must not attempt to write to the memory or the augmentation result.
*
* @param params the parameters, including the user message, the memory (maybe null),
* and the augmentation result (maybe null). Cannot be {@code null}
*/
@Override
default InputGuardrailResult validate(InputGuardrailParams params) {
return validate(params.userMessage());
}
/**
* Represents the parameter passed to {@link #validate(InputGuardrailParams)}.
*
* @param userMessage the user message, cannot be {@code null}
* @param memory the memory, can be {@code null} or empty
* @param augmentationResult the augmentation result, can be {@code null}
*/
record InputGuardrailParams(UserMessage userMessage, ChatMemory memory,
AugmentationResult augmentationResult) implements GuardrailParams {
}
}
The validate
method of the InputGuardrail
interface can have two signatures:
-
InputGuardrailParams validate(UserMessage responseFromLLM)
-
InputGuardrailParams validate(InputGuardrailParams params)
The first one is used when the guardrail only needs the user message. Simple guardrails can use this method. The second one is used for more complex guardrails that need more information, like the memory or the augmentation results. For example, they can check that there are enough documents in the augmentation results, or that the user is not asking the same question multiple times.
Input Guardrails Outcome
Input guardrails can have three outcomes:
-
pass - The input is valid, the next guardrail is executed, and if the last guardrail passes, the LLM is called.
-
fail - The input is invalid, but the next guardrail is executed the same, in order to accumulate all the possible validation problems. The LLM is not called.
-
fatal - The input is invalid, the next guardrail is not executed, and the error is rethrown. The LLM is not called.
Declaring Input Guardrails
Input guardrails are declared on the AI Service interface. You can declare output guardrails in two ways:
-
By annotating the AI Service interface with
@InputGuardrails
and listing the guardrails - these guardrails will be applied to all the methods of the AI Service. -
By annotating the method of the AI Service with
@InputGuardrails
and listing the guardrails - these guardrails will be applied to this method only.
Method guardrails take precedence over class guardrails. |
Here is an example of an AI Service interface with input guardrails:
import dev.langchain4j.service.SystemMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
import jakarta.enterprise.context.SessionScoped;
@RegisterAiService(retrievalAugmentor = Retriever.class)
@SystemMessage("""
You are Mona, a chatbot answering questions about a museum. Be polite, concise, and helpful.
""")
@SessionScoped
public interface ChatBot {
@InputGuardrails(InScopeGuard.class)
String chat(String question);
}
Input Guardrail Chain
You can declare multiple guardrails. In this case, a chain is created, and the guardrails are executed in the order they are declared. Thus, the order of the guardrails is important.
@RegisterAiService
@SystemMessage("""
You are simulating fights between a superhero and a supervillain.
""")
public interface Simulator {
@UserMessage("""
Simulate a fight between:
- a hero: {hero}
- a villain: {villain}
""")
@InputGuardrails({VerifyHeroFormat.class, VerifyVillainFormat.class})
FightResult fight(Hero hero, Villain villain);
}
In this example, the VerifyHeroFormat
is executed first to check that the passed hero is valid.
Then, the VerifyVillainFormat
is executed to check that the villain is valid.
If the VerifyHeroFormat
fails, the VerifyVillainFormat
may or may not be executed depending on whether the failure is fatal or not. For instance the VerifyHeroFormat
could be implemented as it follows.
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
import jakarta.enterprise.context.ApplicationScoped;
@ApplicationScoped
public class VerifyHeroFormat implements InputGuardrail {
@Override
public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) {
String text = um.singleText();
if (text.length() > 1000) {
// a fatal failure, the next InputGuardrail won't be evaluated
return fatal("Input too long, size = " + text.length());
}
if (!text.contains("hero")) {
// a normal failure, still allowing to evaluate also the next InputGuardrail and accumulate multiple failures
return failure("The input should contain the word 'hero'");
}
return success();
}
}
Output Guardrails
Output guardrails are functions invoked once the LLM has produced its output.
Implementing Output Guardrails
Output guardrails are implemented as CDI beans and must implement the io.quarkiverse.langchain4j.guardrails.OutputGuardrail
interface:
package io.quarkiverse.langchain4j.guardrails;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.rag.AugmentationResult;
/**
* An output guardrail is a rule that is applied to the output of the model to ensure that the output is safe and meets the
* expectations.
* <p>
* Implementation should be exposed as a CDI bean, and the class name configured in {@link OutputGuardrails#value()} annotation.
* <p>
* In the case of reprompting, the reprompt message is added to the LLM context and the request is retried.
* <p>
* The maximum number of retries is configurable using {@code quarkus.langchain4j.guardrails.max-retries}, defaulting to 3.
*/
public interface OutputGuardrail extends Guardrail<OutputGuardrail.OutputGuardrailParams, OutputGuardrailResult> {
/**
* Validates the response from the LLM.
*
* @param responseFromLLM the response from the LLM
*/
default OutputGuardrailResult validate(AiMessage responseFromLLM) {
return failure("Validation not implemented");
}
/**
* Validates the response from the LLM.
* <p>
* Unlike {@link #validate(AiMessage)}, this method allows to access the memory and the augmentation result (in the
* case of a RAG).
* <p>
* Implementation must not attempt to write to the memory or the augmentation result.
*
* @param params the parameters, including the response from the LLM, the memory (maybe null),
* and the augmentation result (maybe null). Cannot be {@code null}
*/
@Override
default OutputGuardrailResult validate(OutputGuardrailParams params) {
return validate(params.responseFromLLM());
}
/**
* Represents the parameter passed to {@link #validate(OutputGuardrailParams)}.
*
* @param responseFromLLM the response from the LLM
* @param memory the memory, can be {@code null} or empty
* @param augmentationResult the augmentation result, can be {@code null}
*/
record OutputGuardrailParams(AiMessage responseFromLLM, ChatMemory memory,
AugmentationResult augmentationResult) implements GuardrailParams {
}
}
The validate
method of the OutputGuardrail
interface can have two signatures:
-
OutputGuardrailParams validate(AiMessage responseFromLLM)
-
OutputGuardrailParams validate(OutputGuardrailParams params)
The first one is used when the guardrail only needs the output of the LLM. Simple guardrails can use this method. For example, here is an output guardrail that checks that the output is a JSON document:
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.message.AiMessage;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
@ApplicationScoped
public class JsonGuardrail implements OutputGuardrail {
@Inject
ObjectMapper mapper;
@Override
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
try {
mapper.readTree(responseFromLLM.text());
} catch (Exception e) {
return reprompt("Invalid JSON", e, "Make sure you return a valid JSON object");
}
return success();
}
}
The second signature is used when the guardrail needs more information, like the augmentation results or the memory. Note that the guardrail cannot modify the memory or the augmentation results. The Detecting Hallucinations in the RAG Context section gives an example of guardrail using the augmented results.
Output Guardrails Outcome
Output guardrails can have five outcomes:
-
pass - The output is valid, the next guardrail is executed, and if the last guardrail passes, the output is returned to the caller.
-
fail - The output is invalid, but the next guardrail is executed the same, in order to accumulate all the possible validation problems.
-
fatal - The output is invalid, the next guardrail is not executed, and the error is rethrown.
-
fatal with retry - The output is invalid, the next guardrail is not executed, and the LLM is called again with the same prompt.
-
fatal with reprompt - The output is invalid, the next guardrail is not executed, and the LLM is called again with a new prompt.
In fact if the validation fails, then the guardrail can specify whether the LLM should be retried or reprompted.
// Retry - The LLM is called again with the same prompt and context
// The guardrails will be called again with the new output
return retry("Invalid JSON");
// Retry with reprompt - The LLM is called again with a new prompt and context
// A new user message is added to the LLM context (memory), and the LLM is called again with this new context.
// The guardrails will be called again with the new output
return reprompt("Invalid JSON", "Make sure you return a valid JSON object");
By default, Quarkus Langchain4J will limit the number of retries to 3.
This is configurable using the quarkus.langchain4j.guardrails.max-retries
configuration property:
quarkus.langchain4j.guardrails.max-retries=5
Setting quarkus.langchain4j.guardrails.max-retries to 0 disables retries.
|
Output Guardrails Scopes
Output guardrails are CDI beans. They can be in any CDI scope, including request scope, application scope, or session scope.
The scope of the guardrail is important as it defines the lifecycle of the guardrail, especially when the guardrail is stateful.
You cannot use output guardrails on AI service method returning a Multi .
This is because it would require accumulating the items from the stream before calling the guardrails.
This would go against the idea of streaming the results.
|
Declaring Output Guardrails
Output guardrails are declared on the AI Service interface. You can declare output guardrails in two ways:
-
By annotating the AI Service interface with
@OutputGuardrails
and listing the guardrails - these guardrails will be applied to all the methods of the AI Service. -
By annotating the method of the AI Service with
@OutputGuardrails
and listing the guardrails - these guardrails will be applied to this method only.
Method guardrails take precedence over class guardrails. |
Here is an example of an AI Service interface with output guardrails:
import dev.langchain4j.service.SystemMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
import jakarta.enterprise.context.SessionScoped;
@RegisterAiService(retrievalAugmentor = Retriever.class)
@SystemMessage("""
You are Mona, a chatbot answering questions about a museum. Be polite, concise, and helpful.
""")
@SessionScoped
public interface ChatBot {
@OutputGuardrails(HallucinationGuard.class)
String chat(String question);
}
Output Guardrail Chain
You can declare multiple guardrails. In this case, a chain is created, and the guardrails are executed in the order they are declared. Thus, the order of the guardrails is important.
Typically, it’s a good idea to have a guardrail that checks the format of the output first, and then a guardrail that checks the content.
@RegisterAiService
@SystemMessage("""
You are simulating fights between a superhero and a supervillain.
""")
public interface Simulator {
@UserMessage("""
Simulate a fight between:
- a hero: {hero}
- a villain: {villain}
""")
@OutputGuardrails({JsonGuardrail.class, ConsistentStoryGuardrail.class})
FightResult fight(Hero hero, Villain villain);
}
In this example, the JsonGuardrail
is executed first to check that the output is a valid JSON document.
Then, the ConsistentStoryGuardrail
is executed to check that the story is consistent.
If the JsonGuardrail
fails, the ConsistentStoryGuardrail
is not executed.
However, if the ConsistentStoryGuardrail
fails with a retry or reprompt, the JsonGuardrail
is executed again with the new response.
Detecting Hallucinations in the RAG Context
This section is an example of how to implement a guardrail that detects hallucinations in the context of a RAG. The idea is to check that the output of the LLM is consistent with the augmentation results.
package me.escoffier.langchain4j.nomic;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.rag.content.Content;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
import io.quarkus.logging.Log;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import org.eclipse.microprofile.config.inject.ConfigProperty;
@ApplicationScoped
public class HallucinationGuard implements OutputGuardrail {
@Inject
NomicEmbeddingV1 embedding;
@ConfigProperty(name = "hallucination.threshold", defaultValue = "0.7")
double threshold;
@Override
public OutputGuardrailResult validate(OutputGuardrailParams params) {
Response<Embedding> embeddingOfTheResponse = embedding.embed(params.responseFromLLM().text());
if (params.augmentationResult() == null || params.augmentationResult().contents().isEmpty()) {
Log.info("No content to validate against");
return success();
}
float[] vectorOfTheResponse = embeddingOfTheResponse.content().vector();
for (Content content : params.augmentationResult().contents()) {
Response<Embedding> embeddingOfTheContent = embedding.embed(content.textSegment());
float[] vectorOfTheContent = embeddingOfTheContent.content().vector();
double distance = cosineDistance(vectorOfTheResponse, vectorOfTheContent);
if (distance < threshold) {
Log.info("Passed hallucination guardrail: " + distance);
return success();
}
}
return reprompt("Hallucination detected", "Make sure you use the given documents to produce the response");
}
public static double cosineDistance(float[] vector1, float[] vector2) {
double dotProduct = 0.0;
double normA = 0.0;
double normB = 0.0;
for (int i = 0; i < vector1.length; i++) {
dotProduct += vector1[i] * vector2[i];
normA += Math.pow(vector1[i], 2);
normB += Math.pow(vector2[i], 2);
}
double cosineSimilarity = dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
return 1.0 - cosineSimilarity;
}
}