fix: update trulens LLM ranking score for retrieval confidence, improve citation (#98)

* Round when displaying not by default

* Add LLMTrulens reranking model

* Use llmtrulensscoring in pipeline

* fix: update UI display for trulen score

---------

Co-authored-by: taprosoft <tadashi@cinnamon.is>
This commit is contained in:
cin-ace
2024-07-07 21:59:59 +07:00
committed by GitHub
parent d04dc2f75d
commit 9e2fe4afc9
6 changed files with 284 additions and 115 deletions

View File

@@ -2,5 +2,12 @@ from .base import BaseReranking
from .cohere import CohereReranking
from .llm import LLMReranking
from .llm_scoring import LLMScoring
from .llm_trulens import LLMTrulensScoring
__all__ = ["CohereReranking", "LLMReranking", "LLMScoring", "BaseReranking"]
__all__ = [
"CohereReranking",
"LLMReranking",
"LLMScoring",
"BaseReranking",
"LLMTrulensScoring",
]

View File

@@ -33,7 +33,7 @@ class CohereReranking(BaseReranking):
)
for r in response.results:
doc = documents[r.index]
doc.metadata["cohere_reranking_score"] = round(r.relevance_score, 2)
doc.metadata["cohere_reranking_score"] = r.relevance_score
compressed_docs.append(doc)
return compressed_docs

View File

@@ -42,9 +42,9 @@ class LLMScoring(LLMReranking):
score = np.exp(np.average(result.logprobs))
include_doc = output_parser.parse(result.text)
if include_doc:
doc.metadata["llm_reranking_score"] = round(score, 2)
doc.metadata["llm_reranking_score"] = score
else:
doc.metadata["llm_reranking_score"] = round(1 - score, 2)
doc.metadata["llm_reranking_score"] = 1 - score
filtered_docs.append(doc)
# prevent returning empty result

View File

@@ -0,0 +1,155 @@
from __future__ import annotations
import re
from concurrent.futures import ThreadPoolExecutor
from kotaemon.base import Document, HumanMessage, SystemMessage
from kotaemon.llms import BaseLLM, PromptTemplate
from .llm import LLMReranking
SYSTEM_PROMPT_TEMPLATE = PromptTemplate(
"""You are a RELEVANCE grader; providing the relevance of the given CONTEXT to the given QUESTION.
Respond only as a number from 0 to 10 where 0 is the least relevant and 10 is the most relevant.
A few additional scoring guidelines:
- Long CONTEXTS should score equally well as short CONTEXTS.
- RELEVANCE score should increase as the CONTEXTS provides more RELEVANT context to the QUESTION.
- RELEVANCE score should increase as the CONTEXTS provides RELEVANT context to more parts of the QUESTION.
- CONTEXT that is RELEVANT to some of the QUESTION should score of 2, 3 or 4. Higher score indicates more RELEVANCE.
- CONTEXT that is RELEVANT to most of the QUESTION should get a score of 5, 6, 7 or 8. Higher score indicates more RELEVANCE.
- CONTEXT that is RELEVANT to the entire QUESTION should get a score of 9 or 10. Higher score indicates more RELEVANCE.
- CONTEXT must be relevant and helpful for answering the entire QUESTION to get a score of 10.
- Never elaborate.""" # noqa: E501
)
USER_PROMPT_TEMPLATE = PromptTemplate(
"""QUESTION: {question}
CONTEXT: {context}
RELEVANCE: """
) # noqa
PATTERN_INTEGER: re.Pattern = re.compile(r"([+-]?[1-9][0-9]*|0)")
"""Regex that matches integers."""
def validate_rating(rating) -> int:
"""Validate a rating is between 0 and 10."""
if not 0 <= rating <= 10:
raise ValueError("Rating must be between 0 and 10")
return rating
def re_0_10_rating(s: str) -> int:
"""Extract a 0-10 rating from a string.
If the string does not match an integer or matches an integer outside the
0-10 range, raises an error instead. If multiple numbers are found within
the expected 0-10 range, the smallest is returned.
Args:
s: String to extract rating from.
Returns:
int: Extracted rating.
Raises:
ParseError: If no integers between 0 and 10 are found in the string.
"""
matches = PATTERN_INTEGER.findall(s)
if not matches:
raise AssertionError
vals = set()
for match in matches:
try:
vals.add(validate_rating(int(match)))
except ValueError:
pass
if not vals:
raise AssertionError
# Min to handle cases like "The rating is 8 out of 10."
return min(vals)
class LLMTrulensScoring(LLMReranking):
llm: BaseLLM
system_prompt_template: PromptTemplate = SYSTEM_PROMPT_TEMPLATE
user_prompt_template: PromptTemplate = USER_PROMPT_TEMPLATE
top_k: int = 3
concurrent: bool = True
normalize: float = 10
def run(
self,
documents: list[Document],
query: str,
) -> list[Document]:
"""Filter down documents based on their relevance to the query."""
filtered_docs = []
documents = sorted(documents, key=lambda doc: doc.get_content())
if self.concurrent:
with ThreadPoolExecutor() as executor:
futures = []
for doc in documents:
messages = []
messages.append(
SystemMessage(self.system_prompt_template.populate())
)
messages.append(
HumanMessage(
self.user_prompt_template.populate(
question=query, context=doc.get_content()
)
)
)
futures.append(executor.submit(lambda: self.llm(messages).text))
results = [future.result() for future in futures]
else:
results = []
for doc in documents:
messages = []
messages.append(SystemMessage(self.system_prompt_template.populate()))
messages.append(
SystemMessage(
self.user_prompt_template.populate(
question=query, context=doc.get_content()
)
)
)
results.append(self.llm(messages).text)
# use Boolean parser to extract relevancy output from LLM
results = [
(r_idx, float(re_0_10_rating(result)) / self.normalize)
for r_idx, result in enumerate(results)
]
results.sort(key=lambda x: x[1], reverse=True)
for r_idx, score in results:
doc = documents[r_idx]
doc.metadata["llm_trulens_score"] = score
filtered_docs.append(doc)
# prevent returning empty result
if len(filtered_docs) == 0:
filtered_docs = documents[: self.top_k]
return filtered_docs

