Guardrails - Controlling the Chaos

Guardrails are mechanisms that let you validate the output of the LLM to ensure it meets your expectations. Typically, you can:

  • Ensure that the format is correct (e.g., it is a JSON document with the right schema)

  • Detect hallucinations (especially in the context of a RAG)

  • Verify that the user input is not out of scope

Input guardrails are executed before the LLM is called, while output guardrails are executed after the LLM has produced its output. Failing an input guardrail prevents the LLM from being called. Failing an output guardrail allows retrying or reprompting to improve the response.

guardrails

Input Guardrails

Input guardrails are functions invoked before the LLM is called.

Implementing Input Guardrails

Input guardrails are implemented as CDI beans and must implement the io.quarkiverse.langchain4j.guardrails.InputGuardrail interface:

package io.quarkiverse.langchain4j.guardrails;

import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.rag.AugmentationResult;

/**
 * An input guardrail is a rule that is applied to the input of the model to ensure that the input (the user message) is
 * safe and meets the expectations of the model.
 * <p>
 * Implementation should be exposed as a CDI bean, and the class name configured in {@link InputGuardrails#value()} annotation.
 */
public interface InputGuardrail extends Guardrail<InputGuardrailParams, InputGuardrailResult> {

    /**
     * Validates the {@code user message} that will be sent to the LLM.
     * <p>
     *
     * @param userMessage the response from the LLM
     */
    default InputGuardrailResult validate(UserMessage userMessage) {
        return failure("Validation not implemented");
    }

    /**
     * Validates the input that will be sent to the LLM.
     * <p>
     * Unlike {@link #validate(UserMessage)}, this method allows to access the memory and the augmentation result (in the
     * case of a RAG).
     * <p>
     * Implementation must not attempt to write to the memory or the augmentation result.
     *
     * @param params the parameters, including the user message, the memory (maybe null),
     *        and the augmentation result (maybe null). Cannot be {@code null}
     */
    @Override
    default InputGuardrailResult validate(InputGuardrailParams params) {
        return validate(params.userMessage());
    }

    /**
     * Represents the parameter passed to {@link InputGuardrail#validate(InputGuardrailParams)}.
     *
     * @param userMessage the user message, cannot be {@code null}
     * @param memory the memory, can be {@code null} or empty
     * @param augmentationResult the augmentation result, can be {@code null}
     */
    record InputGuardrailParams(UserMessage userMessage, ChatMemory memory,
            AugmentationResult augmentationResult) implements GuardrailParams {
    }

}

The validate method of the InputGuardrail interface can have two signatures:

  • InputGuardrailParams validate(UserMessage responseFromLLM)

  • InputGuardrailParams validate(InputGuardrailParams params)

The first one is used when the guardrail only needs the user message. Simple guardrails can use this method. The second one is used for more complex guardrails that need more information, like the memory or the augmentation results. For example, they can check that there are enough documents in the augmentation results, or that the user is not asking the same question multiple times.

Input Guardrails Outcome

Input guardrails can have three outcomes:

  • pass - The input is valid, the next guardrail is executed, and if the last guardrail passes, the LLM is called.

  • fail - The input is invalid, but the next guardrail is executed the same, in order to accumulate all the possible validation problems. The LLM is not called.

  • fatal - The input is invalid, the next guardrail is not executed, and the error is rethrown. The LLM is not called.

Input Guardrails Scopes

Input guardrails are CDI beans. They can be in any CDI scope, including request scope, application scope, or session scope.

The scope of the guardrail is important as it defines the lifecycle of the guardrail, especially when the guardrail is stateful.

Declaring Input Guardrails

Input guardrails are declared on the AI Service interface. You can declare input guardrails in two ways:

  • By annotating the AI Service interface with @InputGuardrails and listing the guardrails - these guardrails will be applied to all the methods of the AI Service.

  • By annotating the method of the AI Service with @InputGuardrails and listing the guardrails - these guardrails will be applied to this method only.

Method guardrails take precedence over class guardrails.

Here is an example of an AI Service interface with input guardrails:

import dev.langchain4j.service.SystemMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
import jakarta.enterprise.context.SessionScoped;

@RegisterAiService(retrievalAugmentor = Retriever.class)
@SystemMessage("""
    You are Mona, a chatbot answering questions about a museum. Be polite, concise, and helpful.
""")
@SessionScoped
public interface ChatBot {

