From 40d094869fb214d7da8dc502d3233d34e607a5fe Mon Sep 17 00:00:00 2001 From: Ylber Gashi Date: Mon, 17 Feb 2025 16:33:39 +0100 Subject: [PATCH] Add user-based access control to document and node retrieval services --- apps/server/src/jobs/assistant-response.ts | 3 +- apps/server/src/jobs/embed-document.ts | 39 ++--- apps/server/src/jobs/embed-node.ts | 24 +--- apps/server/src/services/chunking-service.ts | 133 +++++++++++------- .../services/document-retrieval-service.ts | 25 +++- apps/server/src/services/llm-service.ts | 113 ++++++++------- .../src/services/node-retrieval-service.ts | 13 +- 7 files changed, 194 insertions(+), 156 deletions(-) diff --git a/apps/server/src/jobs/assistant-response.ts b/apps/server/src/jobs/assistant-response.ts index 37ff9cca..60e3de1d 100644 --- a/apps/server/src/jobs/assistant-response.ts +++ b/apps/server/src/jobs/assistant-response.ts @@ -110,7 +110,8 @@ export const assistantResponseHandler = async ( ); const documentDocs = await documentRetrievalService.retrieve( rewrittenQuery, - workspaceId + workspaceId, + user.id ); const allContext = [...nodeDocs, ...documentDocs]; const reranked = await rerankDocuments( diff --git a/apps/server/src/jobs/embed-document.ts b/apps/server/src/jobs/embed-document.ts index 2001f4b8..a7ecb49c 100644 --- a/apps/server/src/jobs/embed-document.ts +++ b/apps/server/src/jobs/embed-document.ts @@ -2,7 +2,7 @@ import { OpenAIEmbeddings } from '@langchain/openai'; import { ChunkingService } from '@/services/chunking-service'; import { database } from '@/data/database'; import { configuration } from '@/lib/configuration'; -import { CreateDocumentEmbedding } from '@/data/schema'; +import { CreateDocumentEmbedding, SelectNode } from '@/data/schema'; import { sql } from 'kysely'; import { fetchNode } from '@/lib/nodes'; import { DocumentContent, extractBlockTexts } from '@colanode/core'; @@ -22,47 +22,26 @@ declare module '@/types/jobs' { } const extractDocumentText = async ( - documentId: string, + node: SelectNode, content: DocumentContent ): Promise => { - const sections: string[] = []; - - const node = await fetchNode(documentId); - if (!node) { - return ''; - } - const nodeModel = getNodeModel(node.attributes.type); if (!nodeModel) { return ''; } - const nodeName = nodeModel.getName(documentId, node.attributes); - if (nodeName) { - sections.push(`${node.attributes.type} "${nodeName}"`); - } - - const attributesText = nodeModel.getAttributesText( - documentId, - node.attributes - ); - if (attributesText) { - sections.push(attributesText); - } - - const documentText = nodeModel.getDocumentText(documentId, content); + const documentText = nodeModel.getDocumentText(node.id, content); if (documentText) { - sections.push(documentText); + return documentText; } else { // Fallback to block text extraction if the model doesn't handle it - const blocksText = extractBlockTexts(documentId, content.blocks); + const blocksText = extractBlockTexts(node.id, content.blocks); if (blocksText) { - sections.push('Content:'); - sections.push(blocksText); + return blocksText; } } - return sections.filter(Boolean).join('\n\n'); + return ''; }; export const embedDocumentHandler = async (input: { @@ -93,7 +72,7 @@ export const embedDocumentHandler = async (input: { return; } - const text = await extractDocumentText(documentId, document.content); + const text = await extractDocumentText(node, document.content); if (!text || text.trim() === '') { await database .deleteFrom('document_embeddings') @@ -105,7 +84,7 @@ export const embedDocumentHandler = async (input: { const chunkingService = new ChunkingService(); const chunks = await chunkingService.chunkText(text, { type: 'document', - id: documentId, + node: node, }); const embeddings = new OpenAIEmbeddings({ apiKey: configuration.ai.embedding.apiKey, diff --git a/apps/server/src/jobs/embed-node.ts b/apps/server/src/jobs/embed-node.ts index 9229bdfd..2927f7f0 100644 --- a/apps/server/src/jobs/embed-node.ts +++ b/apps/server/src/jobs/embed-node.ts @@ -26,33 +26,15 @@ const extractNodeText = async ( ): Promise => { if (!node) return ''; - // Get the node model to use its text extraction methods const nodeModel = getNodeModel(node.attributes.type); if (!nodeModel) return ''; - const sections: string[] = []; - - // Get the node's name if available - const nodeName = nodeModel.getName(nodeId, node.attributes); - if (nodeName) { - sections.push(`${node.attributes.type} "${nodeName}"`); - } - - // Get text from attributes (this handles message content, record fields, etc.) const attributesText = nodeModel.getAttributesText(nodeId, node.attributes); if (attributesText) { - sections.push(attributesText); + return attributesText; } - // For records, add database context - if (node.attributes.type === 'record') { - const databaseNode = await fetchNode(node.attributes.databaseId); - if (databaseNode?.attributes.type === 'database') { - sections.push(`In database "${databaseNode.attributes.name}"`); - } - } - - return sections.filter(Boolean).join('\n'); + return ''; }; export const embedNodeHandler = async (input: { @@ -87,7 +69,7 @@ export const embedNodeHandler = async (input: { const chunkingService = new ChunkingService(); const chunks = await chunkingService.chunkText(text, { type: 'node', - id: nodeId, + node, }); const embeddings = new OpenAIEmbeddings({ apiKey: configuration.ai.embedding.apiKey, diff --git a/apps/server/src/services/chunking-service.ts b/apps/server/src/services/chunking-service.ts index 8e70eadf..91e11d01 100644 --- a/apps/server/src/services/chunking-service.ts +++ b/apps/server/src/services/chunking-service.ts @@ -12,7 +12,6 @@ import type { SelectNode, SelectDocument, SelectUser } from '@/data/schema'; type BaseMetadata = { id: string; - type: string; name?: string; createdAt: Date; createdBy: string; @@ -21,22 +20,23 @@ type BaseMetadata = { id: string; type: string; name?: string; + path?: string; }; - collaborators?: Array<{ id: string; name: string }>; + collaborators?: Array<{ id: string; name: string; role: string }>; + lastUpdated?: Date; + updatedBy?: { id: string; name: string }; + workspace?: { id: string; name: string }; }; -export type NodeMetadata = { +export type NodeMetadata = BaseMetadata & { type: 'node'; - metadata: BaseMetadata & { - fields?: Record | null; - }; + nodeType: string; + fields?: Record | null; }; -export type DocumentMetadata = { +export type DocumentMetadata = BaseMetadata & { type: 'document'; - metadata: BaseMetadata & { - content: DocumentContent; - }; + content: DocumentContent; }; export type ChunkingMetadata = NodeMetadata | DocumentMetadata; @@ -44,7 +44,7 @@ export type ChunkingMetadata = NodeMetadata | DocumentMetadata; export class ChunkingService { public async chunkText( text: string, - metadata?: { type: 'node' | 'document'; id: string } + metadata?: { type: 'node' | 'document'; node: SelectNode } ): Promise { const chunkSize = configuration.ai.chunking.defaultChunkSize; const chunkOverlap = configuration.ai.chunking.defaultOverlap; @@ -70,40 +70,25 @@ export class ChunkingService { private async fetchMetadata(metadata?: { type: 'node' | 'document'; - id: string; + node: SelectNode; }): Promise { if (!metadata) { return undefined; } if (metadata.type === 'node') { - const node = (await database - .selectFrom('nodes') - .selectAll() - .where('id', '=', metadata.id) - .executeTakeFirst()) as SelectNode | undefined; - if (!node) { - return undefined; - } - - return this.buildNodeMetadata(node); + return this.buildNodeMetadata(metadata.node); } else { const document = (await database .selectFrom('documents') .selectAll() - .where('id', '=', metadata.id) + .where('id', '=', metadata.node.id) .executeTakeFirst()) as SelectDocument | undefined; if (!document) { return undefined; } - const node = (await database - .selectFrom('nodes') - .selectAll() - .where('id', '=', document.id) - .executeTakeFirst()) as SelectNode | undefined; - - return this.buildDocumentMetadata(document, node); + return this.buildDocumentMetadata(document, metadata.node); } } @@ -138,10 +123,9 @@ export class ChunkingService { return { type: 'node', - metadata: { - ...baseMetadata, - fields: 'fields' in node.attributes ? node.attributes.fields : null, - }, + nodeType: node.attributes.type, + fields: 'fields' in node.attributes ? node.attributes.fields : null, + ...baseMetadata, }; } @@ -151,7 +135,6 @@ export class ChunkingService { ): Promise { let baseMetadata: BaseMetadata = { id: document.id, - type: 'document', createdAt: document.created_at, createdBy: document.created_by, }; @@ -172,14 +155,18 @@ export class ChunkingService { } } } + + return { + type: 'document', + content: document.content, + ...baseMetadata, + }; } return { type: 'document', - metadata: { - ...baseMetadata, - content: document.content, - }, + content: document.content, + ...baseMetadata, }; } @@ -187,30 +174,46 @@ export class ChunkingService { const nodeModel = getNodeModel(node.attributes.type); const nodeName = nodeModel?.getName(node.id, node.attributes); - const author = (await database + const author = await database .selectFrom('users') .select(['id', 'name']) .where('id', '=', node.created_by) - .executeTakeFirst()) as SelectUser | undefined; + .executeTakeFirst(); + + const updatedBy = node.updated_by + ? await database + .selectFrom('users') + .select(['id', 'name']) + .where('id', '=', node.updated_by) + .executeTakeFirst() + : undefined; + + const workspace = await database + .selectFrom('workspaces') + .select(['id', 'name']) + .where('id', '=', node.workspace_id) + .executeTakeFirst(); return { id: node.id, - type: node.attributes.type, - name: nodeName ?? undefined, + name: nodeName ?? '', createdAt: node.created_at, createdBy: node.created_by, author: author ?? undefined, + lastUpdated: node.updated_at ?? undefined, + updatedBy: updatedBy ?? undefined, + workspace: workspace ?? undefined, }; } private async buildParentContext( node: SelectNode ): Promise { - const parentNode = (await database + const parentNode = await database .selectFrom('nodes') .selectAll() .where('id', '=', node.parent_id) - .executeTakeFirst()) as SelectNode | undefined; + .executeTakeFirst(); if (!parentNode) { return undefined; @@ -226,24 +229,58 @@ export class ChunkingService { parentNode.attributes ); + // Get the full path by traversing up the tree + const pathNodes = await database + .selectFrom('node_paths') + .innerJoin('nodes', 'nodes.id', 'node_paths.ancestor_id') + .select(['nodes.id', 'nodes.attributes']) + .where('node_paths.descendant_id', '=', node.id) + .orderBy('node_paths.level', 'asc') + .execute(); + + const path = pathNodes + .map((n) => { + const model = getNodeModel(n.attributes.type); + return model?.getName(n.id, n.attributes) ?? 'Untitled'; + }) + .join(' / '); + return { id: parentNode.id, type: parentNode.attributes.type, name: parentName ?? undefined, + path, }; } private async fetchCollaborators( collaboratorIds: string[] - ): Promise> { + ): Promise> { if (!collaboratorIds.length) { return []; } - return database + const collaborators = await database .selectFrom('users') .select(['id', 'name']) .where('id', 'in', collaboratorIds) - .execute() as Promise>; + .execute(); + + // Get roles for each collaborator + return Promise.all( + collaborators.map(async (c) => { + const collaboration = await database + .selectFrom('collaborations') + .select(['role']) + .where('collaborator_id', '=', c.id) + .executeTakeFirst(); + + return { + id: c.id, + name: c.name, + role: collaboration?.role ?? 'unknown', + }; + }) + ); } } diff --git a/apps/server/src/services/document-retrieval-service.ts b/apps/server/src/services/document-retrieval-service.ts index 16cc7a2f..f48eee75 100644 --- a/apps/server/src/services/document-retrieval-service.ts +++ b/apps/server/src/services/document-retrieval-service.ts @@ -23,6 +23,7 @@ export class DocumentRetrievalService { public async retrieve( query: string, workspaceId: string, + userId: string, limit = configuration.ai.retrieval.hybridSearch.maxResults ): Promise { const embedding = await this.embeddings.embedQuery(query); @@ -30,20 +31,34 @@ export class DocumentRetrievalService { const semanticResults = await this.semanticSearch( embedding, workspaceId, + userId, + limit + ); + const keywordResults = await this.keywordSearch( + query, + workspaceId, + userId, limit ); - const keywordResults = await this.keywordSearch(query, workspaceId, limit); return this.combineSearchResults(semanticResults, keywordResults); } private async semanticSearch( embedding: number[], workspaceId: string, + userId: string, limit: number ): Promise { const results = await database .selectFrom('document_embeddings') .innerJoin('documents', 'documents.id', 'document_embeddings.document_id') + .innerJoin('nodes', 'nodes.id', 'documents.id') + .innerJoin('collaborations', (join) => + join + .onRef('collaborations.node_id', '=', 'nodes.root_id') + .on('collaborations.collaborator_id', '=', sql.lit(userId)) + .on('collaborations.deleted_at', 'is', null) + ) .select((eb) => [ 'document_embeddings.document_id as id', 'document_embeddings.text', @@ -77,11 +92,19 @@ export class DocumentRetrievalService { private async keywordSearch( query: string, workspaceId: string, + userId: string, limit: number ): Promise { const results = await database .selectFrom('document_embeddings') .innerJoin('documents', 'documents.id', 'document_embeddings.document_id') + .innerJoin('nodes', 'nodes.id', 'documents.id') + .innerJoin('collaborations', (join) => + join + .onRef('collaborations.node_id', '=', 'nodes.root_id') + .on('collaborations.collaborator_id', '=', sql.lit(userId)) + .on('collaborations.deleted_at', 'is', null) + ) .select((eb) => [ 'document_embeddings.document_id as id', 'document_embeddings.text', diff --git a/apps/server/src/services/llm-service.ts b/apps/server/src/services/llm-service.ts index d6723335..1598c5ab 100644 --- a/apps/server/src/services/llm-service.ts +++ b/apps/server/src/services/llm-service.ts @@ -7,15 +7,11 @@ import { HumanMessage } from '@langchain/core/messages'; import { configuration } from '@/lib/configuration'; import { Document } from '@langchain/core/documents'; import { z } from 'zod'; -import { NodeAttributes, NodeType } from '@colanode/core'; import type { ChunkingMetadata, NodeMetadata, - DocumentMetadata, } from '@/services/chunking-service'; -// Use proper Zod schemas and updated prompt templates - const rerankedDocumentsSchema = z.object({ rankings: z.array( z.object({ @@ -242,24 +238,13 @@ export async function assessUserIntent( : 'retrieve'; } -interface NodeContextData { - metadata: { - type: NodeType; - name?: string; - author?: { id: string; name: string }; - parentContext?: { - type: string; - name?: string; - }; - collaborators?: Array<{ id: string; name: string }>; - fields?: Record; - }; -} - const getNodeContextPrompt = (metadata: NodeMetadata): string => { const basePrompt = `Given the following context about a {nodeType}: Name: {name} -Created by: {authorName} +Created by: {authorName} on {createdAt} +Last updated: {lastUpdated} by {updatedByName} +Location: {path} +Workspace: {workspaceName} {additionalContext} Full content: @@ -274,42 +259,56 @@ Generate a brief (50-100 tokens) contextual prefix that: 3. Makes the chunk more understandable in isolation Do not repeat the chunk content. Return only the contextual prefix.`; - const getCollaboratorNames = ( - collaborators?: Array<{ id: string; name: string }> - ) => collaborators?.map((c) => c.name).join(', ') ?? 'unknown'; + const getCollaboratorInfo = ( + collaborators?: Array<{ id: string; name: string; role: string }> + ) => { + if (!collaborators?.length) return 'No collaborators'; + return collaborators.map((c) => `${c.name} (${c.role})`).join(', '); + }; - switch (metadata.metadata.type) { + const formatDate = (date?: Date) => { + if (!date) return 'unknown'; + return new Date(date).toLocaleString(); + }; + + switch (metadata.nodeType) { case 'message': return basePrompt.replace( '{additionalContext}', - `In: ${metadata.metadata.parentContext?.type ?? 'unknown'} "${metadata.metadata.parentContext?.name ?? 'unknown'}" -Participants: ${getCollaboratorNames(metadata.metadata.collaborators)}` + `In: ${metadata.parentContext?.type ?? 'unknown'} "${metadata.parentContext?.name ?? 'unknown'}" +Path: ${metadata.parentContext?.path ?? 'unknown'} +Participants: ${getCollaboratorInfo(metadata.collaborators)}` ); case 'record': return basePrompt.replace( '{additionalContext}', - `Database: ${metadata.metadata.parentContext?.name ?? 'unknown'} -Fields: ${Object.keys(metadata.metadata.fields ?? {}).join(', ')}` + `Database: ${metadata.parentContext?.name ?? 'unknown'} +Path: ${metadata.parentContext?.path ?? 'unknown'} +Fields: ${Object.keys(metadata.fields ?? {}).join(', ')}` ); case 'page': return basePrompt.replace( '{additionalContext}', - `Location: ${metadata.metadata.parentContext?.name ? `in ${metadata.metadata.parentContext.name}` : 'root level'}` + `Location: ${metadata.parentContext?.path ?? 'root level'} +Collaborators: ${getCollaboratorInfo(metadata.collaborators)}` ); case 'database': return basePrompt.replace( '{additionalContext}', - `Fields: ${Object.keys(metadata.metadata.fields ?? {}).join(', ')}` + `Path: ${metadata.parentContext?.path ?? 'root level'} +Fields: ${Object.keys(metadata.fields ?? {}).join(', ')} +Collaborators: ${getCollaboratorInfo(metadata.collaborators)}` ); case 'channel': return basePrompt.replace( '{additionalContext}', `Type: Channel -Members: ${getCollaboratorNames(metadata.metadata.collaborators)}` +Path: ${metadata.parentContext?.path ?? 'root level'} +Members: ${getCollaboratorInfo(metadata.collaborators)}` ); default: @@ -321,6 +320,11 @@ interface PromptVariables { nodeType: string; name: string; authorName: string; + createdAt: string; + lastUpdated: string; + updatedByName: string; + path: string; + workspaceName: string; fullText: string; chunk: string; [key: string]: string; @@ -330,8 +334,10 @@ const documentContextPrompt = PromptTemplate.fromTemplate( `Given the following context about a document: Type: {nodeType} Name: {name} -Parent: {parentName} -Created by: {authorName} +Location: {path} +Created by: {authorName} on {createdAt} +Last updated: {lastUpdated} by {updatedByName} +Workspace: {workspaceName} Full content: {fullText} @@ -362,31 +368,30 @@ export async function addContextToChunk( let prompt: string; let promptVars: PromptVariables; + const formatDate = (date?: Date) => { + if (!date) return 'unknown'; + return new Date(date).toLocaleString(); + }; + + const baseVars = { + nodeType: metadata.type === 'node' ? metadata.nodeType : metadata.type, + name: metadata.name ?? 'Untitled', + authorName: metadata.author?.name ?? 'Unknown', + createdAt: formatDate(metadata.createdAt), + lastUpdated: formatDate(metadata.lastUpdated), + updatedByName: metadata.updatedBy?.name ?? 'Unknown', + path: metadata.parentContext?.path ?? 'root level', + workspaceName: metadata.workspace?.name ?? 'Unknown Workspace', + fullText, + chunk, + }; + if (metadata.type === 'node') { prompt = getNodeContextPrompt(metadata); - promptVars = { - nodeType: metadata.metadata.type, - name: metadata.metadata.name ?? 'Untitled', - authorName: metadata.metadata.author?.name ?? 'Unknown', - fullText, - chunk, - }; + promptVars = baseVars; } else { - prompt = await documentContextPrompt.format({ - nodeType: metadata.metadata.type, - name: metadata.metadata.name ?? 'Untitled', - parentName: metadata.metadata.parentContext?.name ?? 'Unknown', - authorName: metadata.metadata.author?.name ?? 'Unknown', - fullText, - chunk, - }); - promptVars = { - nodeType: metadata.metadata.type, - name: metadata.metadata.name ?? 'Untitled', - authorName: metadata.metadata.author?.name ?? 'Unknown', - fullText, - chunk, - }; + prompt = await documentContextPrompt.format(baseVars); + promptVars = baseVars; } const formattedPrompt = Object.entries(promptVars).reduce( diff --git a/apps/server/src/services/node-retrieval-service.ts b/apps/server/src/services/node-retrieval-service.ts index 470d0f09..e01600a5 100644 --- a/apps/server/src/services/node-retrieval-service.ts +++ b/apps/server/src/services/node-retrieval-service.ts @@ -52,12 +52,17 @@ export class NodeRetrievalService { const results = await database .selectFrom('node_embeddings') .innerJoin('nodes', 'nodes.id', 'node_embeddings.node_id') + .innerJoin('collaborations', (join) => + join + .onRef('collaborations.node_id', '=', 'node_embeddings.root_id') + .on('collaborations.collaborator_id', '=', sql.lit(userId)) + .on('collaborations.deleted_at', 'is', null) + ) .select((eb) => [ 'node_embeddings.node_id as id', 'node_embeddings.text', 'nodes.created_at', 'node_embeddings.chunk as chunk_index', - // Wrap raw expression to satisfy type: sql`('[${embedding.join(',')}]'::vector) <=> node_embeddings.embedding_vector`.as( 'similarity' ), @@ -92,6 +97,12 @@ export class NodeRetrievalService { const results = await database .selectFrom('node_embeddings') .innerJoin('nodes', 'nodes.id', 'node_embeddings.node_id') + .innerJoin('collaborations', (join) => + join + .onRef('collaborations.node_id', '=', 'node_embeddings.root_id') + .on('collaborations.collaborator_id', '=', sql.lit(userId)) + .on('collaborations.deleted_at', 'is', null) + ) .select((eb) => [ 'node_embeddings.node_id as id', 'node_embeddings.text',