mirror of
https://github.com/Cinnamon/kotaemon.git
synced 2026-02-23 19:49:37 +01:00
feat: add SSO login
This commit is contained in:
@@ -65,6 +65,8 @@ os.environ["HF_HUB_CACHE"] = str(KH_APP_DATA_DIR / "huggingface")
|
||||
KH_DOC_DIR = this_dir / "docs"
|
||||
|
||||
KH_MODE = "dev"
|
||||
KH_SSO_ENABLED = config("KH_SSO_ENABLED", default=False, cast=bool)
|
||||
|
||||
KH_FEATURE_CHAT_SUGGESTION = config(
|
||||
"KH_FEATURE_CHAT_SUGGESTION", default=False, cast=bool
|
||||
)
|
||||
@@ -145,7 +147,7 @@ if config("OPENAI_API_KEY", default=""):
|
||||
"base_url": config("OPENAI_API_BASE", default="")
|
||||
or "https://api.openai.com/v1",
|
||||
"api_key": config("OPENAI_API_KEY", default=""),
|
||||
"model": config("OPENAI_CHAT_MODEL", default="gpt-3.5-turbo"),
|
||||
"model": config("OPENAI_CHAT_MODEL", default="gpt-4o-mini"),
|
||||
"timeout": 20,
|
||||
},
|
||||
"default": True,
|
||||
@@ -156,7 +158,7 @@ if config("OPENAI_API_KEY", default=""):
|
||||
"base_url": config("OPENAI_API_BASE", default="https://api.openai.com/v1"),
|
||||
"api_key": config("OPENAI_API_KEY", default=""),
|
||||
"model": config(
|
||||
"OPENAI_EMBEDDINGS_MODEL", default="text-embedding-ada-002"
|
||||
"OPENAI_EMBEDDINGS_MODEL", default="text-embedding-3-large"
|
||||
),
|
||||
"timeout": 10,
|
||||
"context_length": 8191,
|
||||
@@ -323,7 +325,7 @@ GRAPHRAG_INDICES = [
|
||||
".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, "
|
||||
".pptx, .csv, .html, .mhtml, .txt, .md, .zip"
|
||||
),
|
||||
"private": False,
|
||||
"private": True,
|
||||
},
|
||||
"index_type": graph_type,
|
||||
}
|
||||
@@ -338,7 +340,7 @@ KH_INDICES = [
|
||||
".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, "
|
||||
".pptx, .csv, .html, .mhtml, .txt, .md, .zip"
|
||||
),
|
||||
"private": False,
|
||||
"private": True,
|
||||
},
|
||||
"index_type": "ktem.index.file.FileIndex",
|
||||
},
|
||||
|
||||
@@ -13,7 +13,7 @@ from ktem.settings import BaseSettingGroup, SettingGroup, SettingReasoningGroup
|
||||
from theflow.settings import settings
|
||||
from theflow.utils.modules import import_dotted_string
|
||||
|
||||
BASE_PATH = os.environ.get("GRADIO_ROOT_PATH", "")
|
||||
BASE_PATH = os.environ.get("GR_FILE_ROOT_PATH", "")
|
||||
|
||||
|
||||
class BaseApp:
|
||||
@@ -57,7 +57,7 @@ class BaseApp:
|
||||
self._pdf_view_js = self._pdf_view_js.replace(
|
||||
"PDFJS_PREBUILT_DIR",
|
||||
pdf_js_dist_dir,
|
||||
).replace("GRADIO_ROOT_PATH", BASE_PATH)
|
||||
).replace("GR_FILE_ROOT_PATH", BASE_PATH)
|
||||
with (dir_assets / "js" / "svg-pan-zoom.min.js").open() as fi:
|
||||
self._svg_js = fi.read()
|
||||
|
||||
@@ -79,7 +79,7 @@ class BaseApp:
|
||||
self.default_settings.index.finalize()
|
||||
self.settings_state = gr.State(self.default_settings.flatten())
|
||||
|
||||
self.user_id = gr.State(1 if not self.f_user_management else None)
|
||||
self.user_id = gr.State(None)
|
||||
|
||||
def initialize_indices(self):
|
||||
"""Create the index manager, start indices, and register to app settings"""
|
||||
|
||||
@@ -11,6 +11,14 @@ function run() {
|
||||
version_node.style = "position: fixed; top: 10px; right: 10px;";
|
||||
main_parent.appendChild(version_node);
|
||||
|
||||
// add favicon
|
||||
const favicon = document.createElement("link");
|
||||
// set favicon attributes
|
||||
favicon.rel = "icon";
|
||||
favicon.type = "image/svg+xml";
|
||||
favicon.href = "/favicon.svg";
|
||||
document.head.appendChild(favicon);
|
||||
|
||||
// move info-expand-button
|
||||
let info_expand_button = document.getElementById("info-expand-button");
|
||||
let chat_info_panel = document.getElementById("info-expand");
|
||||
|
||||
@@ -17,7 +17,7 @@ function onBlockLoad () {
|
||||
<span class="close" id="modal-expand">⛶</span>
|
||||
</div>
|
||||
<div class="modal-body">
|
||||
<pdfjs-viewer-element id="pdf-viewer" viewer-path="GRADIO_ROOT_PATH/file=PDFJS_PREBUILT_DIR" locale="en" phrase="true">
|
||||
<pdfjs-viewer-element id="pdf-viewer" viewer-path="GR_FILE_ROOT_PATH/file=PDFJS_PREBUILT_DIR" locale="en" phrase="true">
|
||||
</pdfjs-viewer-element>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -55,7 +55,7 @@ class BaseUser(SQLModel):
|
||||
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
id: Optional[str] = Field(default=None, primary_key=True)
|
||||
username: str = Field(unique=True)
|
||||
username_lower: str = Field(unique=True)
|
||||
password: str
|
||||
|
||||
@@ -9,6 +9,7 @@ from ktem.pages.setup import SetupPage
|
||||
from theflow.settings import settings as flowsettings
|
||||
|
||||
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
|
||||
KH_SSO_ENABLED = getattr(flowsettings, "KH_SSO_ENABLED", False)
|
||||
KH_ENABLE_FIRST_SETUP = getattr(flowsettings, "KH_ENABLE_FIRST_SETUP", False)
|
||||
KH_APP_DATA_EXISTS = getattr(flowsettings, "KH_APP_DATA_EXISTS", True)
|
||||
|
||||
@@ -90,14 +91,15 @@ class App(BaseApp):
|
||||
page = index.get_index_page_ui()
|
||||
setattr(self, f"_index_{index.id}", page)
|
||||
|
||||
with gr.Tab(
|
||||
"Resources",
|
||||
elem_id="resources-tab",
|
||||
id="resources-tab",
|
||||
visible=not self.f_user_management,
|
||||
elem_classes=["fill-main-area-height", "scrollable"],
|
||||
) as self._tabs["resources-tab"]:
|
||||
self.resources_page = ResourcesTab(self)
|
||||
if not KH_SSO_ENABLED:
|
||||
with gr.Tab(
|
||||
"Resources",
|
||||
elem_id="resources-tab",
|
||||
id="resources-tab",
|
||||
visible=not self.f_user_management,
|
||||
elem_classes=["fill-main-area-height", "scrollable"],
|
||||
) as self._tabs["resources-tab"]:
|
||||
self.resources_page = ResourcesTab(self)
|
||||
|
||||
with gr.Tab(
|
||||
"Settings",
|
||||
|
||||
@@ -3,6 +3,7 @@ import hashlib
|
||||
import gradio as gr
|
||||
from ktem.app import BasePage
|
||||
from ktem.db.models import User, engine
|
||||
from ktem.pages.resources.user import create_user
|
||||
from sqlmodel import Session, select
|
||||
|
||||
fetch_creds = """
|
||||
@@ -85,19 +86,44 @@ class LoginPage(BasePage):
|
||||
},
|
||||
)
|
||||
|
||||
def login(self, usn, pwd):
|
||||
if not usn or not pwd:
|
||||
return None, usn, pwd
|
||||
def login(self, usn, pwd, request: gr.Request):
|
||||
import gradiologin as grlogin
|
||||
|
||||
user = grlogin.get_user(request)
|
||||
|
||||
if user:
|
||||
user_id = user["sub"]
|
||||
with Session(engine) as session:
|
||||
stmt = select(User).where(
|
||||
User.id == user_id,
|
||||
)
|
||||
result = session.exec(stmt).all()
|
||||
|
||||
hashed_password = hashlib.sha256(pwd.encode()).hexdigest()
|
||||
with Session(engine) as session:
|
||||
stmt = select(User).where(
|
||||
User.username_lower == usn.lower().strip(),
|
||||
User.password == hashed_password,
|
||||
)
|
||||
result = session.exec(stmt).all()
|
||||
if result:
|
||||
return result[0].id, "", ""
|
||||
print("Existing user:", user)
|
||||
return user_id, "", ""
|
||||
else:
|
||||
print("Creating new user:", user)
|
||||
create_user(
|
||||
usn=user["email"],
|
||||
pwd="",
|
||||
user_id=user_id,
|
||||
is_admin=False,
|
||||
)
|
||||
return user_id, "", ""
|
||||
else:
|
||||
if not usn or not pwd:
|
||||
return None, usn, pwd
|
||||
|
||||
gr.Warning("Invalid username or password")
|
||||
return None, usn, pwd
|
||||
hashed_password = hashlib.sha256(pwd.encode()).hexdigest()
|
||||
with Session(engine) as session:
|
||||
stmt = select(User).where(
|
||||
User.username_lower == usn.lower().strip(),
|
||||
User.password == hashed_password,
|
||||
)
|
||||
result = session.exec(stmt).all()
|
||||
if result:
|
||||
return result[0].id, "", ""
|
||||
|
||||
gr.Warning("Invalid username or password")
|
||||
return None, usn, pwd
|
||||
|
||||
@@ -94,7 +94,7 @@ def validate_password(pwd, pwd_cnf):
|
||||
return ""
|
||||
|
||||
|
||||
def create_user(usn, pwd) -> bool:
|
||||
def create_user(usn, pwd, user_id=None, is_admin=True) -> bool:
|
||||
with Session(engine) as session:
|
||||
statement = select(User).where(User.username_lower == usn.lower())
|
||||
result = session.exec(statement).all()
|
||||
@@ -105,10 +105,11 @@ def create_user(usn, pwd) -> bool:
|
||||
else:
|
||||
hashed_password = hashlib.sha256(pwd.encode()).hexdigest()
|
||||
user = User(
|
||||
id=user_id,
|
||||
username=usn,
|
||||
username_lower=usn.lower(),
|
||||
password=hashed_password,
|
||||
admin=True,
|
||||
admin=is_admin,
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
@@ -5,6 +5,10 @@ from ktem.app import BasePage
|
||||
from ktem.components import reasonings
|
||||
from ktem.db.models import Settings, User, engine
|
||||
from sqlmodel import Session, select
|
||||
from theflow.settings import settings as flowsettings
|
||||
|
||||
KH_SSO_ENABLED = getattr(flowsettings, "KH_SSO_ENABLED", False)
|
||||
|
||||
|
||||
signout_js = """
|
||||
function(u, c, pw, pwc) {
|
||||
@@ -80,38 +84,44 @@ class SettingsPage(BasePage):
|
||||
|
||||
# render application page if there are application settings
|
||||
self._render_app_tab = False
|
||||
if self._default_settings.application.settings:
|
||||
|
||||
if not KH_SSO_ENABLED and self._default_settings.application.settings:
|
||||
self._render_app_tab = True
|
||||
|
||||
# render index page if there are index settings (general and/or specific)
|
||||
self._render_index_tab = False
|
||||
if self._default_settings.index.settings:
|
||||
self._render_index_tab = True
|
||||
else:
|
||||
for sig in self._default_settings.index.options.values():
|
||||
if sig.settings:
|
||||
self._render_index_tab = True
|
||||
break
|
||||
|
||||
if not KH_SSO_ENABLED:
|
||||
if self._default_settings.index.settings:
|
||||
self._render_index_tab = True
|
||||
else:
|
||||
for sig in self._default_settings.index.options.values():
|
||||
if sig.settings:
|
||||
self._render_index_tab = True
|
||||
break
|
||||
|
||||
# render reasoning page if there are reasoning settings
|
||||
self._render_reasoning_tab = False
|
||||
if len(self._default_settings.reasoning.settings) > 1:
|
||||
self._render_reasoning_tab = True
|
||||
else:
|
||||
for sig in self._default_settings.reasoning.options.values():
|
||||
if sig.settings:
|
||||
self._render_reasoning_tab = True
|
||||
break
|
||||
|
||||
if not KH_SSO_ENABLED:
|
||||
if len(self._default_settings.reasoning.settings) > 1:
|
||||
self._render_reasoning_tab = True
|
||||
else:
|
||||
for sig in self._default_settings.reasoning.options.values():
|
||||
if sig.settings:
|
||||
self._render_reasoning_tab = True
|
||||
break
|
||||
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
self.setting_save_btn = gr.Button(
|
||||
"Save & Close",
|
||||
variant="primary",
|
||||
elem_classes=["right-button"],
|
||||
elem_id="save-setting-btn",
|
||||
)
|
||||
if not KH_SSO_ENABLED:
|
||||
self.setting_save_btn = gr.Button(
|
||||
"Save & Close",
|
||||
variant="primary",
|
||||
elem_classes=["right-button"],
|
||||
elem_id="save-setting-btn",
|
||||
)
|
||||
if self._app.f_user_management:
|
||||
with gr.Tab("User settings"):
|
||||
self.user_tab()
|
||||
@@ -175,21 +185,22 @@ class SettingsPage(BasePage):
|
||||
)
|
||||
|
||||
def on_register_events(self):
|
||||
self.setting_save_btn.click(
|
||||
self.save_setting,
|
||||
inputs=[self._user_id] + self.components(),
|
||||
outputs=self._settings_state,
|
||||
).then(
|
||||
lambda: gr.Tabs(selected="chat-tab"),
|
||||
outputs=self._app.tabs,
|
||||
)
|
||||
if not KH_SSO_ENABLED:
|
||||
self.setting_save_btn.click(
|
||||
self.save_setting,
|
||||
inputs=[self._user_id] + self.components(),
|
||||
outputs=self._settings_state,
|
||||
).then(
|
||||
lambda: gr.Tabs(selected="chat-tab"),
|
||||
outputs=self._app.tabs,
|
||||
)
|
||||
self._components["reasoning.use"].change(
|
||||
self.change_reasoning_mode,
|
||||
inputs=[self._components["reasoning.use"]],
|
||||
outputs=list(self._reasoning_mode.values()),
|
||||
show_progress="hidden",
|
||||
)
|
||||
if self._app.f_user_management:
|
||||
if self._app.f_user_management and not KH_SSO_ENABLED:
|
||||
self.password_change_btn.click(
|
||||
self.change_password,
|
||||
inputs=[
|
||||
@@ -223,15 +234,21 @@ class SettingsPage(BasePage):
|
||||
def user_tab(self):
|
||||
# user management
|
||||
self.current_name = gr.Markdown("Current user: ___")
|
||||
self.signout = gr.Button("Logout")
|
||||
|
||||
self.password_change = gr.Textbox(
|
||||
label="New password", interactive=True, type="password"
|
||||
)
|
||||
self.password_change_confirm = gr.Textbox(
|
||||
label="Confirm password", interactive=True, type="password"
|
||||
)
|
||||
self.password_change_btn = gr.Button("Change password", interactive=True)
|
||||
if KH_SSO_ENABLED:
|
||||
import gradiologin as grlogin
|
||||
|
||||
self.sso_singout = grlogin.LogoutButton("Logout")
|
||||
else:
|
||||
self.signout = gr.Button("Logout")
|
||||
|
||||
self.password_change = gr.Textbox(
|
||||
label="New password", interactive=True, type="password"
|
||||
)
|
||||
self.password_change_confirm = gr.Textbox(
|
||||
label="Confirm password", interactive=True, type="password"
|
||||
)
|
||||
self.password_change_btn = gr.Button("Change password", interactive=True)
|
||||
|
||||
def change_password(self, user_id, password, password_confirm):
|
||||
from ktem.pages.resources.user import validate_password
|
||||
|
||||
@@ -5,7 +5,7 @@ from fast_langdetect import detect
|
||||
|
||||
from kotaemon.base import RetrievedDocument
|
||||
|
||||
BASE_PATH = os.environ.get("GRADIO_ROOT_PATH", "")
|
||||
BASE_PATH = os.environ.get("GR_FILE_ROOT_PATH", "")
|
||||
|
||||
|
||||
def is_close(val1, val2, tolerance=1e-9):
|
||||
|
||||
51
sso_app.py
Normal file
51
sso_app.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
|
||||
import gradiologin as grlogin
|
||||
from decouple import config
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import FileResponse
|
||||
from theflow.settings import settings as flowsettings
|
||||
|
||||
KH_APP_DATA_DIR = getattr(flowsettings, "KH_APP_DATA_DIR", ".")
|
||||
GRADIO_TEMP_DIR = os.getenv("GRADIO_TEMP_DIR", None)
|
||||
# override GRADIO_TEMP_DIR if it's not set
|
||||
if GRADIO_TEMP_DIR is None:
|
||||
GRADIO_TEMP_DIR = os.path.join(KH_APP_DATA_DIR, "gradio_tmp")
|
||||
os.environ["GRADIO_TEMP_DIR"] = GRADIO_TEMP_DIR
|
||||
|
||||
|
||||
GOOGLE_CLIENT_ID = config("GOOGLE_CLIENT_ID", default="")
|
||||
GOOGLE_CLIENT_SECRET = config("GOOGLE_CLIENT_SECRET", default="")
|
||||
|
||||
|
||||
from ktem.main import App # noqa
|
||||
|
||||
gradio_app = App()
|
||||
demo = gradio_app.make()
|
||||
|
||||
app = FastAPI()
|
||||
grlogin.register(
|
||||
name="google",
|
||||
server_metadata_url="https://accounts.google.com/.well-known/openid-configuration",
|
||||
client_id=GOOGLE_CLIENT_ID,
|
||||
client_secret=GOOGLE_CLIENT_SECRET,
|
||||
client_kwargs={
|
||||
"scope": "openid email profile",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@app.get("/favicon.svg", include_in_schema=False)
|
||||
async def favicon():
|
||||
return FileResponse(gradio_app._favicon)
|
||||
|
||||
|
||||
grlogin.mount_gradio_app(
|
||||
app,
|
||||
demo,
|
||||
"/app",
|
||||
allowed_paths=[
|
||||
"libs/ktem/ktem/assets",
|
||||
GRADIO_TEMP_DIR,
|
||||
],
|
||||
)
|
||||
Reference in New Issue
Block a user