    @InputGuardrails(InScopeGuard.class)
    String chat(String question);

}

Input Guardrail Chain

You can declare multiple guardrails. In this case, a chain is created, and the guardrails are executed in the order they are declared. Thus, the order of the guardrails is important.

@RegisterAiService
@SystemMessage("""
    You are simulating fights between a superhero and a supervillain.
""")
public interface Simulator {

    @UserMessage("""
        Simulate a fight between:
        - a hero: {hero}
        - a villain: {villain}
    """)
    @InputGuardrails({VerifyHeroFormat.class, VerifyVillainFormat.class})
    FightResult fight(Hero hero, Villain villain);

}

In this example, the VerifyHeroFormat is executed first to check that the passed hero is valid. Then, the VerifyVillainFormat is executed to check that the villain is valid.

If the VerifyHeroFormat fails, the VerifyVillainFormat may or may not be executed depending on whether the failure is fatal or not. For instance, the VerifyHeroFormat could be implemented as it follows.

import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
import jakarta.enterprise.context.ApplicationScoped;

@ApplicationScoped
public class VerifyHeroFormat implements InputGuardrail {

    @Override
    public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) {
        String text = um.singleText();
        if (text.length() > 1000) {
            // a fatal failure, the next InputGuardrail won't be evaluated
            return fatal("Input too long, size = " + text.length());
        }
        if (!text.contains("hero")) {
            // a normal failure, still allowing to evaluate also the next InputGuardrail and accumulate multiple failures
            return failure("The input should contain the word 'hero'");
        }
        return success();
    }
}

Output Guardrails

Output guardrails are functions invoked once the LLM has produced its output.

Implementing Output Guardrails

Output guardrails are implemented as CDI beans and must implement the io.quarkiverse.langchain4j.guardrails.OutputGuardrail interface:

package io.quarkiverse.langchain4j.guardrails;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.rag.AugmentationResult;

/**
 * An output guardrail is a rule that is applied to the output of the model to ensure that the output is safe and meets the
 * expectations.
 * <p>
 * Implementation should be exposed as a CDI bean, and the class name configured in {@link OutputGuardrails#value()} annotation.
 * <p>
 * In the case of reprompting, the reprompt message is added to the LLM context and the request is retried.
 * <p>
 * The maximum number of retries is configurable using {@code quarkus.langchain4j.guardrails.max-retries}, defaulting to 3.
 */
public interface OutputGuardrail extends Guardrail<OutputGuardrailParams, OutputGuardrailResult> {

    /**
     * Validates the response from the LLM.
     *
     * @param responseFromLLM the response from the LLM
     */
    default OutputGuardrailResult validate(AiMessage responseFromLLM) {
        return failure("Validation not implemented");
    }

    /**
     * Validates the response from the LLM.
     * <p>
     * Unlike {@link #validate(AiMessage)}, this method allows to access the memory and the augmentation result (in the
     * case of a RAG).
     * <p>
     * Implementation must not attempt to write to the memory or the augmentation result.
     *
     * @param params the parameters, including the response from the LLM, the memory (maybe null),
     *        and the augmentation result (maybe null). Cannot be {@code null}
     */
    @Override
    default OutputGuardrailResult validate(OutputGuardrailParams params) {
        return validate(params.responseFromLLM());
    }

    /**
     * Represents the parameter passed to {@link OutputGuardrail#validate(OutputGuardrailParams)}.
     *
     * @param responseFromLLM the response from the LLM
     * @param memory the memory, can be {@code null} or empty
     * @param augmentationResult the augmentation result, can be {@code null}
     */
    record OutputGuardrailParams(AiMessage responseFromLLM, ChatMemory memory,
            AugmentationResult augmentationResult) implements GuardrailParams {
    }
}

The validate method of the OutputGuardrail interface can have two signatures:

  • OutputGuardrailParams validate(AiMessage responseFromLLM)

  • OutputGuardrailParams validate(OutputGuardrailParams params)

The first one is used when the guardrail only needs the output of the LLM. Simple guardrails can use this method. For example, here is an output guardrail that checks that the output is a JSON document:

import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.message.AiMessage;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;

@ApplicationScoped
public class JsonGuardrail implements OutputGuardrail {

    @Inject
    ObjectMapper mapper;

    @Override
    public OutputGuardrailResult validate(AiMessage responseFromLLM) {
        try {
            mapper.readTree(responseFromLLM.text());
        } catch (Exception e) {
            return reprompt("Invalid JSON", e, "Make sure you return a valid JSON object");
        }
        return success();
    }

}

The second signature is used when the guardrail needs more information, like the augmentation results or the memory. Note that the guardrail cannot modify the memory or the augmentation results. The Detecting Hallucinations in the RAG Context section gives an example of guardrail using the augmented results.

Output Guardrails Outcome

Output guardrails can have six outcomes:

  • pass - The output is valid, the next guardrail is executed, and if the last guardrail passes, the output is returned to the caller.

  • pass with rewrite - The output isn’t valid in its original form, but has been rewritten in order to make it valid, the next guardrail is executed against the rewritten output, and if the last guardrail passes, the output is returned to the caller.

  • fail - The output is invalid, but the next guardrail is executed the same, in order to accumulate all the possible validation problems.

  • fatal - The output is invalid, the next guardrail is not executed, and the error is rethrown.

  • fatal with retry - The output is invalid, the next guardrail is not executed, and the LLM is called again with the same prompt.

  • fatal with reprompt - The output is invalid, the next guardrail is not executed, and the LLM is called again with a new prompt.

In fact if the validation fails, then the guardrail can specify whether the LLM should be retried or reprompted.

// Retry - The LLM is called again with the same prompt and context
// The guardrails will be called again with the new output
return retry("Invalid JSON");

// Retry with reprompt - The LLM is called again with a new prompt and context
// A new user message is added to the LLM context (memory), and the LLM is called again with this new context.
// The guardrails will be called again with the new output
return reprompt("Invalid JSON", "Make sure you return a valid JSON object");

By default, Quarkus Langchain4J will limit the number of retries to 3. This is configurable using the quarkus.langchain4j.guardrails.max-retries configuration property:

quarkus.langchain4j.guardrails.max-retries=5
Setting quarkus.langchain4j.guardrails.max-retries to 0 disables retries.

Output Guardrails Scopes

Output guardrails are CDI beans. They can be in any CDI scope, including request scope, application scope, or session scope.

The scope of the guardrail is important as it defines the lifecycle of the guardrail, especially when the guardrail is stateful.

Declaring Output Guardrails

Output guardrails are declared on the AI Service interface. You can declare output guardrails in two ways:

  • By annotating the AI Service interface with @OutputGuardrails and listing the guardrails - these guardrails will be applied to all the methods of the AI Service.

  • By annotating the method of the AI Service with @OutputGuardrails and listing the guardrails - these guardrails will be applied to this method only.

Method guardrails take precedence over class guardrails.

Here is an example of an AI Service interface with output guardrails:

import dev.langchain4j.service.SystemMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
import jakarta.enterprise.context.SessionScoped;

@RegisterAiService(retrievalAugmentor = Retriever.class)
@SystemMessage("""
    You are Mona, a chatbot answering questions about a museum. Be polite, concise, and helpful.
""")
@SessionScoped
public interface ChatBot {

    @OutputGuardrails(HallucinationGuard.class)
    String chat(String question);

}

Output Guardrail Chain

You can declare multiple guardrails. In this case, a chain is created, and the guardrails are executed in the order they are declared. Thus, the order of the guardrails is important.

Typically, it’s a good idea to have a guardrail that checks the format of the output first, and then a guardrail that checks the content.

@RegisterAiService
@SystemMessage("""
    You are simulating fights between a superhero and a supervillain.
""")
public interface Simulator {

    @UserMessage("""
        Simulate a fight between:
        - a hero: {hero}
        - a villain: {villain}
    """)
    @OutputGuardrails({JsonGuardrail.class, ConsistentStoryGuardrail.class})
    FightResult fight(Hero hero, Villain villain);

}

In this example, the JsonGuardrail is executed first to check that the output is a valid JSON document. Then, the ConsistentStoryGuardrail is executed to check that the story is consistent.

If the JsonGuardrail fails, the ConsistentStoryGuardrail is not executed. However, if the ConsistentStoryGuardrail fails with a retry or reprompt, the JsonGuardrail is executed again with the new response.

Output Guardrails for Streamed Responses

Output guardrails can be applied to methods that return Multi, but not to those returning TokenStream. By default, Quarkus will automatically assemble the full response before executing the guardrail chain. Keep in mind that this may have a performance impact when handling large responses.

