mirror of
https://github.com/Cinnamon/kotaemon.git
synced 2025-12-16 11:47:48 +01:00
feat: add paper recommendation
This commit is contained in:
@@ -57,6 +57,7 @@ repos:
|
||||
"types-requests",
|
||||
"sqlmodel",
|
||||
"types-Markdown",
|
||||
"types-cachetools",
|
||||
types-tzlocal,
|
||||
]
|
||||
args: ["--check-untyped-defs", "--ignore-missing-imports"]
|
||||
|
||||
@@ -26,6 +26,7 @@ from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
|
||||
|
||||
from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex, get_urls
|
||||
from ...utils.commands import WEB_SEARCH_COMMAND
|
||||
from ...utils.hf_papers import get_recommended_papers
|
||||
from ...utils.rate_limit import check_rate_limit
|
||||
from .chat_panel import ChatPanel
|
||||
from .chat_suggestion import ChatSuggestion
|
||||
@@ -68,6 +69,36 @@ function() {
|
||||
}
|
||||
"""
|
||||
|
||||
recommended_papers_js = """
|
||||
function() {
|
||||
// Get all links and attach click event
|
||||
var links = document.querySelectorAll("#related-papers a");
|
||||
|
||||
function submitPaper(event) {
|
||||
event.preventDefault();
|
||||
var target = event.currentTarget;
|
||||
var url = target.getAttribute("href");
|
||||
console.log("URL:", url);
|
||||
|
||||
let newChatButton = document.querySelector("#new-conv-button");
|
||||
newChatButton.click();
|
||||
|
||||
setTimeout(() => {
|
||||
let urlInput = document.querySelector("#quick-url-demo textarea");
|
||||
// Fill the URL input
|
||||
urlInput.value = url;
|
||||
urlInput.dispatchEvent(new Event("input", { bubbles: true }));
|
||||
urlInput.dispatchEvent(new KeyboardEvent('keypress', {'key': 'Enter'}));
|
||||
}, 500
|
||||
);
|
||||
}
|
||||
|
||||
for (var i = 0; i < links.length; i++) {
|
||||
links[i].onclick = submitPaper;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
clear_bot_message_selection_js = """
|
||||
function() {
|
||||
var bot_messages = document.querySelectorAll(
|
||||
@@ -268,14 +299,17 @@ class ChatPage(BasePage):
|
||||
if not KH_DEMO_MODE:
|
||||
self.report_issue = ReportIssue(self._app)
|
||||
else:
|
||||
with gr.Accordion(label="Related papers", open=False):
|
||||
self.related_papers = gr.Markdown(elem_id="related-papers")
|
||||
|
||||
self.hint_page = HintPage(self._app)
|
||||
|
||||
with gr.Column(scale=6, elem_id="chat-area"):
|
||||
self.chat_panel = ChatPanel(self._app)
|
||||
|
||||
if KH_DEMO_MODE:
|
||||
self.paper_list = PaperListPage(self._app)
|
||||
|
||||
self.chat_panel = ChatPanel(self._app)
|
||||
|
||||
with gr.Accordion(
|
||||
label="Chat settings",
|
||||
elem_id="chat-settings-expand",
|
||||
@@ -360,6 +394,19 @@ class ChatPage(BasePage):
|
||||
return plot
|
||||
|
||||
def on_register_events(self):
|
||||
# first index paper recommendation
|
||||
if KH_DEMO_MODE and len(self._indices_input) > 0:
|
||||
self._indices_input[1].change(
|
||||
self.get_recommendations,
|
||||
inputs=[self.first_selector_choices, self._indices_input[1]],
|
||||
outputs=[self.related_papers],
|
||||
).then(
|
||||
fn=None,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
js=recommended_papers_js,
|
||||
)
|
||||
|
||||
chat_event = (
|
||||
gr.on(
|
||||
triggers=[
|
||||
@@ -916,6 +963,17 @@ class ChatPage(BasePage):
|
||||
+ [used_command]
|
||||
)
|
||||
|
||||
def get_recommendations(self, first_selector_choices, file_ids):
|
||||
first_selector_choices_map = {
|
||||
item[1]: item[0] for item in first_selector_choices
|
||||
}
|
||||
file_names = [first_selector_choices_map[file_id] for file_id in file_ids]
|
||||
if not file_names:
|
||||
return ""
|
||||
|
||||
first_file_name = file_names[0].split(".")[0].replace("_", " ")
|
||||
return get_recommended_papers(first_file_name)
|
||||
|
||||
def toggle_delete(self, conv_id):
|
||||
if conv_id:
|
||||
return gr.update(visible=False), gr.update(visible=True)
|
||||
|
||||
@@ -13,8 +13,8 @@ class PaperListPage(BasePage):
|
||||
def on_building_ui(self):
|
||||
self.papers_state = gr.State(None)
|
||||
with gr.Accordion(
|
||||
label="Browse daily top papers",
|
||||
open=False,
|
||||
label="Browse popular daily papers",
|
||||
open=True,
|
||||
) as self.accordion:
|
||||
self.examples = gr.DataFrame(
|
||||
value=[],
|
||||
|
||||
@@ -435,7 +435,14 @@ class FullQAPipeline(BaseReasoning):
|
||||
},
|
||||
"system_prompt": {
|
||||
"name": "System Prompt",
|
||||
"value": "This is a question answering system",
|
||||
"value": dedent(
|
||||
"""This is a question answering system.
|
||||
Organize the answer in bullet points if applicable.
|
||||
When asked for paper summary, provide a brief summary of the paper
|
||||
with the following sections:
|
||||
Background, Hypothesis, Method, Results, Conclusion & Future Work.
|
||||
"""
|
||||
),
|
||||
},
|
||||
"qa_prompt": {
|
||||
"name": "QA Prompt (contains {context}, {question}, {lang})",
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import requests
|
||||
from cachetools import TTLCache, cached
|
||||
|
||||
HF_API_URL = "https://huggingface.co/api/daily_papers"
|
||||
ARXIV_URL = "https://arxiv.org/abs/{paper_id}"
|
||||
SEMANTIC_SCHOLAR_QUERY_URL = "https://api.semanticscholar.org/graph/v1/paper/search/match?query={paper_name}" # noqa
|
||||
SEMANTIC_SCHOLAR_RECOMMEND_URL = (
|
||||
"https://api.semanticscholar.org/recommendations/v1/papers/" # noqa
|
||||
)
|
||||
CACHE_TIME = 60 * 60 * 6 # 6 hours
|
||||
|
||||
|
||||
# Function to parse the date string
|
||||
@@ -11,6 +17,69 @@ def parse_date(date_str):
|
||||
return datetime.strptime(date_str, "%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
|
||||
|
||||
@cached(cache=TTLCache(maxsize=500, ttl=CACHE_TIME))
|
||||
def get_recommendations_from_semantic_scholar(semantic_scholar_id: str):
|
||||
try:
|
||||
r = requests.post(
|
||||
SEMANTIC_SCHOLAR_RECOMMEND_URL,
|
||||
json={
|
||||
"positivePaperIds": [semantic_scholar_id],
|
||||
},
|
||||
params={"fields": "externalIds,title,year", "limit": 14}, # type: ignore
|
||||
)
|
||||
return r.json()["recommendedPapers"]
|
||||
except KeyError as e:
|
||||
print(e)
|
||||
return []
|
||||
|
||||
|
||||
def filter_recommendations(recommendations, max_paper_count=5):
|
||||
# include only arxiv papers
|
||||
arxiv_paper = [
|
||||
r for r in recommendations if r["externalIds"].get("ArXiv", None) is not None
|
||||
]
|
||||
if len(arxiv_paper) > max_paper_count:
|
||||
arxiv_paper = arxiv_paper[:max_paper_count]
|
||||
return arxiv_paper
|
||||
|
||||
|
||||
def format_recommendation_into_markdown(recommendations):
|
||||
comment = "(recommended by the Semantic Scholar API)\n\n"
|
||||
for r in recommendations:
|
||||
hub_paper_url = f"https://arxiv.org/abs/{r['externalIds']['ArXiv']}"
|
||||
comment += f"* [{r['title']}]({hub_paper_url}) ({r['year']})\n"
|
||||
|
||||
return comment
|
||||
|
||||
|
||||
def get_paper_id_from_name(paper_name):
|
||||
try:
|
||||
response = requests.get(
|
||||
SEMANTIC_SCHOLAR_QUERY_URL.format(paper_name=paper_name)
|
||||
)
|
||||
response.raise_for_status()
|
||||
items = response.json()
|
||||
paper_id = items.get("data", [])[0].get("paperId")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
return paper_id
|
||||
|
||||
|
||||
def get_recommended_papers(paper_name):
|
||||
paper_id = get_paper_id_from_name(paper_name)
|
||||
recommended_content = ""
|
||||
if paper_id is None:
|
||||
return recommended_content
|
||||
|
||||
recommended_papers = get_recommendations_from_semantic_scholar(paper_id)
|
||||
filtered_recommendations = filter_recommendations(recommended_papers)
|
||||
|
||||
recommended_content = format_recommendation_into_markdown(filtered_recommendations)
|
||||
return recommended_content
|
||||
|
||||
|
||||
def fetch_papers(top_n=5):
|
||||
try:
|
||||
response = requests.get(f"{HF_API_URL}?limit=100")
|
||||
|
||||
Reference in New Issue
Block a user