From 9e2fe4afc9259733c2bec05d79fed48107b84df8 Mon Sep 17 00:00:00 2001 From: cin-ace Date: Sun, 7 Jul 2024 21:59:59 +0700 Subject: [PATCH] fix: update trulens LLM ranking score for retrieval confidence, improve citation (#98) * Round when displaying not by default * Add LLMTrulens reranking model * Use llmtrulensscoring in pipeline * fix: update UI display for trulen score --------- Co-authored-by: taprosoft --- .../kotaemon/indices/rankings/__init__.py | 9 +- .../kotaemon/indices/rankings/cohere.py | 2 +- .../kotaemon/indices/rankings/llm_scoring.py | 4 +- .../kotaemon/indices/rankings/llm_trulens.py | 155 ++++++++++++ libs/ktem/ktem/index/file/pipelines.py | 4 +- libs/ktem/ktem/reasoning/simple.py | 225 +++++++++--------- 6 files changed, 284 insertions(+), 115 deletions(-) create mode 100644 libs/kotaemon/kotaemon/indices/rankings/llm_trulens.py diff --git a/libs/kotaemon/kotaemon/indices/rankings/__init__.py b/libs/kotaemon/kotaemon/indices/rankings/__init__.py index 9de04d8d..84b8765b 100644 --- a/libs/kotaemon/kotaemon/indices/rankings/__init__.py +++ b/libs/kotaemon/kotaemon/indices/rankings/__init__.py @@ -2,5 +2,12 @@ from .base import BaseReranking from .cohere import CohereReranking from .llm import LLMReranking from .llm_scoring import LLMScoring +from .llm_trulens import LLMTrulensScoring -__all__ = ["CohereReranking", "LLMReranking", "LLMScoring", "BaseReranking"] +__all__ = [ + "CohereReranking", + "LLMReranking", + "LLMScoring", + "BaseReranking", + "LLMTrulensScoring", +] diff --git a/libs/kotaemon/kotaemon/indices/rankings/cohere.py b/libs/kotaemon/kotaemon/indices/rankings/cohere.py index d22be9a8..4f5866ac 100644 --- a/libs/kotaemon/kotaemon/indices/rankings/cohere.py +++ b/libs/kotaemon/kotaemon/indices/rankings/cohere.py @@ -33,7 +33,7 @@ class CohereReranking(BaseReranking): ) for r in response.results: doc = documents[r.index] - doc.metadata["cohere_reranking_score"] = round(r.relevance_score, 2) + doc.metadata["cohere_reranking_score"] = r.relevance_score compressed_docs.append(doc) return compressed_docs diff --git a/libs/kotaemon/kotaemon/indices/rankings/llm_scoring.py b/libs/kotaemon/kotaemon/indices/rankings/llm_scoring.py index 0ee5b23a..b4f51053 100644 --- a/libs/kotaemon/kotaemon/indices/rankings/llm_scoring.py +++ b/libs/kotaemon/kotaemon/indices/rankings/llm_scoring.py @@ -42,9 +42,9 @@ class LLMScoring(LLMReranking): score = np.exp(np.average(result.logprobs)) include_doc = output_parser.parse(result.text) if include_doc: - doc.metadata["llm_reranking_score"] = round(score, 2) + doc.metadata["llm_reranking_score"] = score else: - doc.metadata["llm_reranking_score"] = round(1 - score, 2) + doc.metadata["llm_reranking_score"] = 1 - score filtered_docs.append(doc) # prevent returning empty result diff --git a/libs/kotaemon/kotaemon/indices/rankings/llm_trulens.py b/libs/kotaemon/kotaemon/indices/rankings/llm_trulens.py new file mode 100644 index 00000000..1fa4dc45 --- /dev/null +++ b/libs/kotaemon/kotaemon/indices/rankings/llm_trulens.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import re +from concurrent.futures import ThreadPoolExecutor + +from kotaemon.base import Document, HumanMessage, SystemMessage +from kotaemon.llms import BaseLLM, PromptTemplate + +from .llm import LLMReranking + +SYSTEM_PROMPT_TEMPLATE = PromptTemplate( + """You are a RELEVANCE grader; providing the relevance of the given CONTEXT to the given QUESTION. + Respond only as a number from 0 to 10 where 0 is the least relevant and 10 is the most relevant. + + A few additional scoring guidelines: + + - Long CONTEXTS should score equally well as short CONTEXTS. + + - RELEVANCE score should increase as the CONTEXTS provides more RELEVANT context to the QUESTION. + + - RELEVANCE score should increase as the CONTEXTS provides RELEVANT context to more parts of the QUESTION. + + - CONTEXT that is RELEVANT to some of the QUESTION should score of 2, 3 or 4. Higher score indicates more RELEVANCE. + + - CONTEXT that is RELEVANT to most of the QUESTION should get a score of 5, 6, 7 or 8. Higher score indicates more RELEVANCE. + + - CONTEXT that is RELEVANT to the entire QUESTION should get a score of 9 or 10. Higher score indicates more RELEVANCE. + + - CONTEXT must be relevant and helpful for answering the entire QUESTION to get a score of 10. + + - Never elaborate.""" # noqa: E501 +) + +USER_PROMPT_TEMPLATE = PromptTemplate( + """QUESTION: {question} + + CONTEXT: {context} + + RELEVANCE: """ +) # noqa + +PATTERN_INTEGER: re.Pattern = re.compile(r"([+-]?[1-9][0-9]*|0)") +"""Regex that matches integers.""" + + +def validate_rating(rating) -> int: + """Validate a rating is between 0 and 10.""" + + if not 0 <= rating <= 10: + raise ValueError("Rating must be between 0 and 10") + + return rating + + +def re_0_10_rating(s: str) -> int: + """Extract a 0-10 rating from a string. + + If the string does not match an integer or matches an integer outside the + 0-10 range, raises an error instead. If multiple numbers are found within + the expected 0-10 range, the smallest is returned. + + Args: + s: String to extract rating from. + + Returns: + int: Extracted rating. + + Raises: + ParseError: If no integers between 0 and 10 are found in the string. + """ + + matches = PATTERN_INTEGER.findall(s) + if not matches: + raise AssertionError + + vals = set() + for match in matches: + try: + vals.add(validate_rating(int(match))) + except ValueError: + pass + + if not vals: + raise AssertionError + + # Min to handle cases like "The rating is 8 out of 10." + return min(vals) + + +class LLMTrulensScoring(LLMReranking): + llm: BaseLLM + system_prompt_template: PromptTemplate = SYSTEM_PROMPT_TEMPLATE + user_prompt_template: PromptTemplate = USER_PROMPT_TEMPLATE + top_k: int = 3 + concurrent: bool = True + normalize: float = 10 + + def run( + self, + documents: list[Document], + query: str, + ) -> list[Document]: + """Filter down documents based on their relevance to the query.""" + filtered_docs = [] + + documents = sorted(documents, key=lambda doc: doc.get_content()) + if self.concurrent: + with ThreadPoolExecutor() as executor: + futures = [] + for doc in documents: + messages = [] + messages.append( + SystemMessage(self.system_prompt_template.populate()) + ) + messages.append( + HumanMessage( + self.user_prompt_template.populate( + question=query, context=doc.get_content() + ) + ) + ) + futures.append(executor.submit(lambda: self.llm(messages).text)) + + results = [future.result() for future in futures] + else: + results = [] + for doc in documents: + messages = [] + messages.append(SystemMessage(self.system_prompt_template.populate())) + messages.append( + SystemMessage( + self.user_prompt_template.populate( + question=query, context=doc.get_content() + ) + ) + ) + results.append(self.llm(messages).text) + + # use Boolean parser to extract relevancy output from LLM + results = [ + (r_idx, float(re_0_10_rating(result)) / self.normalize) + for r_idx, result in enumerate(results) + ] + results.sort(key=lambda x: x[1], reverse=True) + + for r_idx, score in results: + doc = documents[r_idx] + doc.metadata["llm_trulens_score"] = score + filtered_docs.append(doc) + + # prevent returning empty result + if len(filtered_docs) == 0: + filtered_docs = documents[: self.top_k] + + return filtered_docs diff --git a/libs/ktem/ktem/index/file/pipelines.py b/libs/ktem/ktem/index/file/pipelines.py index 22e9a1d2..afb655a8 100644 --- a/libs/ktem/ktem/index/file/pipelines.py +++ b/libs/ktem/ktem/index/file/pipelines.py @@ -35,7 +35,7 @@ from kotaemon.indices.rankings import ( BaseReranking, CohereReranking, LLMReranking, - LLMScoring, + LLMTrulensScoring, ) from kotaemon.indices.splitters import BaseSplitter, TokenSplitter @@ -254,7 +254,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever): ) ], retrieval_mode=user_settings["retrieval_mode"], - rerankers=[LLMScoring(), CohereReranking()], + rerankers=[CohereReranking(), LLMTrulensScoring()], ) if not user_settings["use_reranking"]: retriever.rerankers = [] # type: ignore diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index ad8ef4a3..2e9e00cb 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -3,6 +3,7 @@ import html import logging import re from collections import defaultdict +from difflib import SequenceMatcher from functools import partial from typing import Generator @@ -50,6 +51,26 @@ def is_close(val1, val2, tolerance=1e-9): return abs(val1 - val2) <= tolerance +def find_text(search_span, context): + sentence_list = search_span.split("\n") + matches = [] + # don't search for small text + if len(search_span) > 5: + for sentence in sentence_list: + match = SequenceMatcher( + None, sentence, context, autojunk=False + ).find_longest_match() + if match.size > len(search_span) * 0.6: + matches.append((match.b, match.b + match.size)) + print( + "search", + search_span, + "matched", + context[match.b : match.b + match.size], + ) + return matches + + _default_token_func = tiktoken.encoding_for_model("gpt-3.5-turbo").encode @@ -139,9 +160,9 @@ class PrepareEvidencePipeline(BaseComponent): DEFAULT_QA_TEXT_PROMPT = ( - "Use the following pieces of context to answer the question at the end. " + "Use the following pieces of context to answer the question at the end in detail with clear explanation. " # noqa: E501 "If you don't know the answer, just say that you don't know, don't try to " - "make up an answer. Keep the answer as concise as possible. Give answer in " + "make up an answer. Give answer in " "{lang}.\n\n" "{context}\n" "Question: {question}\n" @@ -150,14 +171,14 @@ DEFAULT_QA_TEXT_PROMPT = ( DEFAULT_QA_TABLE_PROMPT = ( "List all rows (row number) from the table context that related to the question, " - "then provide detail answer with clear explanation and citations. " + "then provide detail answer with clear explanation. " "If you don't know the answer, just say that you don't know, " "don't try to make up an answer. Give answer in {lang}.\n\n" "Context:\n" "{context}\n" "Question: {question}\n" "Helpful Answer:" -) +) # noqa DEFAULT_QA_CHATBOT_PROMPT = ( "Pick the most suitable chatbot scenarios to answer the question at the end, " @@ -168,7 +189,7 @@ DEFAULT_QA_CHATBOT_PROMPT = ( "{context}\n" "Question: {question}\n" "Answer:" -) +) # noqa DEFAULT_QA_FIGURE_PROMPT = ( "Use the given context: texts, tables, and figures below to answer the question. " @@ -178,7 +199,7 @@ DEFAULT_QA_FIGURE_PROMPT = ( "{context}\n" "Question: {question}\n" "Answer: " -) +) # noqa DEFAULT_REWRITE_PROMPT = ( "Given the following question, rephrase and expand it " @@ -187,7 +208,7 @@ DEFAULT_REWRITE_PROMPT = ( "Give answer in {lang}\n" "Original question: {question}\n" "Rephrased question: " -) +) # noqa class AnswerWithContextPipeline(BaseComponent): @@ -400,12 +421,14 @@ class AnswerWithContextPipeline(BaseComponent): if evidence and self.enable_citation: citation = self.citation_pipeline(context=evidence, question=question) + if logprobs: + qa_score = np.exp(np.average(logprobs)) + else: + qa_score = None + answer = Document( text=output, - metadata={ - "citation": citation, - "qa_score": round(np.exp(np.average(logprobs)), 2), - }, + metadata={"citation": citation, "qa_score": qa_score}, ) return answer @@ -556,6 +579,47 @@ class FullQAPipeline(BaseReasoning): return docs, info + def _format_retrieval_score_and_doc( + self, + doc: Document, + rendered_doc_content: str, + open_collapsible: bool = False, + ) -> str: + """Format the retrieval score and the document""" + # score from doc_store (Elasticsearch) + if is_close(doc.score, -1.0): + text_search_str = " default from full-text search
" + else: + text_search_str = "
" + + vectorstore_score = round(doc.score, 2) + llm_reranking_score = ( + round(doc.metadata["llm_trulens_score"], 2) + if doc.metadata.get("llm_trulens_score") is not None + else None + ) + cohere_reranking_score = ( + round(doc.metadata["cohere_reranking_score"], 2) + if doc.metadata.get("cohere_reranking_score") + else None + ) + item_type_prefix = doc.metadata.get("type", "") + item_type_prefix = item_type_prefix.capitalize() + if item_type_prefix: + item_type_prefix += " from " + + return Render.collapsible( + header=(f"{item_type_prefix}{get_header(doc)} [{llm_reranking_score}]"), + content="Vectorstore score:" + f" {vectorstore_score}" + f"{text_search_str}" + "LLM reranking score:" + f" {llm_reranking_score}
" + "Cohere reranking score:" + f" {cohere_reranking_score}
" + rendered_doc_content, + open=open_collapsible, + ) + def prepare_citations(self, answer, docs) -> tuple[list[Document], list[Document]]: """Prepare the citations to show on the UI""" with_citation, without_citation = [], [] @@ -565,116 +629,63 @@ class FullQAPipeline(BaseReasoning): for fact_with_evidence in answer.metadata["citation"].answer: for quote in fact_with_evidence.substring_quote: for doc in docs: - start_idx = doc.text.find(quote) - if start_idx == -1: - continue + matches = find_text(quote, doc.text) - end_idx = start_idx + len(quote) - - current_idx = start_idx - if "|" not in doc.text[start_idx:end_idx]: - spans[doc.doc_id].append( - {"start": start_idx, "end": end_idx} - ) - else: - while doc.text[current_idx:end_idx].find("|") != -1: - match_idx = doc.text[current_idx:end_idx].find("|") + for start, end in matches: + if "|" not in doc.text[start:end]: spans[doc.doc_id].append( { - "start": current_idx, - "end": current_idx + match_idx, + "start": start, + "end": end, } ) - current_idx += match_idx + 2 - if current_idx > end_idx: - break - break id2docs = {doc.doc_id: doc for doc in docs} not_detected = set(id2docs.keys()) - set(spans.keys()) - for id, ss in spans.items(): + + # render highlight spans + for _id, ss in spans.items(): if not ss: - not_detected.add(id) + not_detected.add(_id) continue + cur_doc = id2docs[_id] ss = sorted(ss, key=lambda x: x["start"]) - text = id2docs[id].text[: ss[0]["start"]] + text = cur_doc.text[: ss[0]["start"]] for idx, span in enumerate(ss): - text += Render.highlight(id2docs[id].text[span["start"] : span["end"]]) + text += Render.highlight(cur_doc.text[span["start"] : span["end"]]) if idx < len(ss) - 1: - text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]] - text += id2docs[id].text[ss[-1]["end"] :] - if is_close(id2docs[id].score, -1.0): - text_search_str = " default from full-text search
" - else: - text_search_str = "
" - - if ( - id2docs[id].metadata.get("llm_reranking_score") is None - or id2docs[id].metadata.get("cohere_reranking_score") is None - ): - cloned_chunk_str = ( - "Cloned chunk for a table. No reranking score
" - ) - else: - cloned_chunk_str = "" - + text += cur_doc.text[span["end"] : ss[idx + 1]["start"]] + text += cur_doc.text[ss[-1]["end"] :] + # add to display list with_citation.append( Document( channel="info", - content=Render.collapsible( - header=( - f"{get_header(id2docs[id])}
" - "Vectorstore score:" - f" {round(id2docs[id].score, 2)}" - f"{text_search_str}" - f"{cloned_chunk_str}" - "LLM reranking score:" - f' {id2docs[id].metadata.get("llm_reranking_score")}
' - "Cohere reranking score:" - f' {id2docs[id].metadata.get("cohere_reranking_score")}
' - ), - content=Render.table(text), - open=True, + content=self._format_retrieval_score_and_doc( + cur_doc, + Render.table(text), + open_collapsible=True, ), ) ) + print("Got {} cited docs".format(len(with_citation))) - for id_ in list(not_detected): + sorted_not_detected_items_with_scores = [ + (id_, id2docs[id_].metadata.get("llm_trulens_score", 0.0)) + for id_ in not_detected + ] + sorted_not_detected_items_with_scores.sort(key=lambda x: x[1], reverse=True) + + for id_, _ in sorted_not_detected_items_with_scores: doc = id2docs[id_] - if is_close(doc.score, -1.0): - text_search_str = " default from full-text search
" - else: - text_search_str = "
" - - if ( - doc.metadata.get("llm_reranking_score") is None - or doc.metadata.get("cohere_reranking_score") is None - ): - cloned_chunk_str = ( - "Cloned chunk for a table. No reranking score
" - ) - else: - cloned_chunk_str = "" if doc.metadata.get("type", "") == "image": without_citation.append( Document( channel="info", - content=Render.collapsible( - header=( - f"{get_header(doc)}
" - "Vectorstore score:" - f" {round(doc.score, 2)}" - f"{text_search_str}" - f"{cloned_chunk_str}" - "LLM reranking score:" - f' {doc.metadata.get("llm_reranking_score")}
' - "Cohere reranking score:" - f' {doc.metadata.get("cohere_reranking_score")}
' - ), - content=Render.image( + content=self._format_retrieval_score_and_doc( + doc, + Render.image( url=doc.metadata["image_origin"], text=doc.text ), - open=True, ), ) ) @@ -682,20 +693,8 @@ class FullQAPipeline(BaseReasoning): without_citation.append( Document( channel="info", - content=Render.collapsible( - header=( - f"{get_header(doc)}
" - "Vectorstore score:" - f" {round(doc.score, 2)}" - f"{text_search_str}" - f"{cloned_chunk_str}" - "LLM reranking score:" - f' {doc.metadata.get("llm_reranking_score")}
' - "Cohere reranking score:" - f' {doc.metadata.get("cohere_reranking_score")}
' - ), - content=Render.table(doc.text), - open=True, + content=self._format_retrieval_score_and_doc( + doc, Render.table(doc.text) ), ) ) @@ -744,6 +743,8 @@ class FullQAPipeline(BaseReasoning): if self.use_rewrite: message = self.rewrite_pipeline(question=message).text + print(f"Rewritten message (use_rewrite={self.use_rewrite}): {message}") + print(f"Retrievers {self.retrievers}") # should populate the context docs, infos = self.retrieve(message, history) for _ in infos: @@ -770,12 +771,18 @@ class FullQAPipeline(BaseReasoning): if without_citation: for _ in without_citation: yield _ + + qa_score = ( + round(answer.metadata["qa_score"], 2) + if answer.metadata.get("qa_score") + else None + ) yield Document( channel="info", content=( "
Question answering

" "Question answering confidence: " - f"{answer.metadata.get('qa_score')}" + f"{qa_score}" ), )