From 47b8412695d0bc07705bda54ad07cd9145d09d66 Mon Sep 17 00:00:00 2001 From: jvinolus Date: Wed, 15 Jan 2025 17:05:04 -0800 Subject: [PATCH 1/3] Initialize support for prefixing embeddings --- backend/open_webui/config.py | 12 ++++++++ backend/open_webui/retrieval/utils.py | 40 +++++++++++++------------ backend/open_webui/routers/retrieval.py | 3 +- 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index a48b2db055..ac121672e4 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1330,6 +1330,18 @@ RAG_EMBEDDING_BATCH_SIZE = PersistentConfig( ), ) +RAG_EMBEDDING_PASSAGE_PREFIX = PersistentConfig( + "RAG_EMBEDDING_PASSAGE_PREFIX", + "rag.embedding_passage_prefix", + os.environ.get("RAG_EMBEDDING_PASSAGE_PREFIX", False), +) + +RAG_EMBEDDING_QUERY_PREFIX = PersistentConfig( + "RAG_EMBEDDING_QUERY_PREFIX", + "rag.embedding_query_prefix", + os.environ.get("RAG_EMBEDDING_QUERY_PREFIX", False), +) + RAG_RERANKING_MODEL = PersistentConfig( "RAG_RERANKING_MODEL", "rag.reranking_model", diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index c95367e6c3..e420814d80 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -15,7 +15,7 @@ from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE - +from open_webui.config import RAG_EMBEDDING_QUERY_PREFIX, RAG_EMBEDDING_PASSAGE_PREFIX log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -39,7 +39,7 @@ class VectorSearchRetriever(BaseRetriever): ) -> list[Document]: result = VECTOR_DB_CLIENT.search( collection_name=self.collection_name, - vectors=[self.embedding_function(query)], + vectors=[self.embedding_function(query,RAG_EMBEDDING_QUERY_PREFIX)], limit=self.top_k, ) @@ -183,7 +183,7 @@ def query_collection( ) -> dict: results = [] for query in queries: - query_embedding = embedding_function(query) + query_embedding = embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX) for collection_name in collection_names: if collection_name: try: @@ -247,26 +247,27 @@ def get_embedding_function( embedding_batch_size, ): if embedding_engine == "": - return lambda query: embedding_function.encode(query).tolist() + return lambda query, prefix: embedding_function.encode(query, prompt = prefix if prefix else None).tolist() elif embedding_engine in ["ollama", "openai"]: - func = lambda query: generate_embeddings( + func = lambda query, prefix: generate_embeddings( engine=embedding_engine, model=embedding_model, text=query, + prefix=prefix, url=url, key=key, ) - def generate_multiple(query, func): + def generate_multiple(query, prefix, func): if isinstance(query, list): embeddings = [] for i in range(0, len(query), embedding_batch_size): - embeddings.extend(func(query[i : i + embedding_batch_size])) + embeddings.extend(func(query[i : i + embedding_batch_size], prefix)) return embeddings else: return func(query) - return lambda query: generate_multiple(query, func) + return lambda query, prefix: generate_multiple(query, prefix, func) def get_sources_from_files( @@ -411,7 +412,7 @@ def get_model_path(model: str, update_model: bool = False): def generate_openai_batch_embeddings( - model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = "" + model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = "", prefix: str = None ) -> Optional[list[list[float]]]: try: r = requests.post( @@ -420,7 +421,7 @@ def generate_openai_batch_embeddings( "Content-Type": "application/json", "Authorization": f"Bearer {key}", }, - json={"input": texts, "model": model}, + json={"input": texts, "model": model} if not prefix else {"input": texts, "model": model, "prefix": prefix}, ) r.raise_for_status() data = r.json() @@ -434,7 +435,7 @@ def generate_openai_batch_embeddings( def generate_ollama_batch_embeddings( - model: str, texts: list[str], url: str, key: str = "" + model: str, texts: list[str], url: str, key: str = "", prefix: str = None ) -> Optional[list[list[float]]]: try: r = requests.post( @@ -443,7 +444,7 @@ def generate_ollama_batch_embeddings( "Content-Type": "application/json", "Authorization": f"Bearer {key}", }, - json={"input": texts, "model": model}, + json={"input": texts, "model": model} if not prefix else {"input": texts, "model": model, "prefix": prefix}, ) r.raise_for_status() data = r.json() @@ -457,25 +458,25 @@ def generate_ollama_batch_embeddings( return None -def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs): +def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], prefix: Union[str , None] = None, **kwargs): url = kwargs.get("url", "") key = kwargs.get("key", "") if engine == "ollama": if isinstance(text, list): embeddings = generate_ollama_batch_embeddings( - **{"model": model, "texts": text, "url": url, "key": key} + **{"model": model, "texts": text, "url": url, "key": key, "prefix": prefix} ) else: embeddings = generate_ollama_batch_embeddings( - **{"model": model, "texts": [text], "url": url, "key": key} + **{"model": model, "texts": [text], "url": url, "key": key, "prefix": prefix} ) return embeddings[0] if isinstance(text, str) else embeddings elif engine == "openai": if isinstance(text, list): - embeddings = generate_openai_batch_embeddings(model, text, url, key) + embeddings = generate_openai_batch_embeddings(model, text, url, key, prefix) else: - embeddings = generate_openai_batch_embeddings(model, [text], url, key) + embeddings = generate_openai_batch_embeddings(model, [text], url, key, prefix) return embeddings[0] if isinstance(text, str) else embeddings @@ -512,9 +513,10 @@ class RerankCompressor(BaseDocumentCompressor): else: from sentence_transformers import util - query_embedding = self.embedding_function(query) + query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX) document_embedding = self.embedding_function( - [doc.page_content for doc in documents] + [doc.page_content for doc in documents], + RAG_EMBEDDING_PASSAGE_PREFIX ) scores = util.cos_sim(query_embedding, document_embedding)[0] diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index c791bde842..b0c3f8e042 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -79,6 +79,7 @@ from open_webui.config import ( RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, UPLOAD_DIR, DEFAULT_LOCALE, + RAG_EMBEDDING_PASSAGE_PREFIX ) from open_webui.env import ( SRC_LOG_LEVELS, @@ -775,7 +776,7 @@ def save_docs_to_vector_db( ) embeddings = embedding_function( - list(map(lambda x: x.replace("\n", " "), texts)) + list(map(lambda x: x.replace("\n", " "), texts)), RAG_EMBEDDING_PASSAGE_PREFIX ) items = [ From 7b8e5d4e7cb03d79ee832dc1107b8d74a405ae2e Mon Sep 17 00:00:00 2001 From: jvinolus Date: Tue, 4 Feb 2025 13:04:36 -0800 Subject: [PATCH 2/3] Fixed errors and added more support --- backend/open_webui/config.py | 16 ++++++++-------- backend/open_webui/retrieval/utils.py | 12 ++++++++---- backend/open_webui/routers/retrieval.py | 8 ++++---- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index ac121672e4..f1b1c14a5c 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1330,16 +1330,16 @@ RAG_EMBEDDING_BATCH_SIZE = PersistentConfig( ), ) -RAG_EMBEDDING_PASSAGE_PREFIX = PersistentConfig( - "RAG_EMBEDDING_PASSAGE_PREFIX", - "rag.embedding_passage_prefix", - os.environ.get("RAG_EMBEDDING_PASSAGE_PREFIX", False), +RAG_EMBEDDING_QUERY_PREFIX = ( + os.environ.get("RAG_EMBEDDING_QUERY_PREFIX", None) ) -RAG_EMBEDDING_QUERY_PREFIX = PersistentConfig( - "RAG_EMBEDDING_QUERY_PREFIX", - "rag.embedding_query_prefix", - os.environ.get("RAG_EMBEDDING_QUERY_PREFIX", False), +RAG_EMBEDDING_PASSAGE_PREFIX = ( + os.environ.get("RAG_EMBEDDING_PASSAGE_PREFIX", None) +) + +RAG_EMBEDDING_PREFIX_FIELD_NAME = ( + os.environ.get("RAG_EMBEDDING_PREFIX_FIELD_NAME", "input_type") ) RAG_RERANKING_MODEL = PersistentConfig( diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index e420814d80..544a65a89d 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -15,7 +15,11 @@ from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE -from open_webui.config import RAG_EMBEDDING_QUERY_PREFIX, RAG_EMBEDDING_PASSAGE_PREFIX +from open_webui.config import ( + RAG_EMBEDDING_QUERY_PREFIX, + RAG_EMBEDDING_PASSAGE_PREFIX, + RAG_EMBEDDING_PREFIX_FIELD_NAME +) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -265,7 +269,7 @@ def get_embedding_function( embeddings.extend(func(query[i : i + embedding_batch_size], prefix)) return embeddings else: - return func(query) + return func(query, prefix) return lambda query, prefix: generate_multiple(query, prefix, func) @@ -421,7 +425,7 @@ def generate_openai_batch_embeddings( "Content-Type": "application/json", "Authorization": f"Bearer {key}", }, - json={"input": texts, "model": model} if not prefix else {"input": texts, "model": model, "prefix": prefix}, + json={"input": texts, "model": model} if not prefix else {"input": texts, "model": model, RAG_EMBEDDING_PREFIX_FIELD_NAME: prefix}, ) r.raise_for_status() data = r.json() @@ -444,7 +448,7 @@ def generate_ollama_batch_embeddings( "Content-Type": "application/json", "Authorization": f"Bearer {key}", }, - json={"input": texts, "model": model} if not prefix else {"input": texts, "model": model, "prefix": prefix}, + json={"input": texts, "model": model} if not prefix else {"input": texts, "model": model, RAG_EMBEDDING_PREFIX_FIELD_NAME: prefix}, ) r.raise_for_status() data = r.json() diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index b0c3f8e042..255cff1127 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -70,7 +70,6 @@ from open_webui.utils.misc import ( ) from open_webui.utils.auth import get_admin_user, get_verified_user - from open_webui.config import ( ENV, RAG_EMBEDDING_MODEL_AUTO_UPDATE, @@ -79,7 +78,8 @@ from open_webui.config import ( RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, UPLOAD_DIR, DEFAULT_LOCALE, - RAG_EMBEDDING_PASSAGE_PREFIX + RAG_EMBEDDING_PASSAGE_PREFIX, + RAG_EMBEDDING_QUERY_PREFIX ) from open_webui.env import ( SRC_LOG_LEVELS, @@ -1319,7 +1319,7 @@ def query_doc_handler( else: return query_doc( collection_name=form_data.collection_name, - query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query), + query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query, RAG_EMBEDDING_QUERY_PREFIX), k=form_data.k if form_data.k else request.app.state.config.TOP_K, ) except Exception as e: @@ -1438,7 +1438,7 @@ if ENV == "dev": @router.get("/ef/{text}") async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"): - return {"result": request.app.state.EMBEDDING_FUNCTION(text)} + return {"result": request.app.state.EMBEDDING_FUNCTION(text, RAG_EMBEDDING_QUERY_PREFIX)} class BatchProcessFilesForm(BaseModel): From 6d2f87e9044800320656c98a501302f2f6a3f56a Mon Sep 17 00:00:00 2001 From: jayteaftw Date: Wed, 5 Feb 2025 14:03:16 -0800 Subject: [PATCH 3/3] Added server side Prefixing --- backend/open_webui/config.py | 2 +- backend/open_webui/retrieval/utils.py | 25 +++++++++++++++++++++++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index f1b1c14a5c..5635b70a67 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1339,7 +1339,7 @@ RAG_EMBEDDING_PASSAGE_PREFIX = ( ) RAG_EMBEDDING_PREFIX_FIELD_NAME = ( - os.environ.get("RAG_EMBEDDING_PREFIX_FIELD_NAME", "input_type") + os.environ.get("RAG_EMBEDDING_PREFIX_FIELD_NAME", None) ) RAG_RERANKING_MODEL = PersistentConfig( diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 544a65a89d..7a9be9ea94 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -418,14 +418,22 @@ def get_model_path(model: str, update_model: bool = False): def generate_openai_batch_embeddings( model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = "", prefix: str = None ) -> Optional[list[list[float]]]: + try: + json_data = { + "input": texts, + "model": model + } + if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME,str) and isinstance(prefix,str): + json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix + r = requests.post( f"{url}/embeddings", headers={ "Content-Type": "application/json", "Authorization": f"Bearer {key}", }, - json={"input": texts, "model": model} if not prefix else {"input": texts, "model": model, RAG_EMBEDDING_PREFIX_FIELD_NAME: prefix}, + json=json_data, ) r.raise_for_status() data = r.json() @@ -442,13 +450,20 @@ def generate_ollama_batch_embeddings( model: str, texts: list[str], url: str, key: str = "", prefix: str = None ) -> Optional[list[list[float]]]: try: + json_data = { + "input": texts, + "model": model + } + if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME,str) and isinstance(prefix,str): + json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix + r = requests.post( f"{url}/api/embed", headers={ "Content-Type": "application/json", "Authorization": f"Bearer {key}", }, - json={"input": texts, "model": model} if not prefix else {"input": texts, "model": model, RAG_EMBEDDING_PREFIX_FIELD_NAME: prefix}, + json=json_data, ) r.raise_for_status() data = r.json() @@ -466,6 +481,12 @@ def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], pr url = kwargs.get("url", "") key = kwargs.get("key", "") + if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None: + if isinstance(text, list): + text = [f'{prefix}{text_element}' for text_element in text] + else: + text = f'{prefix}{text}' + if engine == "ollama": if isinstance(text, list): embeddings = generate_ollama_batch_embeddings(