openWebUI supports openGauss vector store (#20179)

This commit is contained in:
Dechao Sun
2025-12-26 22:32:05 +08:00
committed by GitHub
parent 9c2f5148d9
commit 25db8225f8
4 changed files with 459 additions and 0 deletions

View File

@@ -2342,6 +2342,51 @@ else:
except Exception:
PGVECTOR_IVFFLAT_LISTS = 100
# openGauss
OPENGAUSS_DB_URL = os.environ.get("OPENGAUSS_DB_URL", DATABASE_URL)
OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH = int(
os.environ.get("OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
)
OPENGAUSS_POOL_SIZE = os.environ.get("OPENGAUSS_POOL_SIZE", None)
if OPENGAUSS_POOL_SIZE != None:
try:
OPENGAUSS_POOL_SIZE = int(OPENGAUSS_POOL_SIZE)
except Exception:
OPENGAUSS_POOL_SIZE = None
OPENGAUSS_POOL_MAX_OVERFLOW = os.environ.get("OPENGAUSS_POOL_MAX_OVERFLOW", 0)
if OPENGAUSS_POOL_MAX_OVERFLOW == "":
OPENGAUSS_POOL_MAX_OVERFLOW = 0
else:
try:
OPENGAUSS_POOL_MAX_OVERFLOW = int(OPENGAUSS_POOL_MAX_OVERFLOW)
except Exception:
OPENGAUSS_POOL_MAX_OVERFLOW = 0
OPENGAUSS_POOL_TIMEOUT = os.environ.get("OPENGAUSS_POOL_TIMEOUT", 30)
if OPENGAUSS_POOL_TIMEOUT == "":
OPENGAUSS_POOL_TIMEOUT = 30
else:
try:
OPENGAUSS_POOL_TIMEOUT = int(OPENGAUSS_POOL_TIMEOUT)
except Exception:
OPENGAUSS_POOL_TIMEOUT = 30
OPENGAUSS_POOL_RECYCLE = os.environ.get("OPENGAUSS_POOL_RECYCLE", 3600)
if OPENGAUSS_POOL_RECYCLE == "":
OPENGAUSS_POOL_RECYCLE = 3600
else:
try:
OPENGAUSS_POOL_RECYCLE = int(OPENGAUSS_POOL_RECYCLE)
except Exception:
OPENGAUSS_POOL_RECYCLE = 3600
# Pinecone
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)

View File

@@ -0,0 +1,409 @@
from typing import Optional, List, Dict, Any
import logging
import re
import json
from sqlalchemy import (
func,
literal,
cast,
column,
create_engine,
Column,
Integer,
MetaData,
LargeBinary,
select,
text,
Text,
Table,
values,
)
from sqlalchemy.sql import true
from sqlalchemy.pool import NullPool, QueuePool
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
from sqlalchemy.dialects.postgresql import JSONB, array
from pgvector.sqlalchemy import Vector
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.dialects import registry
class OpenGaussDialect(PGDialect_psycopg2):
name = "opengauss"
def _get_server_version_info(self, connection):
try:
version = connection.exec_driver_sql("SELECT version()").scalar()
if not version:
return (9, 0, 0)
match = re.search(
r"openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?",
version,
re.IGNORECASE
)
if match:
return (int(match.group(1)), int(match.group(2)), int(match.group(3)))
return super()._get_server_version_info(connection)
except Exception:
return (9, 0, 0)
# Register dialect
registry.register("opengauss", __name__, "OpenGaussDialect")
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
OPENGAUSS_DB_URL,
OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH,
OPENGAUSS_POOL_SIZE,
OPENGAUSS_POOL_MAX_OVERFLOW,
OPENGAUSS_POOL_TIMEOUT,
OPENGAUSS_POOL_RECYCLE,
)
from open_webui.env import SRC_LOG_LEVELS
VECTOR_LENGTH = OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH
Base = declarative_base()
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class DocumentChunk(Base):
__tablename__ = "document_chunk"
id = Column(Text, primary_key=True)
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
collection_name = Column(Text, nullable=False)
text = Column(Text, nullable=True)
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
class OpenGaussClient(VectorDBBase):
def __init__(self) -> None:
if not OPENGAUSS_DB_URL:
from open_webui.internal.db import Session
self.session = Session
else:
engine_kwargs = {
"pool_pre_ping": True,
"dialect": OpenGaussDialect()
}
if isinstance(OPENGAUSS_POOL_SIZE, int) and OPENGAUSS_POOL_SIZE > 0:
engine_kwargs.update({
"pool_size": OPENGAUSS_POOL_SIZE,
"max_overflow": OPENGAUSS_POOL_MAX_OVERFLOW,
"pool_timeout": OPENGAUSS_POOL_TIMEOUT,
"pool_recycle": OPENGAUSS_POOL_RECYCLE,
"poolclass": QueuePool
})
else:
engine_kwargs["poolclass"] = NullPool
engine = create_engine(OPENGAUSS_DB_URL,** engine_kwargs)
SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)
self.session = scoped_session(SessionLocal)
try:
connection = self.session.connection()
Base.metadata.create_all(bind=connection)
self.session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
"ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
)
)
self.session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
"ON document_chunk (collection_name);"
)
)
self.session.commit()
log.info("OpenGauss vector database initialization completed.")
except Exception as e:
self.session.rollback()
log.exception(f"OpenGauss Initialization failed.: {e}")
raise
def check_vector_length(self) -> None:
metadata = MetaData()
try:
document_chunk_table = Table(
"document_chunk", metadata, autoload_with=self.session.bind
)
except NoSuchTableError:
return
if "vector" in document_chunk_table.columns:
vector_column = document_chunk_table.columns["vector"]
vector_type = vector_column.type
if isinstance(vector_type, Vector):
db_vector_length = vector_type.dim
if db_vector_length != VECTOR_LENGTH:
raise Exception(
f"Vector dimension mismatch: configured {VECTOR_LENGTH} vs. {db_vector_length} in the database."
)
else:
raise Exception("The 'vector' column type is not Vector.")
else:
raise Exception("The 'vector' column does not exist in the 'document_chunk' table.")
def adjust_vector_length(self, vector: List[float]) -> List[float]:
current_length = len(vector)
if current_length < VECTOR_LENGTH:
vector += [0.0] * (VECTOR_LENGTH - current_length)
elif current_length > VECTOR_LENGTH:
vector = vector[:VECTOR_LENGTH]
return vector
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
new_items = []
for item in items:
vector = self.adjust_vector_length(item["vector"])
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=process_metadata(item["metadata"]),
)
new_items.append(new_chunk)
self.session.bulk_save_objects(new_items)
self.session.commit()
log.info(f"Inserting {len(new_items)} items into collection '{collection_name}'.")
except Exception as e:
self.session.rollback()
log.exception(f"Failed to insert data: {e}")
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
for item in items:
vector = self.adjust_vector_length(item["vector"])
existing = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.id == item["id"])
.first()
)
if existing:
existing.vector = vector
existing.text = item["text"]
existing.vmetadata = process_metadata(item["metadata"])
existing.collection_name = collection_name
else:
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=process_metadata(item["metadata"]),
)
self.session.add(new_chunk)
self.session.commit()
log.info(f"Inserting/updating {len(items)} items in collection '{collection_name}'.")
except Exception as e:
self.session.rollback()
log.exception(f"Failed to insert or update data.: {e}")
raise
def search(
self,
collection_name: str,
vectors: List[List[float]],
limit: Optional[int] = None,
) -> Optional[SearchResult]:
try:
if not vectors:
return None
vectors = [self.adjust_vector_length(vector) for vector in vectors]
num_queries = len(vectors)
def vector_expr(vector):
return cast(array(vector), Vector(VECTOR_LENGTH))
qid_col = column("qid", Integer)
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
query_vectors = (
values(qid_col, q_vector_col)
.data([(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)])
.alias("query_vectors")
)
result_fields = [
DocumentChunk.id,
DocumentChunk.text,
DocumentChunk.vmetadata,
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label("distance"),
]
subq = (
select(*result_fields)
.where(DocumentChunk.collection_name == collection_name)
.order_by(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
)
if limit is not None:
subq = subq.limit(limit)
subq = subq.lateral("result")
stmt = (
select(
query_vectors.c.qid,
subq.c.id,
subq.c.text,
subq.c.vmetadata,
subq.c.distance,
)
.select_from(query_vectors)
.join(subq, true())
.order_by(query_vectors.c.qid, subq.c.distance)
)
result_proxy = self.session.execute(stmt)
results = result_proxy.all()
ids = [[] for _ in range(num_queries)]
distances = [[] for _ in range(num_queries)]
documents = [[] for _ in range(num_queries)]
metadatas = [[] for _ in range(num_queries)]
for row in results:
qid = int(row.qid)
ids[qid].append(row.id)
distances[qid].append((2.0 - row.distance) / 2.0)
documents[qid].append(row.text)
metadatas[qid].append(row.vmetadata)
self.session.rollback()
return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
except Exception as e:
self.session.rollback()
log.exception(f"Vector search failed: {e}")
return None
def query(
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
) -> Optional[GetResult]:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
for key, value in filter.items():
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
if limit is not None:
query = query.limit(limit)
results = query.all()
if not results:
return None
ids = [[result.id for result in results]]
documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]]
self.session.rollback()
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
self.session.rollback()
log.exception(f"Conditional query failed: {e}")
return None
def get(
self, collection_name: str, limit: Optional[int] = None
) -> Optional[GetResult]:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if limit is not None:
query = query.limit(limit)
results = query.all()
if not results:
return None
ids = [[result.id for result in results]]
documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]]
self.session.rollback()
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
self.session.rollback()
log.exception(f"Failed to retrieve data: {e}")
return None
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict[str, Any]] = None,
) -> None:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if ids:
query = query.filter(DocumentChunk.id.in_(ids))
if filter:
for key, value in filter.items():
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
deleted = query.delete(synchronize_session=False)
self.session.commit()
log.info(f"Deleted {deleted} items from collection '{collection_name}'")
except Exception as e:
self.session.rollback()
log.exception(f"Failed to delete data: {e}")
raise
def reset(self) -> None:
try:
deleted = self.session.query(DocumentChunk).delete()
self.session.commit()
log.info(f"Reset completed. Deleted {deleted} items")
except Exception as e:
self.session.rollback()
log.exception(f"Reset failed: {e}")
raise
def close(self) -> None:
pass
def has_collection(self, collection_name: str) -> bool:
try:
exists = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.collection_name == collection_name)
.first() is not None
)
self.session.rollback()
return exists
except Exception as e:
self.session.rollback()
log.exception(f"Failed to check collection existence: {e}")
return False
def delete_collection(self, collection_name: str) -> None:
self.delete(collection_name)
log.info(f"Collection '{collection_name}' has been deleted")

View File

@@ -53,6 +53,10 @@ class Vector:
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
return PgvectorClient()
case VectorType.OPENGAUSS:
from open_webui.retrieval.vector.dbs.opengauss import OpenGaussClient
return OpenGaussClient()
case VectorType.ELASTICSEARCH:
from open_webui.retrieval.vector.dbs.elasticsearch import (
ElasticsearchClient,

View File

@@ -12,3 +12,4 @@ class VectorType(StrEnum):
ORACLE23AI = "oracle23ai"
S3VECTOR = "s3vector"
WEAVIATE = "weaviate"
OPENGAUSS = "opengauss"