diff --git a/libs/kotaemon/kotaemon/indices/rankings/__init__.py b/libs/kotaemon/kotaemon/indices/rankings/__init__.py
index 9de04d8d..84b8765b 100644
--- a/libs/kotaemon/kotaemon/indices/rankings/__init__.py
+++ b/libs/kotaemon/kotaemon/indices/rankings/__init__.py
@@ -2,5 +2,12 @@ from .base import BaseReranking
from .cohere import CohereReranking
from .llm import LLMReranking
from .llm_scoring import LLMScoring
+from .llm_trulens import LLMTrulensScoring
-__all__ = ["CohereReranking", "LLMReranking", "LLMScoring", "BaseReranking"]
+__all__ = [
+ "CohereReranking",
+ "LLMReranking",
+ "LLMScoring",
+ "BaseReranking",
+ "LLMTrulensScoring",
+]
diff --git a/libs/kotaemon/kotaemon/indices/rankings/cohere.py b/libs/kotaemon/kotaemon/indices/rankings/cohere.py
index d22be9a8..4f5866ac 100644
--- a/libs/kotaemon/kotaemon/indices/rankings/cohere.py
+++ b/libs/kotaemon/kotaemon/indices/rankings/cohere.py
@@ -33,7 +33,7 @@ class CohereReranking(BaseReranking):
)
for r in response.results:
doc = documents[r.index]
- doc.metadata["cohere_reranking_score"] = round(r.relevance_score, 2)
+ doc.metadata["cohere_reranking_score"] = r.relevance_score
compressed_docs.append(doc)
return compressed_docs
diff --git a/libs/kotaemon/kotaemon/indices/rankings/llm_scoring.py b/libs/kotaemon/kotaemon/indices/rankings/llm_scoring.py
index 0ee5b23a..b4f51053 100644
--- a/libs/kotaemon/kotaemon/indices/rankings/llm_scoring.py
+++ b/libs/kotaemon/kotaemon/indices/rankings/llm_scoring.py
@@ -42,9 +42,9 @@ class LLMScoring(LLMReranking):
score = np.exp(np.average(result.logprobs))
include_doc = output_parser.parse(result.text)
if include_doc:
- doc.metadata["llm_reranking_score"] = round(score, 2)
+ doc.metadata["llm_reranking_score"] = score
else:
- doc.metadata["llm_reranking_score"] = round(1 - score, 2)
+ doc.metadata["llm_reranking_score"] = 1 - score
filtered_docs.append(doc)
# prevent returning empty result
diff --git a/libs/kotaemon/kotaemon/indices/rankings/llm_trulens.py b/libs/kotaemon/kotaemon/indices/rankings/llm_trulens.py
new file mode 100644
index 00000000..1fa4dc45
--- /dev/null
+++ b/libs/kotaemon/kotaemon/indices/rankings/llm_trulens.py
@@ -0,0 +1,155 @@
+from __future__ import annotations
+
+import re
+from concurrent.futures import ThreadPoolExecutor
+
+from kotaemon.base import Document, HumanMessage, SystemMessage
+from kotaemon.llms import BaseLLM, PromptTemplate
+
+from .llm import LLMReranking
+
+SYSTEM_PROMPT_TEMPLATE = PromptTemplate(
+ """You are a RELEVANCE grader; providing the relevance of the given CONTEXT to the given QUESTION.
+ Respond only as a number from 0 to 10 where 0 is the least relevant and 10 is the most relevant.
+
+ A few additional scoring guidelines:
+
+ - Long CONTEXTS should score equally well as short CONTEXTS.
+
+ - RELEVANCE score should increase as the CONTEXTS provides more RELEVANT context to the QUESTION.
+
+ - RELEVANCE score should increase as the CONTEXTS provides RELEVANT context to more parts of the QUESTION.
+
+ - CONTEXT that is RELEVANT to some of the QUESTION should score of 2, 3 or 4. Higher score indicates more RELEVANCE.
+
+ - CONTEXT that is RELEVANT to most of the QUESTION should get a score of 5, 6, 7 or 8. Higher score indicates more RELEVANCE.
+
+ - CONTEXT that is RELEVANT to the entire QUESTION should get a score of 9 or 10. Higher score indicates more RELEVANCE.
+
+ - CONTEXT must be relevant and helpful for answering the entire QUESTION to get a score of 10.
+
+ - Never elaborate.""" # noqa: E501
+)
+
+USER_PROMPT_TEMPLATE = PromptTemplate(
+ """QUESTION: {question}
+
+ CONTEXT: {context}
+
+ RELEVANCE: """
+) # noqa
+
+PATTERN_INTEGER: re.Pattern = re.compile(r"([+-]?[1-9][0-9]*|0)")
+"""Regex that matches integers."""
+
+
+def validate_rating(rating) -> int:
+ """Validate a rating is between 0 and 10."""
+
+ if not 0 <= rating <= 10:
+ raise ValueError("Rating must be between 0 and 10")
+
+ return rating
+
+
+def re_0_10_rating(s: str) -> int:
+ """Extract a 0-10 rating from a string.
+
+ If the string does not match an integer or matches an integer outside the
+ 0-10 range, raises an error instead. If multiple numbers are found within
+ the expected 0-10 range, the smallest is returned.
+
+ Args:
+ s: String to extract rating from.
+
+ Returns:
+ int: Extracted rating.
+
+ Raises:
+ ParseError: If no integers between 0 and 10 are found in the string.
+ """
+
+ matches = PATTERN_INTEGER.findall(s)
+ if not matches:
+ raise AssertionError
+
+ vals = set()
+ for match in matches:
+ try:
+ vals.add(validate_rating(int(match)))
+ except ValueError:
+ pass
+
+ if not vals:
+ raise AssertionError
+
+ # Min to handle cases like "The rating is 8 out of 10."
+ return min(vals)
+
+
+class LLMTrulensScoring(LLMReranking):
+ llm: BaseLLM
+ system_prompt_template: PromptTemplate = SYSTEM_PROMPT_TEMPLATE
+ user_prompt_template: PromptTemplate = USER_PROMPT_TEMPLATE
+ top_k: int = 3
+ concurrent: bool = True
+ normalize: float = 10
+
+ def run(
+ self,
+ documents: list[Document],
+ query: str,
+ ) -> list[Document]:
+ """Filter down documents based on their relevance to the query."""
+ filtered_docs = []
+
+ documents = sorted(documents, key=lambda doc: doc.get_content())
+ if self.concurrent:
+ with ThreadPoolExecutor() as executor:
+ futures = []
+ for doc in documents:
+ messages = []
+ messages.append(
+ SystemMessage(self.system_prompt_template.populate())
+ )
+ messages.append(
+ HumanMessage(
+ self.user_prompt_template.populate(
+ question=query, context=doc.get_content()
+ )
+ )
+ )
+ futures.append(executor.submit(lambda: self.llm(messages).text))
+
+ results = [future.result() for future in futures]
+ else:
+ results = []
+ for doc in documents:
+ messages = []
+ messages.append(SystemMessage(self.system_prompt_template.populate()))
+ messages.append(
+ SystemMessage(
+ self.user_prompt_template.populate(
+ question=query, context=doc.get_content()
+ )
+ )
+ )
+ results.append(self.llm(messages).text)
+
+ # use Boolean parser to extract relevancy output from LLM
+ results = [
+ (r_idx, float(re_0_10_rating(result)) / self.normalize)
+ for r_idx, result in enumerate(results)
+ ]
+ results.sort(key=lambda x: x[1], reverse=True)
+
+ for r_idx, score in results:
+ doc = documents[r_idx]
+ doc.metadata["llm_trulens_score"] = score
+ filtered_docs.append(doc)
+
+ # prevent returning empty result
+ if len(filtered_docs) == 0:
+ filtered_docs = documents[: self.top_k]
+
+ return filtered_docs
diff --git a/libs/ktem/ktem/index/file/pipelines.py b/libs/ktem/ktem/index/file/pipelines.py
index 22e9a1d2..afb655a8 100644
--- a/libs/ktem/ktem/index/file/pipelines.py
+++ b/libs/ktem/ktem/index/file/pipelines.py
@@ -35,7 +35,7 @@ from kotaemon.indices.rankings import (
BaseReranking,
CohereReranking,
LLMReranking,
- LLMScoring,
+ LLMTrulensScoring,
)
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
@@ -254,7 +254,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
)
],
retrieval_mode=user_settings["retrieval_mode"],
- rerankers=[LLMScoring(), CohereReranking()],
+ rerankers=[CohereReranking(), LLMTrulensScoring()],
)
if not user_settings["use_reranking"]:
retriever.rerankers = [] # type: ignore
diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py
index ad8ef4a3..2e9e00cb 100644
--- a/libs/ktem/ktem/reasoning/simple.py
+++ b/libs/ktem/ktem/reasoning/simple.py
@@ -3,6 +3,7 @@ import html
import logging
import re
from collections import defaultdict
+from difflib import SequenceMatcher
from functools import partial
from typing import Generator
@@ -50,6 +51,26 @@ def is_close(val1, val2, tolerance=1e-9):
return abs(val1 - val2) <= tolerance
+def find_text(search_span, context):
+ sentence_list = search_span.split("\n")
+ matches = []
+ # don't search for small text
+ if len(search_span) > 5:
+ for sentence in sentence_list:
+ match = SequenceMatcher(
+ None, sentence, context, autojunk=False
+ ).find_longest_match()
+ if match.size > len(search_span) * 0.6:
+ matches.append((match.b, match.b + match.size))
+ print(
+ "search",
+ search_span,
+ "matched",
+ context[match.b : match.b + match.size],
+ )
+ return matches
+
+
_default_token_func = tiktoken.encoding_for_model("gpt-3.5-turbo").encode
@@ -139,9 +160,9 @@ class PrepareEvidencePipeline(BaseComponent):
DEFAULT_QA_TEXT_PROMPT = (
- "Use the following pieces of context to answer the question at the end. "
+ "Use the following pieces of context to answer the question at the end in detail with clear explanation. " # noqa: E501
"If you don't know the answer, just say that you don't know, don't try to "
- "make up an answer. Keep the answer as concise as possible. Give answer in "
+ "make up an answer. Give answer in "
"{lang}.\n\n"
"{context}\n"
"Question: {question}\n"
@@ -150,14 +171,14 @@ DEFAULT_QA_TEXT_PROMPT = (
DEFAULT_QA_TABLE_PROMPT = (
"List all rows (row number) from the table context that related to the question, "
- "then provide detail answer with clear explanation and citations. "
+ "then provide detail answer with clear explanation. "
"If you don't know the answer, just say that you don't know, "
"don't try to make up an answer. Give answer in {lang}.\n\n"
"Context:\n"
"{context}\n"
"Question: {question}\n"
"Helpful Answer:"
-)
+) # noqa
DEFAULT_QA_CHATBOT_PROMPT = (
"Pick the most suitable chatbot scenarios to answer the question at the end, "
@@ -168,7 +189,7 @@ DEFAULT_QA_CHATBOT_PROMPT = (
"{context}\n"
"Question: {question}\n"
"Answer:"
-)
+) # noqa
DEFAULT_QA_FIGURE_PROMPT = (
"Use the given context: texts, tables, and figures below to answer the question. "
@@ -178,7 +199,7 @@ DEFAULT_QA_FIGURE_PROMPT = (
"{context}\n"
"Question: {question}\n"
"Answer: "
-)
+) # noqa
DEFAULT_REWRITE_PROMPT = (
"Given the following question, rephrase and expand it "
@@ -187,7 +208,7 @@ DEFAULT_REWRITE_PROMPT = (
"Give answer in {lang}\n"
"Original question: {question}\n"
"Rephrased question: "
-)
+) # noqa
class AnswerWithContextPipeline(BaseComponent):
@@ -400,12 +421,14 @@ class AnswerWithContextPipeline(BaseComponent):
if evidence and self.enable_citation:
citation = self.citation_pipeline(context=evidence, question=question)
+ if logprobs:
+ qa_score = np.exp(np.average(logprobs))
+ else:
+ qa_score = None
+
answer = Document(
text=output,
- metadata={
- "citation": citation,
- "qa_score": round(np.exp(np.average(logprobs)), 2),
- },
+ metadata={"citation": citation, "qa_score": qa_score},
)
return answer
@@ -556,6 +579,47 @@ class FullQAPipeline(BaseReasoning):
return docs, info
+ def _format_retrieval_score_and_doc(
+ self,
+ doc: Document,
+ rendered_doc_content: str,
+ open_collapsible: bool = False,
+ ) -> str:
+ """Format the retrieval score and the document"""
+ # score from doc_store (Elasticsearch)
+ if is_close(doc.score, -1.0):
+ text_search_str = " default from full-text search
"
+ else:
+ text_search_str = "
"
+
+ vectorstore_score = round(doc.score, 2)
+ llm_reranking_score = (
+ round(doc.metadata["llm_trulens_score"], 2)
+ if doc.metadata.get("llm_trulens_score") is not None
+ else None
+ )
+ cohere_reranking_score = (
+ round(doc.metadata["cohere_reranking_score"], 2)
+ if doc.metadata.get("cohere_reranking_score")
+ else None
+ )
+ item_type_prefix = doc.metadata.get("type", "")
+ item_type_prefix = item_type_prefix.capitalize()
+ if item_type_prefix:
+ item_type_prefix += " from "
+
+ return Render.collapsible(
+ header=(f"{item_type_prefix}{get_header(doc)} [{llm_reranking_score}]"),
+ content="Vectorstore score:"
+ f" {vectorstore_score}"
+ f"{text_search_str}"
+ "LLM reranking score:"
+ f" {llm_reranking_score}
"
+ "Cohere reranking score:"
+ f" {cohere_reranking_score}
" + rendered_doc_content,
+ open=open_collapsible,
+ )
+
def prepare_citations(self, answer, docs) -> tuple[list[Document], list[Document]]:
"""Prepare the citations to show on the UI"""
with_citation, without_citation = [], []
@@ -565,116 +629,63 @@ class FullQAPipeline(BaseReasoning):
for fact_with_evidence in answer.metadata["citation"].answer:
for quote in fact_with_evidence.substring_quote:
for doc in docs:
- start_idx = doc.text.find(quote)
- if start_idx == -1:
- continue
+ matches = find_text(quote, doc.text)
- end_idx = start_idx + len(quote)
-
- current_idx = start_idx
- if "|" not in doc.text[start_idx:end_idx]:
- spans[doc.doc_id].append(
- {"start": start_idx, "end": end_idx}
- )
- else:
- while doc.text[current_idx:end_idx].find("|") != -1:
- match_idx = doc.text[current_idx:end_idx].find("|")
+ for start, end in matches:
+ if "|" not in doc.text[start:end]:
spans[doc.doc_id].append(
{
- "start": current_idx,
- "end": current_idx + match_idx,
+ "start": start,
+ "end": end,
}
)
- current_idx += match_idx + 2
- if current_idx > end_idx:
- break
- break
id2docs = {doc.doc_id: doc for doc in docs}
not_detected = set(id2docs.keys()) - set(spans.keys())
- for id, ss in spans.items():
+
+ # render highlight spans
+ for _id, ss in spans.items():
if not ss:
- not_detected.add(id)
+ not_detected.add(_id)
continue
+ cur_doc = id2docs[_id]
ss = sorted(ss, key=lambda x: x["start"])
- text = id2docs[id].text[: ss[0]["start"]]
+ text = cur_doc.text[: ss[0]["start"]]
for idx, span in enumerate(ss):
- text += Render.highlight(id2docs[id].text[span["start"] : span["end"]])
+ text += Render.highlight(cur_doc.text[span["start"] : span["end"]])
if idx < len(ss) - 1:
- text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
- text += id2docs[id].text[ss[-1]["end"] :]
- if is_close(id2docs[id].score, -1.0):
- text_search_str = " default from full-text search
"
- else:
- text_search_str = "
"
-
- if (
- id2docs[id].metadata.get("llm_reranking_score") is None
- or id2docs[id].metadata.get("cohere_reranking_score") is None
- ):
- cloned_chunk_str = (
- "Cloned chunk for a table. No reranking score
"
- )
- else:
- cloned_chunk_str = ""
-
+ text += cur_doc.text[span["end"] : ss[idx + 1]["start"]]
+ text += cur_doc.text[ss[-1]["end"] :]
+ # add to display list
with_citation.append(
Document(
channel="info",
- content=Render.collapsible(
- header=(
- f"{get_header(id2docs[id])}
"
- "Vectorstore score:"
- f" {round(id2docs[id].score, 2)}"
- f"{text_search_str}"
- f"{cloned_chunk_str}"
- "LLM reranking score:"
- f' {id2docs[id].metadata.get("llm_reranking_score")}
'
- "Cohere reranking score:"
- f' {id2docs[id].metadata.get("cohere_reranking_score")}
'
- ),
- content=Render.table(text),
- open=True,
+ content=self._format_retrieval_score_and_doc(
+ cur_doc,
+ Render.table(text),
+ open_collapsible=True,
),
)
)
+ print("Got {} cited docs".format(len(with_citation)))
- for id_ in list(not_detected):
+ sorted_not_detected_items_with_scores = [
+ (id_, id2docs[id_].metadata.get("llm_trulens_score", 0.0))
+ for id_ in not_detected
+ ]
+ sorted_not_detected_items_with_scores.sort(key=lambda x: x[1], reverse=True)
+
+ for id_, _ in sorted_not_detected_items_with_scores:
doc = id2docs[id_]
- if is_close(doc.score, -1.0):
- text_search_str = " default from full-text search
"
- else:
- text_search_str = "
"
-
- if (
- doc.metadata.get("llm_reranking_score") is None
- or doc.metadata.get("cohere_reranking_score") is None
- ):
- cloned_chunk_str = (
- "Cloned chunk for a table. No reranking score
"
- )
- else:
- cloned_chunk_str = ""
if doc.metadata.get("type", "") == "image":
without_citation.append(
Document(
channel="info",
- content=Render.collapsible(
- header=(
- f"{get_header(doc)}
"
- "Vectorstore score:"
- f" {round(doc.score, 2)}"
- f"{text_search_str}"
- f"{cloned_chunk_str}"
- "LLM reranking score:"
- f' {doc.metadata.get("llm_reranking_score")}
'
- "Cohere reranking score:"
- f' {doc.metadata.get("cohere_reranking_score")}
'
- ),
- content=Render.image(
+ content=self._format_retrieval_score_and_doc(
+ doc,
+ Render.image(
url=doc.metadata["image_origin"], text=doc.text
),
- open=True,
),
)
)
@@ -682,20 +693,8 @@ class FullQAPipeline(BaseReasoning):
without_citation.append(
Document(
channel="info",
- content=Render.collapsible(
- header=(
- f"{get_header(doc)}
"
- "Vectorstore score:"
- f" {round(doc.score, 2)}"
- f"{text_search_str}"
- f"{cloned_chunk_str}"
- "LLM reranking score:"
- f' {doc.metadata.get("llm_reranking_score")}
'
- "Cohere reranking score:"
- f' {doc.metadata.get("cohere_reranking_score")}
'
- ),
- content=Render.table(doc.text),
- open=True,
+ content=self._format_retrieval_score_and_doc(
+ doc, Render.table(doc.text)
),
)
)
@@ -744,6 +743,8 @@ class FullQAPipeline(BaseReasoning):
if self.use_rewrite:
message = self.rewrite_pipeline(question=message).text
+ print(f"Rewritten message (use_rewrite={self.use_rewrite}): {message}")
+ print(f"Retrievers {self.retrievers}")
# should populate the context
docs, infos = self.retrieve(message, history)
for _ in infos:
@@ -770,12 +771,18 @@ class FullQAPipeline(BaseReasoning):
if without_citation:
for _ in without_citation:
yield _
+
+ qa_score = (
+ round(answer.metadata["qa_score"], 2)
+ if answer.metadata.get("qa_score")
+ else None
+ )
yield Document(
channel="info",
content=(
"