To control when the guardrail chain is invoked during streaming, you can configure an accumulator:

@UserMessage("...")
@OutputGuardrails(MyGuardrail.class)
@OutputGuardrailAccumulator(PassThroughAccumulator.class) // Defines the accumulator
Multi<String> ask();

The @OutputGuardrailAccumulator annotation allows you to specify a custom accumulator. The accumulator must implement the io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator interface and be exposed as a CDI bean. Here’s a simple example of an accumulator that does not perform any accumulation:

@ApplicationScoped
public class PassThroughAccumulator implements OutputTokenAccumulator {

    @Override
    public Multi<String> accumulate(Multi<String> tokens) {
        return tokens; // Passes the tokens through without accumulating
    }
}

You can create accumulators based on various criteria, such as the number of tokens, a specific separator, or time intervals.

When an accumulator is set, the output guardrail chain is invoked for each item emitted by the Multi returned by the accumulate method.

In the case of a retry, the accumulator is called again with the new response, restarting the stream from the beginning. The same behavior applies for reprompts.

Detecting Hallucinations in the RAG Context

This section is an example of how to implement a guardrail that detects hallucinations in the context of a RAG. The idea is to check that the output of the LLM is consistent with the augmentation results.

package me.escoffier.langchain4j.nomic;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.rag.content.Content;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
import io.quarkus.logging.Log;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import org.eclipse.microprofile.config.inject.ConfigProperty;

@ApplicationScoped
public class HallucinationGuard implements OutputGuardrail {

    @Inject
    NomicEmbeddingV1 embedding;

    @ConfigProperty(name = "hallucination.threshold", defaultValue = "0.7")
    double threshold;

    @Override
    public OutputGuardrailResult validate(OutputGuardrailParams params) {
        Response<Embedding> embeddingOfTheResponse = embedding.embed(params.responseFromLLM().text());
        if (params.augmentationResult() == null || params.augmentationResult().contents().isEmpty()) {
            Log.info("No content to validate against");
            return success();
        }
        float[] vectorOfTheResponse = embeddingOfTheResponse.content().vector();
        for (Content content : params.augmentationResult().contents()) {
            Response<Embedding> embeddingOfTheContent = embedding.embed(content.textSegment());
            float[] vectorOfTheContent = embeddingOfTheContent.content().vector();
            double distance = cosineDistance(vectorOfTheResponse, vectorOfTheContent);
            if (distance < threshold) {
                Log.info("Passed hallucination guardrail: " + distance);
                return success();
            }
        }

        return reprompt("Hallucination detected", "Make sure you use the given documents to produce the response");
    }

    public static double cosineDistance(float[] vector1, float[] vector2) {
        double dotProduct = 0.0;
        double normA = 0.0;
        double normB = 0.0;

        for (int i = 0; i < vector1.length; i++) {
            dotProduct += vector1[i] * vector2[i];
            normA += Math.pow(vector1[i], 2);
            normB += Math.pow(vector2[i], 2);
        }

        double cosineSimilarity = dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
        return 1.0 - cosineSimilarity;
    }
}

Rewriting the LLM output

It may happen that the output generated by the LLM is not completely satisfying, but it can be programmatically adjusted instead of attempting a retry or a reprompt, both implying a costly, time consuming and less reliable new interaction with the LLM. For instance, it is quite common that an LLM produces the json of the data object that it is required to extract from the user prompt, but appends to it some unwanted explanation of why it generated that result, making the json unparsable, something like

{"name":"Alex", age:18} Alex is 18 since he became an adult a few days ago.

In this situation it is better to try to programmatically trim the json part of the response and check if we can deserialize a valid Person object out of it, before trying to reprompt the LLM again. If the programmatic extraction of the json string from the partially hallucinated LLM output succeeds, it is possible to propagate the rewritten output through the successWith method.

This scenario is so common that it is already provided an abstract class implementing the OutputGuardrail interface and performing this programmatic json sanitization out-of-the-box.

import jakarta.inject.Inject;
import org.jboss.logging.Logger;
import com.fasterxml.jackson.core.type.TypeReference;
import dev.langchain4j.data.message.AiMessage;

public abstract class AbstractJsonExtractorOutputGuardrail implements OutputGuardrail {

    @Inject
    Logger logger;

    @Inject
    JsonGuardrailsUtils jsonGuardrailsUtils;

    protected AbstractJsonExtractorOutputGuardrail() {
        if (getOutputClass() == null && getOutputType() == null) {
            throw new IllegalArgumentException("Either getOutputClass() or getOutputType() must be implemented");
        }
    }

    @Override
    public OutputGuardrailResult validate(AiMessage responseFromLLM) {
        String llmResponse = responseFromLLM.text();
        logger.debugf("LLM output: %s", llmResponse);

        Object result = deserialize(llmResponse);
        if (result != null) {
            return successWith(llmResponse, result);
        }

        String json = jsonGuardrailsUtils.trimNonJson(llmResponse);
        if (json != null) {
            result = deserialize(json);
            if (result != null) {
                return successWith(json, result);
            }
        }

        return reprompt("Invalid JSON",
                "Make sure you return a valid JSON object following "
                        + "the specified format");
    }

    protected Object deserialize(String llmResponse) {
        return getOutputClass() != null ? jsonGuardrailsUtils.deserialize(llmResponse, getOutputClass())
                : jsonGuardrailsUtils.deserialize(llmResponse, getOutputType());
    }

    protected Class<?> getOutputClass() {
        return null;
    }

    protected TypeReference<?> getOutputType() {
        return null;
    }
}

This implementation, first tries to deserialize the LLM response into the expected class to be returned by the data extraction. If this doesn’t succeed it tries to trim away the non-json part of the response and perform the deserialization again. Note that in both case together with the json response, either the original LLM one or the one programmatically trimmed, the successWith method also returns the resulting deserialized object, so that it could be used directly as the final response of the data extraction, instead of uselessly having to execute a second deserialization. In case that both these attempts of deserialization fail then the OutputGuardrail perform a reprompt, hoping that the LLM will finally produce a valid json string.

In this way if for example there is an AI service trying to extract the data of a customer from the user prompts like the following

@RegisterAiService
public interface CustomerExtractor {

    @UserMessage("Extract information about a customer from this text '{text}'. The response must contain only the JSON with customer's data and without any other sentence.")
    @OutputGuardrails(CustomerExtractionOutputGuardrail.class)
    Customer extractData(String text);
}

it is possible to use with it an OutputGuardrail that sanitizes the json LLM response by simply extending the former abstract class and declaring which is the expected output class of the data extraction.

@ApplicationScoped
public class CustomerExtractionOutputGuardrail extends AbstractJsonExtractorOutputGuardrail {

    @Override
    protected Class<?> getOutputClass() {
        return Customer.class;
    }
}

Note that if the data extraction requires a generified Java type, like a List<Customer>, it is conversely necessary to extend the getOutputType and return a Jackson’s TypeReference as it follows:

@ApplicationScoped
public class CustomersExtractionOutputGuardrail extends AbstractJsonExtractorOutputGuardrail {

    @Override
    protected TypeReference<?> getOutputType() {
        return new TypeReference<List<Customer>>() {};
    }
}

Unit testing

You can also unit test your output guardrails. A set of AssertJ custom assertions (following AssertJ’s custom assertions pattern) are available to help you unit test your guardrails.

To get access to these helpers you’ll need to add a test dependency to io.quarkiverse.langchain4j:quarkus-langchain4j-testing-core:

<dependency>
    <groupId>io.quarkiverse.langchain4j</groupId>
    <artifactId>quarkus-langchain4j-testing-core</artifactId>
    <version>{project-version}</version>
    <scope>test</scope>
</dependency>

Then you can import the helpers into your test class:

import static io.quarkiverse.langchain4j.guardrails.GuardrailAssertions.*;

Then, for an OutputGuardrail class that looks something like this:

src/main/java/io/quarkiverse/langchain4j/guardrails/EmailContainsRequiredInformationOutputGuardrail.java
package io.quarkiverse.langchain4j.guardrails;

import java.util.Optional;

import dev.langchain4j.data.message.AiMessage;

