Add user-based access control to document and node retrieval services

This commit is contained in:
Ylber Gashi
2025-02-17 16:33:39 +01:00
parent 2509ad5d75
commit 40d094869f
7 changed files with 194 additions and 156 deletions

View File

@@ -110,7 +110,8 @@ export const assistantResponseHandler = async (
);
const documentDocs = await documentRetrievalService.retrieve(
rewrittenQuery,
workspaceId
workspaceId,
user.id
);
const allContext = [...nodeDocs, ...documentDocs];
const reranked = await rerankDocuments(

View File

@@ -2,7 +2,7 @@ import { OpenAIEmbeddings } from '@langchain/openai';
import { ChunkingService } from '@/services/chunking-service';
import { database } from '@/data/database';
import { configuration } from '@/lib/configuration';
import { CreateDocumentEmbedding } from '@/data/schema';
import { CreateDocumentEmbedding, SelectNode } from '@/data/schema';
import { sql } from 'kysely';
import { fetchNode } from '@/lib/nodes';
import { DocumentContent, extractBlockTexts } from '@colanode/core';
@@ -22,47 +22,26 @@ declare module '@/types/jobs' {
}
const extractDocumentText = async (
documentId: string,
node: SelectNode,
content: DocumentContent
): Promise<string> => {
const sections: string[] = [];
const node = await fetchNode(documentId);
if (!node) {
return '';
}
const nodeModel = getNodeModel(node.attributes.type);
if (!nodeModel) {
return '';
}
const nodeName = nodeModel.getName(documentId, node.attributes);
if (nodeName) {
sections.push(`${node.attributes.type} "${nodeName}"`);
}
const attributesText = nodeModel.getAttributesText(
documentId,
node.attributes
);
if (attributesText) {
sections.push(attributesText);
}
const documentText = nodeModel.getDocumentText(documentId, content);
const documentText = nodeModel.getDocumentText(node.id, content);
if (documentText) {
sections.push(documentText);
return documentText;
} else {
// Fallback to block text extraction if the model doesn't handle it
const blocksText = extractBlockTexts(documentId, content.blocks);
const blocksText = extractBlockTexts(node.id, content.blocks);
if (blocksText) {
sections.push('Content:');
sections.push(blocksText);
return blocksText;
}
}
return sections.filter(Boolean).join('\n\n');
return '';
};
export const embedDocumentHandler = async (input: {
@@ -93,7 +72,7 @@ export const embedDocumentHandler = async (input: {
return;
}
const text = await extractDocumentText(documentId, document.content);
const text = await extractDocumentText(node, document.content);
if (!text || text.trim() === '') {
await database
.deleteFrom('document_embeddings')
@@ -105,7 +84,7 @@ export const embedDocumentHandler = async (input: {
const chunkingService = new ChunkingService();
const chunks = await chunkingService.chunkText(text, {
type: 'document',
id: documentId,
node: node,
});
const embeddings = new OpenAIEmbeddings({
apiKey: configuration.ai.embedding.apiKey,

View File

@@ -26,33 +26,15 @@ const extractNodeText = async (
): Promise<string> => {
if (!node) return '';
// Get the node model to use its text extraction methods
const nodeModel = getNodeModel(node.attributes.type);
if (!nodeModel) return '';
const sections: string[] = [];
// Get the node's name if available
const nodeName = nodeModel.getName(nodeId, node.attributes);
if (nodeName) {
sections.push(`${node.attributes.type} "${nodeName}"`);
}
// Get text from attributes (this handles message content, record fields, etc.)
const attributesText = nodeModel.getAttributesText(nodeId, node.attributes);
if (attributesText) {
sections.push(attributesText);
return attributesText;
}
// For records, add database context
if (node.attributes.type === 'record') {
const databaseNode = await fetchNode(node.attributes.databaseId);
if (databaseNode?.attributes.type === 'database') {
sections.push(`In database "${databaseNode.attributes.name}"`);
}
}
return sections.filter(Boolean).join('\n');
return '';
};
export const embedNodeHandler = async (input: {
@@ -87,7 +69,7 @@ export const embedNodeHandler = async (input: {
const chunkingService = new ChunkingService();
const chunks = await chunkingService.chunkText(text, {
type: 'node',
id: nodeId,
node,
});
const embeddings = new OpenAIEmbeddings({
apiKey: configuration.ai.embedding.apiKey,

View File

@@ -12,7 +12,6 @@ import type { SelectNode, SelectDocument, SelectUser } from '@/data/schema';
type BaseMetadata = {
id: string;
type: string;
name?: string;
createdAt: Date;
createdBy: string;
@@ -21,22 +20,23 @@ type BaseMetadata = {
id: string;
type: string;
name?: string;
path?: string;
};
collaborators?: Array<{ id: string; name: string }>;
collaborators?: Array<{ id: string; name: string; role: string }>;
lastUpdated?: Date;
updatedBy?: { id: string; name: string };
workspace?: { id: string; name: string };
};
export type NodeMetadata = {
export type NodeMetadata = BaseMetadata & {
type: 'node';
metadata: BaseMetadata & {
nodeType: string;
fields?: Record<string, unknown> | null;
};
};
export type DocumentMetadata = {
export type DocumentMetadata = BaseMetadata & {
type: 'document';
metadata: BaseMetadata & {
content: DocumentContent;
};
};
export type ChunkingMetadata = NodeMetadata | DocumentMetadata;
@@ -44,7 +44,7 @@ export type ChunkingMetadata = NodeMetadata | DocumentMetadata;
export class ChunkingService {
public async chunkText(
text: string,
metadata?: { type: 'node' | 'document'; id: string }
metadata?: { type: 'node' | 'document'; node: SelectNode }
): Promise<string[]> {
const chunkSize = configuration.ai.chunking.defaultChunkSize;
const chunkOverlap = configuration.ai.chunking.defaultOverlap;
@@ -70,40 +70,25 @@ export class ChunkingService {
private async fetchMetadata(metadata?: {
type: 'node' | 'document';
id: string;
node: SelectNode;
}): Promise<ChunkingMetadata | undefined> {
if (!metadata) {
return undefined;
}
if (metadata.type === 'node') {
const node = (await database
.selectFrom('nodes')
.selectAll()
.where('id', '=', metadata.id)
.executeTakeFirst()) as SelectNode | undefined;
if (!node) {
return undefined;
}
return this.buildNodeMetadata(node);
return this.buildNodeMetadata(metadata.node);
} else {
const document = (await database
.selectFrom('documents')
.selectAll()
.where('id', '=', metadata.id)
.where('id', '=', metadata.node.id)
.executeTakeFirst()) as SelectDocument | undefined;
if (!document) {
return undefined;
}
const node = (await database
.selectFrom('nodes')
.selectAll()
.where('id', '=', document.id)
.executeTakeFirst()) as SelectNode | undefined;
return this.buildDocumentMetadata(document, node);
return this.buildDocumentMetadata(document, metadata.node);
}
}
@@ -138,10 +123,9 @@ export class ChunkingService {
return {
type: 'node',
metadata: {
...baseMetadata,
nodeType: node.attributes.type,
fields: 'fields' in node.attributes ? node.attributes.fields : null,
},
...baseMetadata,
};
}
@@ -151,7 +135,6 @@ export class ChunkingService {
): Promise<DocumentMetadata> {
let baseMetadata: BaseMetadata = {
id: document.id,
type: 'document',
createdAt: document.created_at,
createdBy: document.created_by,
};
@@ -172,14 +155,18 @@ export class ChunkingService {
}
}
}
return {
type: 'document',
content: document.content,
...baseMetadata,
};
}
return {
type: 'document',
metadata: {
...baseMetadata,
content: document.content,
},
...baseMetadata,
};
}
@@ -187,30 +174,46 @@ export class ChunkingService {
const nodeModel = getNodeModel(node.attributes.type);
const nodeName = nodeModel?.getName(node.id, node.attributes);
const author = (await database
const author = await database
.selectFrom('users')
.select(['id', 'name'])
.where('id', '=', node.created_by)
.executeTakeFirst()) as SelectUser | undefined;
.executeTakeFirst();
const updatedBy = node.updated_by
? await database
.selectFrom('users')
.select(['id', 'name'])
.where('id', '=', node.updated_by)
.executeTakeFirst()
: undefined;
const workspace = await database
.selectFrom('workspaces')
.select(['id', 'name'])
.where('id', '=', node.workspace_id)
.executeTakeFirst();
return {
id: node.id,
type: node.attributes.type,
name: nodeName ?? undefined,
name: nodeName ?? '',
createdAt: node.created_at,
createdBy: node.created_by,
author: author ?? undefined,
lastUpdated: node.updated_at ?? undefined,
updatedBy: updatedBy ?? undefined,
workspace: workspace ?? undefined,
};
}
private async buildParentContext(
node: SelectNode
): Promise<BaseMetadata['parentContext'] | undefined> {
const parentNode = (await database
const parentNode = await database
.selectFrom('nodes')
.selectAll()
.where('id', '=', node.parent_id)
.executeTakeFirst()) as SelectNode | undefined;
.executeTakeFirst();
if (!parentNode) {
return undefined;
@@ -226,24 +229,58 @@ export class ChunkingService {
parentNode.attributes
);
// Get the full path by traversing up the tree
const pathNodes = await database
.selectFrom('node_paths')
.innerJoin('nodes', 'nodes.id', 'node_paths.ancestor_id')
.select(['nodes.id', 'nodes.attributes'])
.where('node_paths.descendant_id', '=', node.id)
.orderBy('node_paths.level', 'asc')
.execute();
const path = pathNodes
.map((n) => {
const model = getNodeModel(n.attributes.type);
return model?.getName(n.id, n.attributes) ?? 'Untitled';
})
.join(' / ');
return {
id: parentNode.id,
type: parentNode.attributes.type,
name: parentName ?? undefined,
path,
};
}
private async fetchCollaborators(
collaboratorIds: string[]
): Promise<Array<{ id: string; name: string }>> {
): Promise<Array<{ id: string; name: string; role: string }>> {
if (!collaboratorIds.length) {
return [];
}
return database
const collaborators = await database
.selectFrom('users')
.select(['id', 'name'])
.where('id', 'in', collaboratorIds)
.execute() as Promise<Array<{ id: string; name: string }>>;
.execute();
// Get roles for each collaborator
return Promise.all(
collaborators.map(async (c) => {
const collaboration = await database
.selectFrom('collaborations')
.select(['role'])
.where('collaborator_id', '=', c.id)
.executeTakeFirst();
return {
id: c.id,
name: c.name,
role: collaboration?.role ?? 'unknown',
};
})
);
}
}

View File

@@ -23,6 +23,7 @@ export class DocumentRetrievalService {
public async retrieve(
query: string,
workspaceId: string,
userId: string,
limit = configuration.ai.retrieval.hybridSearch.maxResults
): Promise<Document[]> {
const embedding = await this.embeddings.embedQuery(query);
@@ -30,20 +31,34 @@ export class DocumentRetrievalService {
const semanticResults = await this.semanticSearch(
embedding,
workspaceId,
userId,
limit
);
const keywordResults = await this.keywordSearch(
query,
workspaceId,
userId,
limit
);
const keywordResults = await this.keywordSearch(query, workspaceId, limit);
return this.combineSearchResults(semanticResults, keywordResults);
}
private async semanticSearch(
embedding: number[],
workspaceId: string,
userId: string,
limit: number
): Promise<SearchResult[]> {
const results = await database
.selectFrom('document_embeddings')
.innerJoin('documents', 'documents.id', 'document_embeddings.document_id')
.innerJoin('nodes', 'nodes.id', 'documents.id')
.innerJoin('collaborations', (join) =>
join
.onRef('collaborations.node_id', '=', 'nodes.root_id')
.on('collaborations.collaborator_id', '=', sql.lit(userId))
.on('collaborations.deleted_at', 'is', null)
)
.select((eb) => [
'document_embeddings.document_id as id',
'document_embeddings.text',
@@ -77,11 +92,19 @@ export class DocumentRetrievalService {
private async keywordSearch(
query: string,
workspaceId: string,
userId: string,
limit: number
): Promise<SearchResult[]> {
const results = await database
.selectFrom('document_embeddings')
.innerJoin('documents', 'documents.id', 'document_embeddings.document_id')
.innerJoin('nodes', 'nodes.id', 'documents.id')
.innerJoin('collaborations', (join) =>
join
.onRef('collaborations.node_id', '=', 'nodes.root_id')
.on('collaborations.collaborator_id', '=', sql.lit(userId))
.on('collaborations.deleted_at', 'is', null)
)
.select((eb) => [
'document_embeddings.document_id as id',
'document_embeddings.text',

View File

@@ -7,15 +7,11 @@ import { HumanMessage } from '@langchain/core/messages';
import { configuration } from '@/lib/configuration';
import { Document } from '@langchain/core/documents';
import { z } from 'zod';
import { NodeAttributes, NodeType } from '@colanode/core';
import type {
ChunkingMetadata,
NodeMetadata,
DocumentMetadata,
} from '@/services/chunking-service';
// Use proper Zod schemas and updated prompt templates
const rerankedDocumentsSchema = z.object({
rankings: z.array(
z.object({
@@ -242,24 +238,13 @@ export async function assessUserIntent(
: 'retrieve';
}
interface NodeContextData {
metadata: {
type: NodeType;
name?: string;
author?: { id: string; name: string };
parentContext?: {
type: string;
name?: string;
};
collaborators?: Array<{ id: string; name: string }>;
fields?: Record<string, unknown>;
};
}
const getNodeContextPrompt = (metadata: NodeMetadata): string => {
const basePrompt = `Given the following context about a {nodeType}:
Name: {name}
Created by: {authorName}
Created by: {authorName} on {createdAt}
Last updated: {lastUpdated} by {updatedByName}
Location: {path}
Workspace: {workspaceName}
{additionalContext}
Full content:
@@ -274,42 +259,56 @@ Generate a brief (50-100 tokens) contextual prefix that:
3. Makes the chunk more understandable in isolation
Do not repeat the chunk content. Return only the contextual prefix.`;
const getCollaboratorNames = (
collaborators?: Array<{ id: string; name: string }>
) => collaborators?.map((c) => c.name).join(', ') ?? 'unknown';
const getCollaboratorInfo = (
collaborators?: Array<{ id: string; name: string; role: string }>
) => {
if (!collaborators?.length) return 'No collaborators';
return collaborators.map((c) => `${c.name} (${c.role})`).join(', ');
};
switch (metadata.metadata.type) {
const formatDate = (date?: Date) => {
if (!date) return 'unknown';
return new Date(date).toLocaleString();
};
switch (metadata.nodeType) {
case 'message':
return basePrompt.replace(
'{additionalContext}',
`In: ${metadata.metadata.parentContext?.type ?? 'unknown'} "${metadata.metadata.parentContext?.name ?? 'unknown'}"
Participants: ${getCollaboratorNames(metadata.metadata.collaborators)}`
`In: ${metadata.parentContext?.type ?? 'unknown'} "${metadata.parentContext?.name ?? 'unknown'}"
Path: ${metadata.parentContext?.path ?? 'unknown'}
Participants: ${getCollaboratorInfo(metadata.collaborators)}`
);
case 'record':
return basePrompt.replace(
'{additionalContext}',
`Database: ${metadata.metadata.parentContext?.name ?? 'unknown'}
Fields: ${Object.keys(metadata.metadata.fields ?? {}).join(', ')}`
`Database: ${metadata.parentContext?.name ?? 'unknown'}
Path: ${metadata.parentContext?.path ?? 'unknown'}
Fields: ${Object.keys(metadata.fields ?? {}).join(', ')}`
);
case 'page':
return basePrompt.replace(
'{additionalContext}',
`Location: ${metadata.metadata.parentContext?.name ? `in ${metadata.metadata.parentContext.name}` : 'root level'}`
`Location: ${metadata.parentContext?.path ?? 'root level'}
Collaborators: ${getCollaboratorInfo(metadata.collaborators)}`
);
case 'database':
return basePrompt.replace(
'{additionalContext}',
`Fields: ${Object.keys(metadata.metadata.fields ?? {}).join(', ')}`
`Path: ${metadata.parentContext?.path ?? 'root level'}
Fields: ${Object.keys(metadata.fields ?? {}).join(', ')}
Collaborators: ${getCollaboratorInfo(metadata.collaborators)}`
);
case 'channel':
return basePrompt.replace(
'{additionalContext}',
`Type: Channel
Members: ${getCollaboratorNames(metadata.metadata.collaborators)}`
Path: ${metadata.parentContext?.path ?? 'root level'}
Members: ${getCollaboratorInfo(metadata.collaborators)}`
);
default:
@@ -321,6 +320,11 @@ interface PromptVariables {
nodeType: string;
name: string;
authorName: string;
createdAt: string;
lastUpdated: string;
updatedByName: string;
path: string;
workspaceName: string;
fullText: string;
chunk: string;
[key: string]: string;
@@ -330,8 +334,10 @@ const documentContextPrompt = PromptTemplate.fromTemplate(
`Given the following context about a document:
Type: {nodeType}
Name: {name}
Parent: {parentName}
Created by: {authorName}
Location: {path}
Created by: {authorName} on {createdAt}
Last updated: {lastUpdated} by {updatedByName}
Workspace: {workspaceName}
Full content:
{fullText}
@@ -362,31 +368,30 @@ export async function addContextToChunk(
let prompt: string;
let promptVars: PromptVariables;
const formatDate = (date?: Date) => {
if (!date) return 'unknown';
return new Date(date).toLocaleString();
};
const baseVars = {
nodeType: metadata.type === 'node' ? metadata.nodeType : metadata.type,
name: metadata.name ?? 'Untitled',
authorName: metadata.author?.name ?? 'Unknown',
createdAt: formatDate(metadata.createdAt),
lastUpdated: formatDate(metadata.lastUpdated),
updatedByName: metadata.updatedBy?.name ?? 'Unknown',
path: metadata.parentContext?.path ?? 'root level',
workspaceName: metadata.workspace?.name ?? 'Unknown Workspace',
fullText,
chunk,
};
if (metadata.type === 'node') {
prompt = getNodeContextPrompt(metadata);
promptVars = {
nodeType: metadata.metadata.type,
name: metadata.metadata.name ?? 'Untitled',
authorName: metadata.metadata.author?.name ?? 'Unknown',
fullText,
chunk,
};
promptVars = baseVars;
} else {
prompt = await documentContextPrompt.format({
nodeType: metadata.metadata.type,
name: metadata.metadata.name ?? 'Untitled',
parentName: metadata.metadata.parentContext?.name ?? 'Unknown',
authorName: metadata.metadata.author?.name ?? 'Unknown',
fullText,
chunk,
});
promptVars = {
nodeType: metadata.metadata.type,
name: metadata.metadata.name ?? 'Untitled',
authorName: metadata.metadata.author?.name ?? 'Unknown',
fullText,
chunk,
};
prompt = await documentContextPrompt.format(baseVars);
promptVars = baseVars;
}
const formattedPrompt = Object.entries(promptVars).reduce(

View File

@@ -52,12 +52,17 @@ export class NodeRetrievalService {
const results = await database
.selectFrom('node_embeddings')
.innerJoin('nodes', 'nodes.id', 'node_embeddings.node_id')
.innerJoin('collaborations', (join) =>
join
.onRef('collaborations.node_id', '=', 'node_embeddings.root_id')
.on('collaborations.collaborator_id', '=', sql.lit(userId))
.on('collaborations.deleted_at', 'is', null)
)
.select((eb) => [
'node_embeddings.node_id as id',
'node_embeddings.text',
'nodes.created_at',
'node_embeddings.chunk as chunk_index',
// Wrap raw expression to satisfy type:
sql<number>`('[${embedding.join(',')}]'::vector) <=> node_embeddings.embedding_vector`.as(
'similarity'
),
@@ -92,6 +97,12 @@ export class NodeRetrievalService {
const results = await database
.selectFrom('node_embeddings')
.innerJoin('nodes', 'nodes.id', 'node_embeddings.node_id')
.innerJoin('collaborations', (join) =>
join
.onRef('collaborations.node_id', '=', 'node_embeddings.root_id')
.on('collaborations.collaborator_id', '=', sql.lit(userId))
.on('collaborations.deleted_at', 'is', null)
)
.select((eb) => [
'node_embeddings.node_id as id',
'node_embeddings.text',