View File

@@ -35,7 +35,7 @@ from kotaemon.indices.rankings import (
BaseReranking,
CohereReranking,
LLMReranking,
LLMScoring,
LLMTrulensScoring,
)
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
@@ -254,7 +254,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
)
],
retrieval_mode=user_settings["retrieval_mode"],
rerankers=[LLMScoring(), CohereReranking()],
rerankers=[CohereReranking(), LLMTrulensScoring()],
)
if not user_settings["use_reranking"]:
retriever.rerankers = [] # type: ignore

View File

@@ -3,6 +3,7 @@ import html
import logging
import re
from collections import defaultdict
from difflib import SequenceMatcher
from functools import partial
from typing import Generator
@@ -50,6 +51,26 @@ def is_close(val1, val2, tolerance=1e-9):
return abs(val1 - val2) <= tolerance
def find_text(search_span, context):
sentence_list = search_span.split("\n")
matches = []
# don't search for small text
if len(search_span) > 5:
for sentence in sentence_list:
match = SequenceMatcher(
None, sentence, context, autojunk=False
).find_longest_match()
if match.size > len(search_span) * 0.6:
matches.append((match.b, match.b + match.size))
print(
"search",
search_span,
"matched",
context[match.b : match.b + match.size],
)
return matches
_default_token_func = tiktoken.encoding_for_model("gpt-3.5-turbo").encode
@@ -139,9 +160,9 @@ class PrepareEvidencePipeline(BaseComponent):
DEFAULT_QA_TEXT_PROMPT = (
"Use the following pieces of context to answer the question at the end. "
"Use the following pieces of context to answer the question at the end in detail with clear explanation. " # noqa: E501
"If you don't know the answer, just say that you don't know, don't try to "
"make up an answer. Keep the answer as concise as possible. Give answer in "
"make up an answer. Give answer in "
"{lang}.\n\n"
"{context}\n"
"Question: {question}\n"
@@ -150,14 +171,14 @@ DEFAULT_QA_TEXT_PROMPT = (
DEFAULT_QA_TABLE_PROMPT = (
"List all rows (row number) from the table context that related to the question, "
"then provide detail answer with clear explanation and citations. "
"then provide detail answer with clear explanation. "
"If you don't know the answer, just say that you don't know, "
"don't try to make up an answer. Give answer in {lang}.\n\n"
"Context:\n"
"{context}\n"
"Question: {question}\n"
"Helpful Answer:"
)
) # noqa
DEFAULT_QA_CHATBOT_PROMPT = (
"Pick the most suitable chatbot scenarios to answer the question at the end, "
@@ -168,7 +189,7 @@ DEFAULT_QA_CHATBOT_PROMPT = (
"{context}\n"
"Question: {question}\n"
"Answer:"
)
) # noqa
DEFAULT_QA_FIGURE_PROMPT = (
"Use the given context: texts, tables, and figures below to answer the question. "
@@ -178,7 +199,7 @@ DEFAULT_QA_FIGURE_PROMPT = (
"{context}\n"
"Question: {question}\n"
"Answer: "
)
) # noqa
DEFAULT_REWRITE_PROMPT = (
"Given the following question, rephrase and expand it "
@@ -187,7 +208,7 @@ DEFAULT_REWRITE_PROMPT = (
"Give answer in {lang}\n"
"Original question: {question}\n"
"Rephrased question: "
)
) # noqa
class AnswerWithContextPipeline(BaseComponent):
@@ -400,12 +421,14 @@ class AnswerWithContextPipeline(BaseComponent):
if evidence and self.enable_citation:
citation = self.citation_pipeline(context=evidence, question=question)
if logprobs:
qa_score = np.exp(np.average(logprobs))
else:
qa_score = None
answer = Document(
text=output,
metadata={
"citation": citation,
"qa_score": round(np.exp(np.average(logprobs)), 2),
},
metadata={"citation": citation, "qa_score": qa_score},
)
return answer
@@ -556,6 +579,47 @@ class FullQAPipeline(BaseReasoning):
return docs, info
def _format_retrieval_score_and_doc(
self,
doc: Document,
rendered_doc_content: str,
open_collapsible: bool = False,
) -> str:
"""Format the retrieval score and the document"""
# score from doc_store (Elasticsearch)
if is_close(doc.score, -1.0):
text_search_str = " default from full-text search<br>"
else:
text_search_str = "<br>"
vectorstore_score = round(doc.score, 2)
llm_reranking_score = (
round(doc.metadata["llm_trulens_score"], 2)
if doc.metadata.get("llm_trulens_score") is not None
else None
)
cohere_reranking_score = (
round(doc.metadata["cohere_reranking_score"], 2)
if doc.metadata.get("cohere_reranking_score")
else None
)
item_type_prefix = doc.metadata.get("type", "")
item_type_prefix = item_type_prefix.capitalize()
if item_type_prefix:
item_type_prefix += " from "
return Render.collapsible(
header=(f"{item_type_prefix}{get_header(doc)} [{llm_reranking_score}]"),
content="<b>Vectorstore score:</b>"
f" {vectorstore_score}"
f"{text_search_str}"
"<b>LLM reranking score:</b>"
f" {llm_reranking_score}<br>"
"<b>Cohere reranking score:</b>"
f" {cohere_reranking_score}<br>" + rendered_doc_content,
open=open_collapsible,
)
def prepare_citations(self, answer, docs) -> tuple[list[Document], list[Document]]:
"""Prepare the citations to show on the UI"""
with_citation, without_citation = [], []
@@ -565,116 +629,63 @@ class FullQAPipeline(BaseReasoning):
for fact_with_evidence in answer.metadata["citation"].answer:
for quote in fact_with_evidence.substring_quote:
for doc in docs:
start_idx = doc.text.find(quote)
if start_idx == -1:
continue
matches = find_text(quote, doc.text)
end_idx = start_idx + len(quote)
current_idx = start_idx
if "|" not in doc.text[start_idx:end_idx]:
spans[doc.doc_id].append(
{"start": start_idx, "end": end_idx}
)
else:
while doc.text[current_idx:end_idx].find("|") != -1:
match_idx = doc.text[current_idx:end_idx].find("|")
for start, end in matches:
if "|" not in doc.text[start:end]:
spans[doc.doc_id].append(
{
"start": current_idx,
"end": current_idx + match_idx,
"start": start,
"end": end,
}
)
current_idx += match_idx + 2
if current_idx > end_idx:
break
break
id2docs = {doc.doc_id: doc for doc in docs}
not_detected = set(id2docs.keys()) - set(spans.keys())
for id, ss in spans.items():
# render highlight spans
for _id, ss in spans.items():
if not ss:
not_detected.add(id)
not_detected.add(_id)
continue
cur_doc = id2docs[_id]
ss = sorted(ss, key=lambda x: x["start"])
text = id2docs[id].text[: ss[0]["start"]]
text = cur_doc.text[: ss[0]["start"]]
for idx, span in enumerate(ss):
text += Render.highlight(id2docs[id].text[span["start"] : span["end"]])
text += Render.highlight(cur_doc.text[span["start"] : span["end"]])
if idx < len(ss) - 1:
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
text += id2docs[id].text[ss[-1]["end"] :]
if is_close(id2docs[id].score, -1.0):
text_search_str = " default from full-text search<br>"
else:
text_search_str = "<br>"
if (
id2docs[id].metadata.get("llm_reranking_score") is None
or id2docs[id].metadata.get("cohere_reranking_score") is None
):
cloned_chunk_str = (
"<b>Cloned chunk for a table. No reranking score</b><br>"
)
else:
cloned_chunk_str = ""
text += cur_doc.text[span["end"] : ss[idx + 1]["start"]]
text += cur_doc.text[ss[-1]["end"] :]
# add to display list
with_citation.append(
Document(
channel="info",
content=Render.collapsible(
header=(
f"{get_header(id2docs[id])}<br>"
"<b>Vectorstore score:</b>"
f" {round(id2docs[id].score, 2)}"
f"{text_search_str}"
f"{cloned_chunk_str}"
"<b>LLM reranking score:</b>"
f' {id2docs[id].metadata.get("llm_reranking_score")}<br>'
"<b>Cohere reranking score:</b>"
f' {id2docs[id].metadata.get("cohere_reranking_score")}<br>'
),
content=Render.table(text),
open=True,
content=self._format_retrieval_score_and_doc(
cur_doc,
Render.table(text),
open_collapsible=True,
),
)
)
print("Got {} cited docs".format(len(with_citation)))
for id_ in list(not_detected):
sorted_not_detected_items_with_scores = [
(id_, id2docs[id_].metadata.get("llm_trulens_score", 0.0))
for id_ in not_detected
]
sorted_not_detected_items_with_scores.sort(key=lambda x: x[1], reverse=True)
for id_, _ in sorted_not_detected_items_with_scores:
doc = id2docs[id_]
if is_close(doc.score, -1.0):
text_search_str = " default from full-text search<br>"
else:
text_search_str = "<br>"
if (
doc.metadata.get("llm_reranking_score") is None
or doc.metadata.get("cohere_reranking_score") is None
):
cloned_chunk_str = (
"<b>Cloned chunk for a table. No reranking score</b><br>"
)
else:
cloned_chunk_str = ""
if doc.metadata.get("type", "") == "image":
without_citation.append(
Document(
channel="info",
content=Render.collapsible(
header=(
f"{get_header(doc)}<br>"
"<b>Vectorstore score:</b>"
f" {round(doc.score, 2)}"
f"{text_search_str}"
f"{cloned_chunk_str}"
"<b>LLM reranking score:</b>"
f' {doc.metadata.get("llm_reranking_score")}<br>'
"<b>Cohere reranking score:</b>"
f' {doc.metadata.get("cohere_reranking_score")}<br>'
),
content=Render.image(
content=self._format_retrieval_score_and_doc(
doc,
Render.image(
url=doc.metadata["image_origin"], text=doc.text
),
open=True,
),
)
)
@@ -682,20 +693,8 @@ class FullQAPipeline(BaseReasoning):
without_citation.append(
Document(
channel="info",
content=Render.collapsible(
header=(
f"{get_header(doc)}<br>"
"<b>Vectorstore score:</b>"
f" {round(doc.score, 2)}"
f"{text_search_str}"
f"{cloned_chunk_str}"
"<b>LLM reranking score:</b>"
f' {doc.metadata.get("llm_reranking_score")}<br>'
"<b>Cohere reranking score:</b>"
f' {doc.metadata.get("cohere_reranking_score")}<br>'
),
content=Render.table(doc.text),
open=True,
content=self._format_retrieval_score_and_doc(
doc, Render.table(doc.text)
),
)
)
@@ -744,6 +743,8 @@ class FullQAPipeline(BaseReasoning):
if self.use_rewrite:
message = self.rewrite_pipeline(question=message).text
print(f"Rewritten message (use_rewrite={self.use_rewrite}): {message}")
print(f"Retrievers {self.retrievers}")
# should populate the context
docs, infos = self.retrieve(message, history)
for _ in infos:
@@ -770,12 +771,18 @@ class FullQAPipeline(BaseReasoning):
if without_citation:
for _ in without_citation:
yield _
qa_score = (
round(answer.metadata["qa_score"], 2)
if answer.metadata.get("qa_score")
else None
)
yield Document(
channel="info",
content=(
"<h5><b>Question answering</b></h5><br>"
"<b>Question answering confidence:</b> "
f"{answer.metadata.get('qa_score')}"
f"{qa_score}"
),
)