public class EmailContainsRequiredInformationOutputGuardrail implements OutputGuardrail {
  static final String NO_RESPONSE_MESSAGE = "No response found";
  static final String NO_RESPONSE_PROMPT = "The response was empty. Please try again.";
  static final String CLIENT_NAME_NOT_FOUND_MESSAGE = "Client name not found";
  static final String CLIENT_NAME_NOT_FOUND_PROMPT = "The response did not contain the client name. Please include the client name \"%s\", exactly as is (case-sensitive), in the email body.";
  static final String CLAIM_NUMBER_NOT_FOUND_MESSAGE = "Claim number not found";
  static final String CLAIM_NUMBER_NOT_FOUND_PROMPT = "The response did not contain the claim number. Please include the claim number \"%s\", exactly as is (case-sensitive), in the email body.";
  static final String CLAIM_STATUS_NOT_FOUND_MESSAGE = "Claim status not found";
  static final String CLAIM_STATUS_NOT_FOUND_PROMPT = "The response did not contain the claim status. Please include the claim status \"%s\", exactly as is (case-sensitive), in the email body.";

  @Override
  public OutputGuardrailResult validate(OutputGuardrailParams params) {
    var claimInfo = Optional.ofNullable(params.variables())
        .map(vars -> vars.get("claimInfo"))
        .map(ClaimInfo.class::cast)
        .orElse(null);

    if (claimInfo != null) {
      var response = Optional.ofNullable(params.responseFromLLM())
            .map(AiMessage::text)
            .orElse("");

      if (response.isBlank()) {
        return reprompt(NO_RESPONSE_MESSAGE, NO_RESPONSE_PROMPT);
      }

      if (!claimInfo.clientName().isBlank() && !response.contains(claimInfo.clientName())) {
        return reprompt(CLIENT_NAME_NOT_FOUND_MESSAGE, CLIENT_NAME_NOT_FOUND_PROMPT.formatted(claimInfo.clientName()));
      }

      if (!claimInfo.claimNumber().isBlank() && !response.contains(claimInfo.claimNumber())) {
        return reprompt(CLAIM_NUMBER_NOT_FOUND_MESSAGE,
                        CLAIM_NUMBER_NOT_FOUND_PROMPT.formatted(claimInfo.claimNumber()));
      }

      if (!claimInfo.claimStatus().isBlank() && !response.contains(claimInfo.claimStatus())) {
          return reprompt(CLAIM_STATUS_NOT_FOUND_MESSAGE,
                        CLAIM_STATUS_NOT_FOUND_PROMPT.formatted(claimInfo.claimStatus()));
      }
    }

    return success();
  }
}

You can test it by doing things like this:

src/test/java/io/quarkiverse/langchain4j/guardrails/EmailContainsRequiredInformationOutputGuardrailTests.java
package io.quarkiverse.langchain4j.guardrails;

import static io.quarkiverse.langchain4j.guardrails.GuardrailAssertions.assertThat;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;

import java.util.Map;
import java.util.stream.Stream;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import dev.langchain4j.data.message.AiMessage;
import io.quarkiverse.langchain4j.guardrails.GuardrailResult.Result;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult.Failure;

class EmailContainsRequiredInformationOutputGuardrailTests {
  private static final String CLAIM_NUMBER = "CLM195501";
  private static final String CLAIM_STATUS = "denied";
  private static final String CLIENT_NAME = "Marty McFly";
  private static final String EMAIL_TEMPLATE = """
    Dear %s,

    We are writing to inform you that your claim (%s) has been reviewed and is currently under consideration. After careful evaluation of the evidence provided, we regret to inform you that your claim has been %s.

    Please note that our decision is based on the information provided in your policy declarations page, as well as applicable laws and regulations governing vehicle insurance claims.

    If you have any questions or concerns regarding this decision, please do not hesitate to contact us at 800-CAR-SAFE or email <a href="mailto:claims@parasol.com">claims@parasol.com</a>. A member of our team will be happy to assist you.

    Sincerely,
    Parasoft Insurance Claims Department

    --------------------------------------------
    Please note this is an unmonitored email box.
    Should you choose to reply, nobody (not even an AI bot) will see your message.
    Call a real human should you have any questions. 1-800-CAR-SAFE.
    """;

  EmailContainsRequiredInformationOutputGuardrail guardrail = spy(new EmailContainsRequiredInformationOutputGuardrail());

  @Test
  void guardrailSuccess() {
    var response = EMAIL_TEMPLATE.formatted(CLIENT_NAME, CLAIM_NUMBER, CLAIM_STATUS);
    var params = createParams(response, CLAIM_NUMBER, CLAIM_STATUS, CLIENT_NAME);

    assertThat(this.guardrail.validate(params))
        .isSuccessful();  // Custom assertion from OutputGuardrailResultAssert

    verify(this.guardrail).validate(params);
    verify(this.guardrail).success();
    verifyNoMoreInteractions(this.guardrail);
  }

