enh: tiktoken/token splitter support

This commit is contained in:
Timothy J. Baek
2024-10-13 02:07:50 -07:00
parent 8ae605ec4b
commit dff3732fcd
4 changed files with 49 additions and 7 deletions

View File

@@ -47,6 +47,8 @@ from open_webui.apps.retrieval.utils import (
from open_webui.apps.webui.models.files import Files
from open_webui.config import (
BRAVE_SEARCH_API_KEY,
TIKTOKEN_ENCODING_NAME,
RAG_TEXT_SPLITTER,
CHUNK_OVERLAP,
CHUNK_SIZE,
CONTENT_EXTRACTION_ENGINE,
@@ -102,7 +104,7 @@ from open_webui.utils.misc import (
)
from open_webui.utils.utils import get_admin_user, get_verified_user
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
from langchain_community.document_loaders import (
YoutubeLoader,
)
@@ -129,6 +131,9 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
@@ -648,11 +653,22 @@ def save_docs_to_vector_db(
raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
if split:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
if app.state.config.TEXT_SPLITTER in ["", "character"]:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
elif app.state.config.TEXT_SPLITTER == "token":
text_splitter = TokenTextSplitter(
encoding_name=app.state.config.TIKTOKEN_ENCODING_NAME,
chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
else:
raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter"))
docs = text_splitter.split_documents(docs)
if len(docs) == 0:

View File

@@ -1014,6 +1014,22 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
RAG_TEXT_SPLITTER = PersistentConfig(
"RAG_TEXT_SPLITTER",
"rag.text_splitter",
os.environ.get("RAG_TEXT_SPLITTER", ""),
)
TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken")
TIKTOKEN_ENCODING_NAME = PersistentConfig(
"TIKTOKEN_ENCODING_NAME",
"rag.tiktoken_encoding_name",
os.environ.get("TIKTOKEN_ENCODING_NAME", "cl100k_base"),
)
CHUNK_SIZE = PersistentConfig(
"CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1000"))
)

View File

@@ -20,7 +20,7 @@ class ERROR_MESSAGES(str, Enum):
def __str__(self) -> str:
return super().__str__()
DEFAULT = lambda err="": f"Something went wrong :/\n{err if err else ''}"
DEFAULT = lambda err="": f"Something went wrong :/\n[ERROR: {err if err else ''}]"
ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now."
CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance."
DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot."