mirror of
https://github.com/Cinnamon/kotaemon.git
synced 2026-02-24 12:10:24 +01:00
fix: update trulens LLM ranking score for retrieval confidence, improve citation (#98)
* Round when displaying not by default * Add LLMTrulens reranking model * Use llmtrulensscoring in pipeline * fix: update UI display for trulen score --------- Co-authored-by: taprosoft <tadashi@cinnamon.is>
This commit is contained in:
@@ -2,5 +2,12 @@ from .base import BaseReranking
|
||||
from .cohere import CohereReranking
|
||||
from .llm import LLMReranking
|
||||
from .llm_scoring import LLMScoring
|
||||
from .llm_trulens import LLMTrulensScoring
|
||||
|
||||
__all__ = ["CohereReranking", "LLMReranking", "LLMScoring", "BaseReranking"]
|
||||
__all__ = [
|
||||
"CohereReranking",
|
||||
"LLMReranking",
|
||||
"LLMScoring",
|
||||
"BaseReranking",
|
||||
"LLMTrulensScoring",
|
||||
]
|
||||
|
||||
@@ -33,7 +33,7 @@ class CohereReranking(BaseReranking):
|
||||
)
|
||||
for r in response.results:
|
||||
doc = documents[r.index]
|
||||
doc.metadata["cohere_reranking_score"] = round(r.relevance_score, 2)
|
||||
doc.metadata["cohere_reranking_score"] = r.relevance_score
|
||||
compressed_docs.append(doc)
|
||||
|
||||
return compressed_docs
|
||||
|
||||
@@ -42,9 +42,9 @@ class LLMScoring(LLMReranking):
|
||||
score = np.exp(np.average(result.logprobs))
|
||||
include_doc = output_parser.parse(result.text)
|
||||
if include_doc:
|
||||
doc.metadata["llm_reranking_score"] = round(score, 2)
|
||||
doc.metadata["llm_reranking_score"] = score
|
||||
else:
|
||||
doc.metadata["llm_reranking_score"] = round(1 - score, 2)
|
||||
doc.metadata["llm_reranking_score"] = 1 - score
|
||||
filtered_docs.append(doc)
|
||||
|
||||
# prevent returning empty result
|
||||
|
||||
155
libs/kotaemon/kotaemon/indices/rankings/llm_trulens.py
Normal file
155
libs/kotaemon/kotaemon/indices/rankings/llm_trulens.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from kotaemon.base import Document, HumanMessage, SystemMessage
|
||||
from kotaemon.llms import BaseLLM, PromptTemplate
|
||||
|
||||
from .llm import LLMReranking
|
||||
|
||||
SYSTEM_PROMPT_TEMPLATE = PromptTemplate(
|
||||
"""You are a RELEVANCE grader; providing the relevance of the given CONTEXT to the given QUESTION.
|
||||
Respond only as a number from 0 to 10 where 0 is the least relevant and 10 is the most relevant.
|
||||
|
||||
A few additional scoring guidelines:
|
||||
|
||||
- Long CONTEXTS should score equally well as short CONTEXTS.
|
||||
|
||||
- RELEVANCE score should increase as the CONTEXTS provides more RELEVANT context to the QUESTION.
|
||||
|
||||
- RELEVANCE score should increase as the CONTEXTS provides RELEVANT context to more parts of the QUESTION.
|
||||
|
||||
- CONTEXT that is RELEVANT to some of the QUESTION should score of 2, 3 or 4. Higher score indicates more RELEVANCE.
|
||||
|
||||
- CONTEXT that is RELEVANT to most of the QUESTION should get a score of 5, 6, 7 or 8. Higher score indicates more RELEVANCE.
|
||||
|
||||
- CONTEXT that is RELEVANT to the entire QUESTION should get a score of 9 or 10. Higher score indicates more RELEVANCE.
|
||||
|
||||
- CONTEXT must be relevant and helpful for answering the entire QUESTION to get a score of 10.
|
||||
|
||||
- Never elaborate.""" # noqa: E501
|
||||
)
|
||||
|
||||
USER_PROMPT_TEMPLATE = PromptTemplate(
|
||||
"""QUESTION: {question}
|
||||
|
||||
CONTEXT: {context}
|
||||
|
||||
RELEVANCE: """
|
||||
) # noqa
|
||||
|
||||
PATTERN_INTEGER: re.Pattern = re.compile(r"([+-]?[1-9][0-9]*|0)")
|
||||
"""Regex that matches integers."""
|
||||
|
||||
|
||||
def validate_rating(rating) -> int:
|
||||
"""Validate a rating is between 0 and 10."""
|
||||
|
||||
if not 0 <= rating <= 10:
|
||||
raise ValueError("Rating must be between 0 and 10")
|
||||
|
||||
return rating
|
||||
|
||||
|
||||
def re_0_10_rating(s: str) -> int:
|
||||
"""Extract a 0-10 rating from a string.
|
||||
|
||||
If the string does not match an integer or matches an integer outside the
|
||||
0-10 range, raises an error instead. If multiple numbers are found within
|
||||
the expected 0-10 range, the smallest is returned.
|
||||
|
||||
Args:
|
||||
s: String to extract rating from.
|
||||
|
||||
Returns:
|
||||
int: Extracted rating.
|
||||
|
||||
Raises:
|
||||
ParseError: If no integers between 0 and 10 are found in the string.
|
||||
"""
|
||||
|
||||
matches = PATTERN_INTEGER.findall(s)
|
||||
if not matches:
|
||||
raise AssertionError
|
||||
|
||||
vals = set()
|
||||
for match in matches:
|
||||
try:
|
||||
vals.add(validate_rating(int(match)))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if not vals:
|
||||
raise AssertionError
|
||||
|
||||
# Min to handle cases like "The rating is 8 out of 10."
|
||||
return min(vals)
|
||||
|
||||
|
||||
class LLMTrulensScoring(LLMReranking):
|
||||
llm: BaseLLM
|
||||
system_prompt_template: PromptTemplate = SYSTEM_PROMPT_TEMPLATE
|
||||
user_prompt_template: PromptTemplate = USER_PROMPT_TEMPLATE
|
||||
top_k: int = 3
|
||||
concurrent: bool = True
|
||||
normalize: float = 10
|
||||
|
||||
def run(
|
||||
self,
|
||||
documents: list[Document],
|
||||
query: str,
|
||||
) -> list[Document]:
|
||||
"""Filter down documents based on their relevance to the query."""
|
||||
filtered_docs = []
|
||||
|
||||
documents = sorted(documents, key=lambda doc: doc.get_content())
|
||||
if self.concurrent:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = []
|
||||
for doc in documents:
|
||||
messages = []
|
||||
messages.append(
|
||||
SystemMessage(self.system_prompt_template.populate())
|
||||
)
|
||||
messages.append(
|
||||
HumanMessage(
|
||||
self.user_prompt_template.populate(
|
||||
question=query, context=doc.get_content()
|
||||
)
|
||||
)
|
||||
)
|
||||
futures.append(executor.submit(lambda: self.llm(messages).text))
|
||||
|
||||
results = [future.result() for future in futures]
|
||||
else:
|
||||
results = []
|
||||
for doc in documents:
|
||||
messages = []
|
||||
messages.append(SystemMessage(self.system_prompt_template.populate()))
|
||||
messages.append(
|
||||
SystemMessage(
|
||||
self.user_prompt_template.populate(
|
||||
question=query, context=doc.get_content()
|
||||
)
|
||||
)
|
||||
)
|
||||
results.append(self.llm(messages).text)
|
||||
|
||||
# use Boolean parser to extract relevancy output from LLM
|
||||
results = [
|
||||
(r_idx, float(re_0_10_rating(result)) / self.normalize)
|
||||
for r_idx, result in enumerate(results)
|
||||
]
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
for r_idx, score in results:
|
||||
doc = documents[r_idx]
|
||||
doc.metadata["llm_trulens_score"] = score
|
||||
filtered_docs.append(doc)
|
||||
|
||||
# prevent returning empty result
|
||||
if len(filtered_docs) == 0:
|
||||
filtered_docs = documents[: self.top_k]
|
||||
|
||||
return filtered_docs
|
||||
@@ -35,7 +35,7 @@ from kotaemon.indices.rankings import (
|
||||
BaseReranking,
|
||||
CohereReranking,
|
||||
LLMReranking,
|
||||
LLMScoring,
|
||||
LLMTrulensScoring,
|
||||
)
|
||||
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
|
||||
|
||||
@@ -254,7 +254,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||
)
|
||||
],
|
||||
retrieval_mode=user_settings["retrieval_mode"],
|
||||
rerankers=[LLMScoring(), CohereReranking()],
|
||||
rerankers=[CohereReranking(), LLMTrulensScoring()],
|
||||
)
|
||||
if not user_settings["use_reranking"]:
|
||||
retriever.rerankers = [] # type: ignore
|
||||
|
||||
@@ -3,6 +3,7 @@ import html
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from difflib import SequenceMatcher
|
||||
from functools import partial
|
||||
from typing import Generator
|
||||
|
||||
@@ -50,6 +51,26 @@ def is_close(val1, val2, tolerance=1e-9):
|
||||
return abs(val1 - val2) <= tolerance
|
||||
|
||||
|
||||
def find_text(search_span, context):
|
||||
sentence_list = search_span.split("\n")
|
||||
matches = []
|
||||
# don't search for small text
|
||||
if len(search_span) > 5:
|
||||
for sentence in sentence_list:
|
||||
match = SequenceMatcher(
|
||||
None, sentence, context, autojunk=False
|
||||
).find_longest_match()
|
||||
if match.size > len(search_span) * 0.6:
|
||||
matches.append((match.b, match.b + match.size))
|
||||
print(
|
||||
"search",
|
||||
search_span,
|
||||
"matched",
|
||||
context[match.b : match.b + match.size],
|
||||
)
|
||||
return matches
|
||||
|
||||
|
||||
_default_token_func = tiktoken.encoding_for_model("gpt-3.5-turbo").encode
|
||||
|
||||
|
||||
@@ -139,9 +160,9 @@ class PrepareEvidencePipeline(BaseComponent):
|
||||
|
||||
|
||||
DEFAULT_QA_TEXT_PROMPT = (
|
||||
"Use the following pieces of context to answer the question at the end. "
|
||||
"Use the following pieces of context to answer the question at the end in detail with clear explanation. " # noqa: E501
|
||||
"If you don't know the answer, just say that you don't know, don't try to "
|
||||
"make up an answer. Keep the answer as concise as possible. Give answer in "
|
||||
"make up an answer. Give answer in "
|
||||
"{lang}.\n\n"
|
||||
"{context}\n"
|
||||
"Question: {question}\n"
|
||||
@@ -150,14 +171,14 @@ DEFAULT_QA_TEXT_PROMPT = (
|
||||
|
||||
DEFAULT_QA_TABLE_PROMPT = (
|
||||
"List all rows (row number) from the table context that related to the question, "
|
||||
"then provide detail answer with clear explanation and citations. "
|
||||
"then provide detail answer with clear explanation. "
|
||||
"If you don't know the answer, just say that you don't know, "
|
||||
"don't try to make up an answer. Give answer in {lang}.\n\n"
|
||||
"Context:\n"
|
||||
"{context}\n"
|
||||
"Question: {question}\n"
|
||||
"Helpful Answer:"
|
||||
)
|
||||
) # noqa
|
||||
|
||||
DEFAULT_QA_CHATBOT_PROMPT = (
|
||||
"Pick the most suitable chatbot scenarios to answer the question at the end, "
|
||||
@@ -168,7 +189,7 @@ DEFAULT_QA_CHATBOT_PROMPT = (
|
||||
"{context}\n"
|
||||
"Question: {question}\n"
|
||||
"Answer:"
|
||||
)
|
||||
) # noqa
|
||||
|
||||
DEFAULT_QA_FIGURE_PROMPT = (
|
||||
"Use the given context: texts, tables, and figures below to answer the question. "
|
||||
@@ -178,7 +199,7 @@ DEFAULT_QA_FIGURE_PROMPT = (
|
||||
"{context}\n"
|
||||
"Question: {question}\n"
|
||||
"Answer: "
|
||||
)
|
||||
) # noqa
|
||||
|
||||
DEFAULT_REWRITE_PROMPT = (
|
||||
"Given the following question, rephrase and expand it "
|
||||
@@ -187,7 +208,7 @@ DEFAULT_REWRITE_PROMPT = (
|
||||
"Give answer in {lang}\n"
|
||||
"Original question: {question}\n"
|
||||
"Rephrased question: "
|
||||
)
|
||||
) # noqa
|
||||
|
||||
|
||||
class AnswerWithContextPipeline(BaseComponent):
|
||||
@@ -400,12 +421,14 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||
if evidence and self.enable_citation:
|
||||
citation = self.citation_pipeline(context=evidence, question=question)
|
||||
|
||||
if logprobs:
|
||||
qa_score = np.exp(np.average(logprobs))
|
||||
else:
|
||||
qa_score = None
|
||||
|
||||
answer = Document(
|
||||
text=output,
|
||||
metadata={
|
||||
"citation": citation,
|
||||
"qa_score": round(np.exp(np.average(logprobs)), 2),
|
||||
},
|
||||
metadata={"citation": citation, "qa_score": qa_score},
|
||||
)
|
||||
|
||||
return answer
|
||||
@@ -556,6 +579,47 @@ class FullQAPipeline(BaseReasoning):
|
||||
|
||||
return docs, info
|
||||
|
||||
def _format_retrieval_score_and_doc(
|
||||
self,
|
||||
doc: Document,
|
||||
rendered_doc_content: str,
|
||||
open_collapsible: bool = False,
|
||||
) -> str:
|
||||
"""Format the retrieval score and the document"""
|
||||
# score from doc_store (Elasticsearch)
|
||||
if is_close(doc.score, -1.0):
|
||||
text_search_str = " default from full-text search<br>"
|
||||
else:
|
||||
text_search_str = "<br>"
|
||||
|
||||
vectorstore_score = round(doc.score, 2)
|
||||
llm_reranking_score = (
|
||||
round(doc.metadata["llm_trulens_score"], 2)
|
||||
if doc.metadata.get("llm_trulens_score") is not None
|
||||
else None
|
||||
)
|
||||
cohere_reranking_score = (
|
||||
round(doc.metadata["cohere_reranking_score"], 2)
|
||||
if doc.metadata.get("cohere_reranking_score")
|
||||
else None
|
||||
)
|
||||
item_type_prefix = doc.metadata.get("type", "")
|
||||
item_type_prefix = item_type_prefix.capitalize()
|
||||
if item_type_prefix:
|
||||
item_type_prefix += " from "
|
||||
|
||||
return Render.collapsible(
|
||||
header=(f"{item_type_prefix}{get_header(doc)} [{llm_reranking_score}]"),
|
||||
content="<b>Vectorstore score:</b>"
|
||||
f" {vectorstore_score}"
|
||||
f"{text_search_str}"
|
||||
"<b>LLM reranking score:</b>"
|
||||
f" {llm_reranking_score}<br>"
|
||||
"<b>Cohere reranking score:</b>"
|
||||
f" {cohere_reranking_score}<br>" + rendered_doc_content,
|
||||
open=open_collapsible,
|
||||
)
|
||||
|
||||
def prepare_citations(self, answer, docs) -> tuple[list[Document], list[Document]]:
|
||||
"""Prepare the citations to show on the UI"""
|
||||
with_citation, without_citation = [], []
|
||||
@@ -565,116 +629,63 @@ class FullQAPipeline(BaseReasoning):
|
||||
for fact_with_evidence in answer.metadata["citation"].answer:
|
||||
for quote in fact_with_evidence.substring_quote:
|
||||
for doc in docs:
|
||||
start_idx = doc.text.find(quote)
|
||||
if start_idx == -1:
|
||||
continue
|
||||
matches = find_text(quote, doc.text)
|
||||
|
||||
end_idx = start_idx + len(quote)
|
||||
|
||||
current_idx = start_idx
|
||||
if "|" not in doc.text[start_idx:end_idx]:
|
||||
spans[doc.doc_id].append(
|
||||
{"start": start_idx, "end": end_idx}
|
||||
)
|
||||
else:
|
||||
while doc.text[current_idx:end_idx].find("|") != -1:
|
||||
match_idx = doc.text[current_idx:end_idx].find("|")
|
||||
for start, end in matches:
|
||||
if "|" not in doc.text[start:end]:
|
||||
spans[doc.doc_id].append(
|
||||
{
|
||||
"start": current_idx,
|
||||
"end": current_idx + match_idx,
|
||||
"start": start,
|
||||
"end": end,
|
||||
}
|
||||
)
|
||||
current_idx += match_idx + 2
|
||||
if current_idx > end_idx:
|
||||
break
|
||||
break
|
||||
|
||||
id2docs = {doc.doc_id: doc for doc in docs}
|
||||
not_detected = set(id2docs.keys()) - set(spans.keys())
|
||||
for id, ss in spans.items():
|
||||
|
||||
# render highlight spans
|
||||
for _id, ss in spans.items():
|
||||
if not ss:
|
||||
not_detected.add(id)
|
||||
not_detected.add(_id)
|
||||
continue
|
||||
cur_doc = id2docs[_id]
|
||||
ss = sorted(ss, key=lambda x: x["start"])
|
||||
text = id2docs[id].text[: ss[0]["start"]]
|
||||
text = cur_doc.text[: ss[0]["start"]]
|
||||
for idx, span in enumerate(ss):
|
||||
text += Render.highlight(id2docs[id].text[span["start"] : span["end"]])
|
||||
text += Render.highlight(cur_doc.text[span["start"] : span["end"]])
|
||||
if idx < len(ss) - 1:
|
||||
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
|
||||
text += id2docs[id].text[ss[-1]["end"] :]
|
||||
if is_close(id2docs[id].score, -1.0):
|
||||
text_search_str = " default from full-text search<br>"
|
||||
else:
|
||||
text_search_str = "<br>"
|
||||
|
||||
if (
|
||||
id2docs[id].metadata.get("llm_reranking_score") is None
|
||||
or id2docs[id].metadata.get("cohere_reranking_score") is None
|
||||
):
|
||||
cloned_chunk_str = (
|
||||
"<b>Cloned chunk for a table. No reranking score</b><br>"
|
||||
)
|
||||
else:
|
||||
cloned_chunk_str = ""
|
||||
|
||||
text += cur_doc.text[span["end"] : ss[idx + 1]["start"]]
|
||||
text += cur_doc.text[ss[-1]["end"] :]
|
||||
# add to display list
|
||||
with_citation.append(
|
||||
Document(
|
||||
channel="info",
|
||||
content=Render.collapsible(
|
||||
header=(
|
||||
f"{get_header(id2docs[id])}<br>"
|
||||
"<b>Vectorstore score:</b>"
|
||||
f" {round(id2docs[id].score, 2)}"
|
||||
f"{text_search_str}"
|
||||
f"{cloned_chunk_str}"
|
||||
"<b>LLM reranking score:</b>"
|
||||
f' {id2docs[id].metadata.get("llm_reranking_score")}<br>'
|
||||
"<b>Cohere reranking score:</b>"
|
||||
f' {id2docs[id].metadata.get("cohere_reranking_score")}<br>'
|
||||
),
|
||||
content=Render.table(text),
|
||||
open=True,
|
||||
content=self._format_retrieval_score_and_doc(
|
||||
cur_doc,
|
||||
Render.table(text),
|
||||
open_collapsible=True,
|
||||
),
|
||||
)
|
||||
)
|
||||
print("Got {} cited docs".format(len(with_citation)))
|
||||
|
||||
for id_ in list(not_detected):
|
||||
sorted_not_detected_items_with_scores = [
|
||||
(id_, id2docs[id_].metadata.get("llm_trulens_score", 0.0))
|
||||
for id_ in not_detected
|
||||
]
|
||||
sorted_not_detected_items_with_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
for id_, _ in sorted_not_detected_items_with_scores:
|
||||
doc = id2docs[id_]
|
||||
if is_close(doc.score, -1.0):
|
||||
text_search_str = " default from full-text search<br>"
|
||||
else:
|
||||
text_search_str = "<br>"
|
||||
|
||||
if (
|
||||
doc.metadata.get("llm_reranking_score") is None
|
||||
or doc.metadata.get("cohere_reranking_score") is None
|
||||
):
|
||||
cloned_chunk_str = (
|
||||
"<b>Cloned chunk for a table. No reranking score</b><br>"
|
||||
)
|
||||
else:
|
||||
cloned_chunk_str = ""
|
||||
if doc.metadata.get("type", "") == "image":
|
||||
without_citation.append(
|
||||
Document(
|
||||
channel="info",
|
||||
content=Render.collapsible(
|
||||
header=(
|
||||
f"{get_header(doc)}<br>"
|
||||
"<b>Vectorstore score:</b>"
|
||||
f" {round(doc.score, 2)}"
|
||||
f"{text_search_str}"
|
||||
f"{cloned_chunk_str}"
|
||||
"<b>LLM reranking score:</b>"
|
||||
f' {doc.metadata.get("llm_reranking_score")}<br>'
|
||||
"<b>Cohere reranking score:</b>"
|
||||
f' {doc.metadata.get("cohere_reranking_score")}<br>'
|
||||
),
|
||||
content=Render.image(
|
||||
content=self._format_retrieval_score_and_doc(
|
||||
doc,
|
||||
Render.image(
|
||||
url=doc.metadata["image_origin"], text=doc.text
|
||||
),
|
||||
open=True,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -682,20 +693,8 @@ class FullQAPipeline(BaseReasoning):
|
||||
without_citation.append(
|
||||
Document(
|
||||
channel="info",
|
||||
content=Render.collapsible(
|
||||
header=(
|
||||
f"{get_header(doc)}<br>"
|
||||
"<b>Vectorstore score:</b>"
|
||||
f" {round(doc.score, 2)}"
|
||||
f"{text_search_str}"
|
||||
f"{cloned_chunk_str}"
|
||||
"<b>LLM reranking score:</b>"
|
||||
f' {doc.metadata.get("llm_reranking_score")}<br>'
|
||||
"<b>Cohere reranking score:</b>"
|
||||
f' {doc.metadata.get("cohere_reranking_score")}<br>'
|
||||
),
|
||||
content=Render.table(doc.text),
|
||||
open=True,
|
||||
content=self._format_retrieval_score_and_doc(
|
||||
doc, Render.table(doc.text)
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -744,6 +743,8 @@ class FullQAPipeline(BaseReasoning):
|
||||
if self.use_rewrite:
|
||||
message = self.rewrite_pipeline(question=message).text
|
||||
|
||||
print(f"Rewritten message (use_rewrite={self.use_rewrite}): {message}")
|
||||
print(f"Retrievers {self.retrievers}")
|
||||
# should populate the context
|
||||
docs, infos = self.retrieve(message, history)
|
||||
for _ in infos:
|
||||
@@ -770,12 +771,18 @@ class FullQAPipeline(BaseReasoning):
|
||||
if without_citation:
|
||||
for _ in without_citation:
|
||||
yield _
|
||||
|
||||
qa_score = (
|
||||
round(answer.metadata["qa_score"], 2)
|
||||
if answer.metadata.get("qa_score")
|
||||
else None
|
||||
)
|
||||
yield Document(
|
||||
channel="info",
|
||||
content=(
|
||||
"<h5><b>Question answering</b></h5><br>"
|
||||
"<b>Question answering confidence:</b> "
|
||||
f"{answer.metadata.get('qa_score')}"
|
||||
f"{qa_score}"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user