  @Test
  void emptyEmail() {
    var params = createParams("", CLAIM_NUMBER, CLAIM_STATUS, CLIENT_NAME);
    var result = this.guardrail.validate(params);

    assertThat(result)
        // custom assertions from OutputGuardrailResultAssert
        .hasFailures()
        .hasResult(Result.FATAL)
        .hasSingleFailureWithMessage(EmailContainsRequiredInformationOutputGuardrail.NO_RESPONSE_MESSAGE);

    verify(this.guardrail).validate(params);
    verify(this.guardrail).reprompt(EmailContainsRequiredInformationOutputGuardrail.NO_RESPONSE_MESSAGE,
                EmailContainsRequiredInformationOutputGuardrail.NO_RESPONSE_PROMPT);
    verifyNoMoreInteractions(this.guardrail);
  }

  @ParameterizedTest
  @MethodSource("emailDoesntContainRequiredInfoParams")
  void emailDoesntContainRequiredInfo(ClaimInfo missingClaimInfo, String expectedRepromptMessage, String expectedRepromptPrompt) {
    var responseWithMissingInfo = EMAIL_TEMPLATE.formatted(missingClaimInfo.clientName(), missingClaimInfo.claimNumber(), missingClaimInfo.claimStatus());
    var params = createParams(responseWithMissingInfo, CLAIM_NUMBER, CLAIM_STATUS, CLIENT_NAME);
    var result = this.guardrail.validate(params);

    assertThat(result)
        // custom assertions from OutputGuardrailResultAssert
        .hasFailures()
        .hasResult(Result.FATAL)
        .hasSingleFailureWithMessageAndReprompt(expectedRepromptMessage, expectedRepromptPrompt)
        .assertSingleFailureSatisfies(failure ->
            assertThat(failure)
                .isNotNull()
                .extracting(
                        Failure::retry,
                        Failure::message,
                        Failure::cause
                )
                .containsExactly(
                        true,
                        expectedRepromptMessage,
                        null)
                );

    verify(this.guardrail).validate(params);
    verify(this.guardrail).reprompt(expectedRepromptMessage, expectedRepromptPrompt);
    verifyNoMoreInteractions(this.guardrail);
  }

  static Stream<Arguments> emailDoesntContainRequiredInfoParams() {
    return Stream.of(
            Arguments.of(
                new ClaimInfo("", CLAIM_NUMBER, CLAIM_STATUS),
                EmailContainsRequiredInformationOutputGuardrail.CLIENT_NAME_NOT_FOUND_MESSAGE,
                EmailContainsRequiredInformationOutputGuardrail.CLIENT_NAME_NOT_FOUND_PROMPT.formatted(CLIENT_NAME)
            ),
            Arguments.of(
                new ClaimInfo(CLIENT_NAME, "", CLAIM_STATUS),
                EmailContainsRequiredInformationOutputGuardrail.CLAIM_NUMBER_NOT_FOUND_MESSAGE,
                EmailContainsRequiredInformationOutputGuardrail.CLAIM_NUMBER_NOT_FOUND_PROMPT.formatted(CLAIM_NUMBER)
            ),
            Arguments.of(
                new ClaimInfo(CLIENT_NAME, CLAIM_NUMBER, ""),
                EmailContainsRequiredInformationOutputGuardrail.CLAIM_STATUS_NOT_FOUND_MESSAGE,
                EmailContainsRequiredInformationOutputGuardrail.CLAIM_STATUS_NOT_FOUND_PROMPT.formatted(CLAIM_STATUS)
            )
          );
  }

  private static OutputGuardrailParams createParams(String response, String claimNumber, String claimStatus, String clientName) {
    return createParams(response, new ClaimInfo(clientName, claimNumber, claimStatus));
  }

  private static OutputGuardrailParams createParams(String response, ClaimInfo claimInfo) {
    return new OutputGuardrailParams(
        AiMessage.from(response),
        null,
        null,
        null,
        Map.of("claimInfo", claimInfo)
    );
  }
}

See OutputGuardrailResultAssert.java for more information about the different kinds of asserts you can do.