Merge pull request #2574 from cheahjs/feat/oauth

feat: experimental SSO support for Google, Microsoft, and OIDC
This commit is contained in:
Timothy Jaeryang Baek
2024-06-24 19:05:58 -07:00
committed by GitHub
52 changed files with 633 additions and 13 deletions

View File

@@ -0,0 +1,49 @@
"""Peewee migrations -- 017_add_user_oauth_sub.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields(
"user",
oauth_sub=pw.TextField(null=True, unique=True),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("user", "oauth_sub")

View File

@@ -2,6 +2,8 @@ from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware
from apps.webui.routers import (
auths,
users,

View File

@@ -105,6 +105,7 @@ class AuthsTable:
name: str,
profile_image_url: str = "/user.png",
role: str = "pending",
oauth_sub: Optional[str] = None,
) -> Optional[UserModel]:
log.info("insert_new_auth")
@@ -115,7 +116,9 @@ class AuthsTable:
)
result = Auth.create(**auth.model_dump())
user = Users.insert_new_user(id, name, email, profile_image_url, role)
user = Users.insert_new_user(
id, name, email, profile_image_url, role, oauth_sub
)
if result and user:
return user

View File

@@ -28,6 +28,8 @@ class User(Model):
settings = JSONField(null=True)
info = JSONField(null=True)
oauth_sub = TextField(null=True, unique=True)
class Meta:
database = DB
@@ -53,6 +55,8 @@ class UserModel(BaseModel):
settings: Optional[UserSettings] = None
info: Optional[dict] = None
oauth_sub: Optional[str] = None
####################
# Forms
@@ -83,6 +87,7 @@ class UsersTable:
email: str,
profile_image_url: str = "/user.png",
role: str = "pending",
oauth_sub: Optional[str] = None,
) -> Optional[UserModel]:
user = UserModel(
**{
@@ -94,6 +99,7 @@ class UsersTable:
"last_active_at": int(time.time()),
"created_at": int(time.time()),
"updated_at": int(time.time()),
"oauth_sub": oauth_sub,
}
)
result = User.create(**user.model_dump())
@@ -123,6 +129,13 @@ class UsersTable:
except:
return None
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try:
user = User.get(User.oauth_sub == sub)
return UserModel(**model_to_dict(user))
except:
return None
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
return [
UserModel(**model_to_dict(user))
@@ -174,6 +187,18 @@ class UsersTable:
except:
return None
def update_user_oauth_sub_by_id(
self, id: str, oauth_sub: str
) -> Optional[UserModel]:
try:
query = User.update(oauth_sub=oauth_sub).where(User.id == id)
query.execute()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
except:
return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try:
query = User.update(**updated).where(User.id == id)

View File

@@ -10,7 +10,6 @@ import re
import uuid
import csv
from apps.webui.models.auths import (
SigninForm,
SignupForm,