From 01c8f4273ebe96f47a2ac811162fdf89f181a506 Mon Sep 17 00:00:00 2001 From: Ylber Gashi Date: Mon, 17 Feb 2025 13:56:31 +0100 Subject: [PATCH] Add vector embedding support for nodes and documents --- .../00013-create-vector-extension.ts | 10 + .../00014-create-node-embeddings-table.ts | 42 +++ .../00015-create-document-embeddings-table.ts | 43 +++ apps/server/src/data/migrations/index.ts | 6 + apps/server/src/data/schema.ts | 34 +++ apps/server/src/jobs/assistant-response.ts | 198 +++++++++++++ apps/server/src/jobs/embed-document.ts | 121 ++++++++ apps/server/src/jobs/embed-node.ts | 183 ++++++++++++ apps/server/src/jobs/index.ts | 7 +- apps/server/src/lib/configuration.ts | 162 ++++++++-- apps/server/src/lib/nodes.ts | 49 +++ apps/server/src/services/chunking-service.ts | 128 +++++--- .../services/document-retrieval-service.ts | 189 ++++++++++++ apps/server/src/services/llm-service.ts | 278 ++++++++++++++++++ .../src/services/node-retrieval-service.ts | 199 +++++++++++++ 15 files changed, 1575 insertions(+), 74 deletions(-) create mode 100644 apps/server/src/data/migrations/00013-create-vector-extension.ts create mode 100644 apps/server/src/data/migrations/00014-create-node-embeddings-table.ts create mode 100644 apps/server/src/data/migrations/00015-create-document-embeddings-table.ts create mode 100644 apps/server/src/jobs/assistant-response.ts create mode 100644 apps/server/src/jobs/embed-document.ts create mode 100644 apps/server/src/jobs/embed-node.ts create mode 100644 apps/server/src/services/document-retrieval-service.ts create mode 100644 apps/server/src/services/llm-service.ts create mode 100644 apps/server/src/services/node-retrieval-service.ts diff --git a/apps/server/src/data/migrations/00013-create-vector-extension.ts b/apps/server/src/data/migrations/00013-create-vector-extension.ts new file mode 100644 index 00000000..fd02edb4 --- /dev/null +++ b/apps/server/src/data/migrations/00013-create-vector-extension.ts @@ -0,0 +1,10 @@ +import { Migration, sql } from 'kysely'; + +export const createVectorExtension: Migration = { + up: async (db) => { + await sql`CREATE EXTENSION IF NOT EXISTS vector`.execute(db); + }, + down: async (db) => { + await sql`DROP EXTENSION IF EXISTS vector`.execute(db); + }, +}; diff --git a/apps/server/src/data/migrations/00014-create-node-embeddings-table.ts b/apps/server/src/data/migrations/00014-create-node-embeddings-table.ts new file mode 100644 index 00000000..a9664792 --- /dev/null +++ b/apps/server/src/data/migrations/00014-create-node-embeddings-table.ts @@ -0,0 +1,42 @@ +import { Migration, sql } from 'kysely'; + +export const createNodeEmbeddingsTable: Migration = { + up: async (db) => { + await db.schema + .createTable('node_embeddings') + .addColumn('node_id', 'varchar(30)', (col) => col.notNull()) + .addColumn('chunk', 'integer', (col) => col.notNull()) + .addColumn('parent_id', 'varchar(30)') + .addColumn('root_id', 'varchar(30)', (col) => col.notNull()) + .addColumn('workspace_id', 'varchar(30)', (col) => col.notNull()) + .addColumn('text', 'text', (col) => col.notNull()) + .addColumn('embedding_vector', sql`vector(2000)`, (col) => col.notNull()) + .addColumn( + 'search_vector', + sql`tsvector GENERATED ALWAYS AS (to_tsvector('english', text)) STORED` + ) + .addColumn('created_at', 'timestamptz', (col) => col.notNull()) + .addColumn('updated_at', 'timestamptz') + .addPrimaryKeyConstraint('node_embeddings_pkey', ['node_id', 'chunk']) + .execute(); + + await sql` + CREATE INDEX node_embeddings_embedding_vector_idx + ON node_embeddings + USING hnsw(embedding_vector vector_cosine_ops) + WITH ( + m = 16, + ef_construction = 64 + ); + `.execute(db); + + await sql` + CREATE INDEX node_embeddings_search_vector_idx + ON node_embeddings + USING GIN (search_vector); + `.execute(db); + }, + down: async (db) => { + await db.schema.dropTable('node_embeddings').execute(); + }, +}; diff --git a/apps/server/src/data/migrations/00015-create-document-embeddings-table.ts b/apps/server/src/data/migrations/00015-create-document-embeddings-table.ts new file mode 100644 index 00000000..612cd5c3 --- /dev/null +++ b/apps/server/src/data/migrations/00015-create-document-embeddings-table.ts @@ -0,0 +1,43 @@ +import { Migration, sql } from 'kysely'; + +export const createDocumentEmbeddingsTable: Migration = { + up: async (db) => { + await db.schema + .createTable('document_embeddings') + .addColumn('document_id', 'varchar(30)', (col) => col.notNull()) + .addColumn('chunk', 'integer', (col) => col.notNull()) + .addColumn('workspace_id', 'varchar(30)', (col) => col.notNull()) + .addColumn('text', 'text', (col) => col.notNull()) + .addColumn('embedding_vector', sql`vector(2000)`, (col) => col.notNull()) + .addColumn( + 'search_vector', + sql`tsvector GENERATED ALWAYS AS (to_tsvector('english', text)) STORED` + ) + .addColumn('created_at', 'timestamptz', (col) => col.notNull()) + .addColumn('updated_at', 'timestamptz') + .addPrimaryKeyConstraint('document_embeddings_pkey', [ + 'document_id', + 'chunk', + ]) + .execute(); + + await sql` + CREATE INDEX document_embeddings_embedding_vector_idx + ON document_embeddings + USING hnsw(embedding_vector vector_cosine_ops) + WITH ( + m = 16, + ef_construction = 64 + ); + `.execute(db); + + await sql` + CREATE INDEX document_embeddings_search_vector_idx + ON document_embeddings + USING GIN (search_vector); + `.execute(db); + }, + down: async (db) => { + await db.schema.dropTable('document_embeddings').execute(); + }, +}; diff --git a/apps/server/src/data/migrations/index.ts b/apps/server/src/data/migrations/index.ts index a74eae31..74f37b60 100644 --- a/apps/server/src/data/migrations/index.ts +++ b/apps/server/src/data/migrations/index.ts @@ -12,6 +12,9 @@ import { createNodePathsTable } from './00009-create-node-paths-table'; import { createCollaborationsTable } from './00010-create-collaborations-table'; import { createDocumentsTable } from './00011-create-documents-table'; import { createDocumentUpdatesTable } from './00012-create-document-updates-table'; +import { createNodeEmbeddingsTable } from './00014-create-node-embeddings-table'; +import { createDocumentEmbeddingsTable } from './00015-create-document-embeddings-table'; +import { createVectorExtension } from './00013-create-vector-extension'; export const databaseMigrations: Record = { '00001_create_accounts_table': createAccountsTable, @@ -26,4 +29,7 @@ export const databaseMigrations: Record = { '00010_create_collaborations_table': createCollaborationsTable, '00011_create_documents_table': createDocumentsTable, '00012_create_document_updates_table': createDocumentUpdatesTable, + '00013_create_vector_extension': createVectorExtension, + '00014_create_node_embeddings_table': createNodeEmbeddingsTable, + '00015_create_document_embeddings_table': createDocumentEmbeddingsTable, }; diff --git a/apps/server/src/data/schema.ts b/apps/server/src/data/schema.ts index 11c4179e..ac079845 100644 --- a/apps/server/src/data/schema.ts +++ b/apps/server/src/data/schema.ts @@ -221,6 +221,38 @@ export type SelectDocumentUpdate = Selectable; export type CreateDocumentUpdate = Insertable; export type UpdateDocumentUpdate = Updateable; +interface NodeEmbeddingTable { + node_id: ColumnType; + chunk: ColumnType; + parent_id: ColumnType; + root_id: ColumnType; + workspace_id: ColumnType; + text: ColumnType; + embedding_vector: ColumnType; + search_vector: ColumnType; + created_at: ColumnType; + updated_at: ColumnType; +} + +export type SelectNodeEmbedding = Selectable; +export type CreateNodeEmbedding = Insertable; +export type UpdateNodeEmbedding = Updateable; + +interface DocumentEmbeddingTable { + document_id: ColumnType; + chunk: ColumnType; + workspace_id: ColumnType; + text: ColumnType; + embedding_vector: ColumnType; + search_vector: ColumnType; + created_at: ColumnType; + updated_at: ColumnType; +} + +export type SelectDocumentEmbedding = Selectable; +export type CreateDocumentEmbedding = Insertable; +export type UpdateDocumentEmbedding = Updateable; + export interface DatabaseSchema { accounts: AccountTable; devices: DeviceTable; @@ -234,4 +266,6 @@ export interface DatabaseSchema { collaborations: CollaborationTable; documents: DocumentTable; document_updates: DocumentUpdateTable; + node_embeddings: NodeEmbeddingTable; + document_embeddings: DocumentEmbeddingTable; } diff --git a/apps/server/src/jobs/assistant-response.ts b/apps/server/src/jobs/assistant-response.ts new file mode 100644 index 00000000..37ff9cca --- /dev/null +++ b/apps/server/src/jobs/assistant-response.ts @@ -0,0 +1,198 @@ +import { + generateId, + IdType, + extractBlockTexts, + MessageAttributes, + NodeAttributes, +} from '@colanode/core'; +import { Document } from '@langchain/core/documents'; +import { database } from '@/data/database'; +import { eventBus } from '@/lib/event-bus'; +import { configuration } from '@/lib/configuration'; +import { fetchNode } from '@/lib/nodes'; +import { nodeRetrievalService } from '@/services/node-retrieval-service'; +import { documentRetrievalService } from '@/services/document-retrieval-service'; +import { JobHandler } from '@/types/jobs'; +import { + rewriteQuery, + rerankDocuments, + generateFinalAnswer, + generateNoContextAnswer, + assessUserIntent, +} from '@/services/llm-service'; + +export type AssistantResponseInput = { + type: 'assistant_response'; + nodeId: string; + workspaceId: string; +}; + +declare module '@/types/jobs' { + interface JobMap { + assistant_response: { + input: AssistantResponseInput; + }; + } +} + +export const assistantResponseHandler = async ( + input: AssistantResponseInput +) => { + const { nodeId, workspaceId } = input; + console.log('Starting assistant response handler', { nodeId, workspaceId }); + if (!configuration.ai.enabled) return; + + const node = await fetchNode(nodeId); + if (!node) return; + // Assume nodes of type 'message' carry the user query. + const userInputText = extractBlockTexts( + node.id, + (node.attributes as MessageAttributes).content + ); + if (!userInputText || userInputText.trim() === '') return; + + // Fetch user details (assuming created_by is the user id) + const user = await database + .selectFrom('users') + .where('id', '=', node.created_by) + .selectAll() + .executeTakeFirst(); + if (!user) return; + + // Get conversation history: for example, sibling nodes (other messages with the same parent) + const chatHistoryNodes = await database + .selectFrom('nodes') + .selectAll() + .where('parent_id', '=', node.parent_id) + .where('id', '!=', node.id) + .orderBy('created_at', 'asc') + .limit(10) + .execute(); + + const chatHistory = chatHistoryNodes.map( + (n) => + new Document({ + pageContent: + extractBlockTexts( + n.id, + (n.attributes as MessageAttributes).content + ) || '', + metadata: { id: n.id, type: n.type, createdAt: n.created_at }, + }) + ); + + const formattedChatHistory = chatHistory + .map((doc) => { + const ts = doc.metadata.createdAt + ? new Date(doc.metadata.createdAt).toLocaleString() + : 'Unknown'; + return `- [${ts}] ${doc.metadata.id}: ${doc.pageContent}`; + }) + .join('\n'); + + const intent = await assessUserIntent(userInputText, formattedChatHistory); + + let finalAnswer: string; + let citations: Array<{ sourceId: string; quote: string }>; + + if (intent === 'no_context') { + finalAnswer = await generateNoContextAnswer( + userInputText, + formattedChatHistory + ); + citations = []; + } else { + const rewrittenQuery = await rewriteQuery(userInputText); + const nodeDocs = await nodeRetrievalService.retrieve( + rewrittenQuery, + workspaceId, + user.id + ); + const documentDocs = await documentRetrievalService.retrieve( + rewrittenQuery, + workspaceId + ); + const allContext = [...nodeDocs, ...documentDocs]; + const reranked = await rerankDocuments( + allContext.map((doc, idx) => ({ + content: doc.pageContent, + type: doc.metadata.type, + sourceId: doc.metadata.id, + })), + rewrittenQuery + ); + const topContext = reranked.slice(0, 5); + const formattedMessages = allContext + .filter((doc) => doc.metadata.type === 'message') + .map((doc) => { + const ts = doc.metadata.createdAt + ? new Date(doc.metadata.createdAt).toLocaleString() + : 'Unknown'; + return `- [${ts}] ${doc.metadata.id}: ${doc.pageContent}`; + }) + .join('\n'); + const formattedDocuments = allContext + .filter( + (doc) => + doc.metadata.type === 'page' || doc.metadata.type === 'document' + ) + .map((doc) => { + const ts = doc.metadata.createdAt + ? new Date(doc.metadata.createdAt).toLocaleString() + : 'Unknown'; + return `- [${ts}] ${doc.metadata.id}: ${doc.pageContent}`; + }) + .join('\n'); + const promptArgs = { + currentTimestamp: new Date().toISOString(), + workspaceName: workspaceId, // Or retrieve workspace details + userName: user.name || 'User', + userEmail: user.email || 'unknown@example.com', + formattedChatHistory, + formattedMessages, + formattedDocuments, + question: userInputText, + }; + const result = await generateFinalAnswer(promptArgs); + finalAnswer = result.answer; + citations = result.citations; + } + + // Create a response node (answer message) + const responseNodeId = generateId(IdType.Node); + const responseAttributes = { + type: 'message', + subtype: 'standard', + content: { + [generateId(IdType.Block)]: { + id: generateId(IdType.Block), + type: 'paragraph', + parentId: responseNodeId, + index: 'a', + content: [{ type: 'text', text: finalAnswer }], + }, + }, + parentId: node.parent_id, + referenceId: node.id, + }; + + await database + .insertInto('nodes') + .values({ + id: responseNodeId, + root_id: node.root_id, + workspace_id: workspaceId, + attributes: JSON.stringify(responseAttributes), + state: Buffer.from(''), // Initial empty state + created_at: new Date(), + created_by: 'colanode_ai', + }) + .executeTakeFirst(); + + eventBus.publish({ + type: 'node_created', + nodeId: responseNodeId, + rootId: node.root_id, + workspaceId, + }); +}; diff --git a/apps/server/src/jobs/embed-document.ts b/apps/server/src/jobs/embed-document.ts new file mode 100644 index 00000000..243f4da6 --- /dev/null +++ b/apps/server/src/jobs/embed-document.ts @@ -0,0 +1,121 @@ +// Updated embed-document.ts +import { OpenAIEmbeddings } from '@langchain/openai'; +import { ChunkingService } from '@/services/chunking-service'; +import { database } from '@/data/database'; +import { configuration } from '@/lib/configuration'; +import { CreateDocumentEmbedding } from '@/data/schema'; +import { sql } from 'kysely'; +import { fetchNodeWithContext } from '@/services/node-retrieval-service'; +import { DocumentContent, extractBlockTexts } from '@colanode/core'; +import { JobHandler } from '@/types/jobs'; + +export type EmbedDocumentInput = { + type: 'embed_document'; + documentId: string; +}; + +export const embedDocumentHandler: JobHandler = async ( + input +) => { + if (!configuration.ai.enabled) return; + const { documentId } = input; + + // Retrieve document along with its associated node in one query + const document = await database + .selectFrom('documents') + .select(['id', 'content', 'workspace_id', 'created_at']) + .where('id', '=', documentId) + .executeTakeFirst(); + if (!document) return; + + // Fetch associated node for context (page node, etc.) + const nodeContext = await fetchNodeWithContext(documentId); + // If available, include node type and parent (space) name in the context. + let header = ''; + if (nodeContext) { + const { node, parent, root } = nodeContext; + header = `${node.attributes.type} "${(node.attributes as any).name || ''}"`; + if (parent && parent.attributes.name) { + header += ` in "${parent.attributes.name}"`; + } + if (root && root.attributes.name) { + header += ` [${root.attributes.name}]`; + } + } + + // Extract text using document content. (For pages and rich–text, use extractBlockTexts.) + const docText = extractBlockTexts(documentId, document.content.blocks) || ''; + const fullText = header ? `${header}\n\nContent:\n${docText}` : docText; + if (!fullText.trim()) { + await database + .deleteFrom('document_embeddings') + .where('document_id', '=', documentId) + .execute(); + return; + } + + const chunkingService = new ChunkingService(); + const chunks = await chunkingService.chunkText(fullText, { + type: 'document', + id: documentId, + }); + const embeddings = new OpenAIEmbeddings({ + apiKey: configuration.ai.embedding.apiKey, + modelName: configuration.ai.embedding.modelName, + dimensions: configuration.ai.embedding.dimensions, + }); + + const existingEmbeddings = await database + .selectFrom('document_embeddings') + .select(['chunk', 'text']) + .where('document_id', '=', documentId) + .execute(); + + const embeddingsToUpsert: CreateDocumentEmbedding[] = []; + for (let i = 0; i < chunks.length; i++) { + const chunk = chunks[i]; + const existing = existingEmbeddings.find((e) => e.chunk === i); + if (existing && existing.text === chunk) continue; + embeddingsToUpsert.push({ + document_id: documentId, + chunk: i, + workspace_id: document.workspace_id, + text: chunk, + embedding_vector: [], + created_at: new Date(), + }); + } + + const batchSize = configuration.ai.embedding.batchSize; + for (let i = 0; i < embeddingsToUpsert.length; i += batchSize) { + const batch = embeddingsToUpsert.slice(i, i + batchSize); + const textsToEmbed = batch.map((item) => item.text); + const embeddingVectors = await embeddings.embedDocuments(textsToEmbed); + for (let j = 0; j < batch.length; j++) { + batch[j].embedding_vector = embeddingVectors[j] ?? []; + } + } + + for (const embedding of embeddingsToUpsert) { + await database + .insertInto('document_embeddings') + .values({ + document_id: embedding.document_id, + chunk: embedding.chunk, + workspace_id: embedding.workspace_id, + text: embedding.text, + embedding_vector: sql.raw( + `'[${embedding.embedding_vector.join(',')}]'::vector` + ), + created_at: embedding.created_at, + }) + .onConflict((oc) => + oc.columns(['document_id', 'chunk']).doUpdateSet({ + text: sql.ref('excluded.text'), + embedding_vector: sql.ref('excluded.embedding_vector'), + updated_at: new Date(), + }) + ) + .execute(); + } +}; diff --git a/apps/server/src/jobs/embed-node.ts b/apps/server/src/jobs/embed-node.ts new file mode 100644 index 00000000..0ead62cd --- /dev/null +++ b/apps/server/src/jobs/embed-node.ts @@ -0,0 +1,183 @@ +import { OpenAIEmbeddings } from '@langchain/openai'; +import { extractBlockTexts, NodeAttributes, FieldValue } from '@colanode/core'; +import { ChunkingService } from '@/services/chunking-service'; +import { database } from '@/data/database'; +import { configuration } from '@/lib/configuration'; +import { CreateNodeEmbedding } from '@/data/schema'; +import { sql } from 'kysely'; +import { fetchNode } from '@/lib/nodes'; +import { JobHandler } from '@/types/jobs'; + +export type EmbedNodeInput = { + type: 'embed_node'; + nodeId: string; +}; + +declare module '@/types/jobs' { + interface JobMap { + embed_node: { + input: EmbedNodeInput; + }; + } +} + +const formatFieldValue = (fieldValue: FieldValue): string => { + switch (fieldValue.type) { + case 'boolean': + return `${fieldValue.value ? 'Yes' : 'No'}`; + case 'string_array': + return (fieldValue.value as string[]).join(', '); + case 'number': + case 'string': + case 'text': + return String(fieldValue.value); + default: + return ''; + } +}; + +const extractNodeText = async ( + nodeId: string, + attributes: NodeAttributes +): Promise => { + switch (attributes.type) { + case 'message': + return extractBlockTexts(nodeId, attributes.content) ?? ''; + case 'record': { + const sections: string[] = []; + + // Fetch the database node to get its name + const databaseNode = await fetchNode(attributes.databaseId); + const databaseName = + databaseNode?.attributes.type === 'database' + ? databaseNode.attributes.name + : attributes.databaseId; + + // Add field context with database name + sections.push(`Field "${attributes.name}" in database "${databaseName}"`); + + // Process field value + Object.entries(attributes.fields).forEach(([fieldName, fieldValue]) => { + if (!fieldValue || !('type' in fieldValue)) { + return; + } + + const value = formatFieldValue(fieldValue as FieldValue); + if (value) { + sections.push(value); + } + }); + + return sections.join('\n'); + } + default: + return ''; + } +}; + +export const embedNodeHandler = async (input: { + type: 'embed_node'; + nodeId: string; +}) => { + if (!configuration.ai.enabled) { + return; + } + + const { nodeId } = input; + const node = await fetchNode(nodeId); + if (!node) { + return; + } + + // Skip page nodes (document content is handled separately in the embed document job) + if (node.type === 'page') { + return; + } + + const text = await extractNodeText(node.id, node.attributes); + if (!text || text.trim() === '') { + await database + .deleteFrom('node_embeddings') + .where('node_id', '=', nodeId) + .execute(); + return; + } + + const chunkingService = new ChunkingService(); + const chunks = await chunkingService.chunkText(text, { + type: 'node', + id: nodeId, + }); + const embeddings = new OpenAIEmbeddings({ + apiKey: configuration.ai.embedding.apiKey, + modelName: configuration.ai.embedding.modelName, + dimensions: configuration.ai.embedding.dimensions, + }); + + const existingEmbeddings = await database + .selectFrom('node_embeddings') + .select(['chunk', 'text']) + .where('node_id', '=', nodeId) + .execute(); + + const embeddingsToCreateOrUpdate: CreateNodeEmbedding[] = []; + for (let i = 0; i < chunks.length; i++) { + const chunk = chunks[i]; + if (!chunk) continue; + const existing = existingEmbeddings.find((e) => e.chunk === i); + if (existing && existing.text === chunk) continue; + embeddingsToCreateOrUpdate.push({ + node_id: nodeId, + chunk: i, + parent_id: node.parent_id, + root_id: node.root_id, + workspace_id: node.workspace_id, + text: chunk, + embedding_vector: [], + created_at: new Date(), + }); + } + + const batchSize = configuration.ai.embedding.batchSize; + for (let i = 0; i < embeddingsToCreateOrUpdate.length; i += batchSize) { + const batch = embeddingsToCreateOrUpdate.slice(i, i + batchSize); + const textsToEmbed = batch.map((item) => item.text); + const embeddingVectors = await embeddings.embedDocuments(textsToEmbed); + for (let j = 0; j < batch.length; j++) { + const vector = embeddingVectors[j]; + const batchItem = batch[j]; + if (batchItem) { + batchItem.embedding_vector = vector ?? []; + } + } + } + + if (embeddingsToCreateOrUpdate.length === 0) { + return; + } + + for (const embedding of embeddingsToCreateOrUpdate) { + await database + .insertInto('node_embeddings') + .values({ + node_id: embedding.node_id, + chunk: embedding.chunk, + parent_id: embedding.parent_id, + root_id: embedding.root_id, + workspace_id: embedding.workspace_id, + text: embedding.text, + embedding_vector: sql.raw( + `'[${embedding.embedding_vector.join(',')}]'::vector` + ), + created_at: embedding.created_at, + }) + .onConflict((oc) => + oc.columns(['node_id', 'chunk']).doUpdateSet({ + text: sql.ref('excluded.text'), + embedding_vector: sql.ref('excluded.embedding_vector'), + updated_at: new Date(), + }) + ) + .execute(); + } +}; diff --git a/apps/server/src/jobs/index.ts b/apps/server/src/jobs/index.ts index 70ab60ec..24829312 100644 --- a/apps/server/src/jobs/index.ts +++ b/apps/server/src/jobs/index.ts @@ -2,7 +2,9 @@ import { cleanNodeDataHandler } from '@/jobs/clean-node-data'; import { cleanWorkspaceDataHandler } from '@/jobs/clean-workspace-data'; import { JobHandler, JobMap } from '@/types/jobs'; import { sendEmailVerifyEmailHandler } from '@/jobs/send-email-verify-email'; - +import { embedNodeHandler } from './embed-node'; +import { embedDocumentHandler } from './embed-document'; +import { assistantResponseHandler } from './assistant-response'; type JobHandlerMap = { [K in keyof JobMap]: JobHandler; }; @@ -11,4 +13,7 @@ export const jobHandlerMap: JobHandlerMap = { send_email_verify_email: sendEmailVerifyEmailHandler, clean_workspace_data: cleanWorkspaceDataHandler, clean_node_data: cleanNodeDataHandler, + embed_node: embedNodeHandler, + embed_document: embedDocumentHandler, + assistant_response: assistantResponseHandler, }; diff --git a/apps/server/src/lib/configuration.ts b/apps/server/src/lib/configuration.ts index 66244091..972a4733 100644 --- a/apps/server/src/lib/configuration.ts +++ b/apps/server/src/lib/configuration.ts @@ -69,26 +69,63 @@ export interface SmtpConfiguration { }; } +export type AIProvider = 'openai' | 'google'; + +export interface AIProviderConfiguration { + apiKey: string; + enabled?: boolean; +} + +export interface AIModelConfiguration { + provider: AIProvider; + modelName: string; + temperature: number; +} + export interface AiConfiguration { enabled: boolean; entryEmbedDelay: number; - openai: OpenAiConfiguration; + providers: { + openai: AIProviderConfiguration; + google: AIProviderConfiguration; + }; + langfuse: { + publicKey: string; + secretKey: string; + baseUrl: string; + }; + models: { + queryRewrite: AIModelConfiguration; + response: AIModelConfiguration; + rerank: AIModelConfiguration; + summarization: AIModelConfiguration; + contextEnhancer: AIModelConfiguration; + noContext: AIModelConfiguration; + intentRecognition: AIModelConfiguration; + }; + embedding: { + provider: AIProvider; + apiKey: string; + modelName: string; + dimensions: number; + batchSize: number; + }; chunking: ChunkingConfiguration; -} - -export interface OpenAiConfiguration { - apiKey: string; - embeddingModel: string; - embeddingDimensions: number; - embeddingBatchSize: number; + retrieval: RetrievalConfiguration; } export interface ChunkingConfiguration { defaultChunkSize: number; defaultOverlap: number; enhanceWithContext: boolean; - contextEnhancerModel: string; - contextEnhancerTemperature: number; +} + +export interface RetrievalConfiguration { + hybridSearch: { + semanticSearchWeight: number; + keywordSearchWeight: number; + maxResults: number; + }; } const getRequiredEnv = (env: string): string => { @@ -172,15 +209,84 @@ export const configuration: Configuration = { entryEmbedDelay: parseInt( getOptionalEnv('AI_ENTRY_EMBED_DELAY') || '60000' ), - openai: { - apiKey: getOptionalEnv('OPENAI_API_KEY') || '', - embeddingModel: getOptionalEnv('OPENAI_EMBEDDING_MODEL') || '', - embeddingDimensions: parseInt( - getOptionalEnv('OPENAI_EMBEDDING_DIMENSIONS') || '2000' - ), - embeddingBatchSize: parseInt( - getOptionalEnv('OPENAI_EMBEDDING_BATCH_SIZE') || '50' - ), + providers: { + openai: { + apiKey: getOptionalEnv('OPENAI_API_KEY') || '', + enabled: getOptionalEnv('OPENAI_ENABLED') === 'true', + }, + google: { + apiKey: getOptionalEnv('GOOGLE_API_KEY') || '', + enabled: getOptionalEnv('GOOGLE_ENABLED') === 'true', + }, + }, + langfuse: { + publicKey: getOptionalEnv('LANGFUSE_PUBLIC_KEY') || '', + secretKey: getOptionalEnv('LANGFUSE_SECRET_KEY') || '', + baseUrl: + getOptionalEnv('LANGFUSE_BASE_URL') || 'https://cloud.langfuse.com', + }, + models: { + queryRewrite: { + provider: (getOptionalEnv('QUERY_REWRITE_PROVIDER') || + 'openai') as AIProvider, + modelName: getOptionalEnv('QUERY_REWRITE_MODEL') || 'gpt-4o-mini', + temperature: parseFloat( + getOptionalEnv('QUERY_REWRITE_TEMPERATURE') || '0.3' + ), + }, + response: { + provider: (getOptionalEnv('RESPONSE_PROVIDER') || + 'openai') as AIProvider, + modelName: getOptionalEnv('RESPONSE_MODEL') || 'gpt-4o-mini', + temperature: parseFloat( + getOptionalEnv('RESPONSE_TEMPERATURE') || '0.3' + ), + }, + rerank: { + provider: (getOptionalEnv('RERANK_PROVIDER') || 'openai') as AIProvider, + modelName: getOptionalEnv('RERANK_MODEL') || 'gpt-4o-mini', + temperature: parseFloat(getOptionalEnv('RERANK_TEMPERATURE') || '0.3'), + }, + summarization: { + provider: (getOptionalEnv('SUMMARIZATION_PROVIDER') || + 'openai') as AIProvider, + modelName: getOptionalEnv('SUMMARIZATION_MODEL') || 'gpt-4o-mini', + temperature: parseFloat( + getOptionalEnv('SUMMARIZATION_TEMPERATURE') || '0.3' + ), + }, + contextEnhancer: { + provider: (getOptionalEnv('CHUNK_CONTEXT_PROVIDER') || + 'openai') as AIProvider, + modelName: getOptionalEnv('CHUNK_CONTEXT_MODEL') || 'gpt-4o-mini', + temperature: parseFloat( + getOptionalEnv('CHUNK_CONTEXT_TEMPERATURE') || '0.3' + ), + }, + noContext: { + provider: (getOptionalEnv('NO_CONTEXT_PROVIDER') || + 'openai') as AIProvider, + modelName: getOptionalEnv('NO_CONTEXT_MODEL') || 'gpt-4o-mini', + temperature: parseFloat( + getOptionalEnv('NO_CONTEXT_TEMPERATURE') || '0.3' + ), + }, + intentRecognition: { + provider: (getOptionalEnv('INTENT_RECOGNITION_PROVIDER') || + 'openai') as AIProvider, + modelName: getOptionalEnv('INTENT_RECOGNITION_MODEL') || 'gpt-4o-mini', + temperature: parseFloat( + getOptionalEnv('INTENT_RECOGNITION_TEMPERATURE') || '0.3' + ), + }, + }, + embedding: { + provider: (getOptionalEnv('EMBEDDING_PROVIDER') || + 'openai') as AIProvider, + modelName: getOptionalEnv('EMBEDDING_MODEL') || 'text-embedding-3-large', + dimensions: parseInt(getOptionalEnv('EMBEDDING_DIMENSIONS') || '2000'), + apiKey: getOptionalEnv('EMBEDDING_API_KEY') || '', + batchSize: parseInt(getOptionalEnv('EMBEDDING_BATCH_SIZE') || '50'), }, chunking: { defaultChunkSize: parseInt( @@ -191,11 +297,19 @@ export const configuration: Configuration = { ), enhanceWithContext: getOptionalEnv('CHUNK_ENHANCE_WITH_CONTEXT') === 'true', - contextEnhancerModel: - getOptionalEnv('CHUNK_CONTEXT_ENHANCER_MODEL') || 'gpt-4o-mini', - contextEnhancerTemperature: parseFloat( - getOptionalEnv('CHUNK_CONTEXT_ENHANCER_TEMPERATURE') || '0.3' - ), + }, + retrieval: { + hybridSearch: { + semanticSearchWeight: parseFloat( + getOptionalEnv('RETRIEVAL_HYBRID_SEARCH_SEMANTIC_WEIGHT') || '0.7' + ), + keywordSearchWeight: parseFloat( + getOptionalEnv('RETRIEVAL_HYBRID_SEARCH_KEYWORD_WEIGHT') || '0.3' + ), + maxResults: parseInt( + getOptionalEnv('RETRIEVAL_HYBRID_SEARCH_MAX_RESULTS') || '20' + ), + }, }, }, }; diff --git a/apps/server/src/lib/nodes.ts b/apps/server/src/lib/nodes.ts index 4f2ea9f9..bcae67f5 100644 --- a/apps/server/src/lib/nodes.ts +++ b/apps/server/src/lib/nodes.ts @@ -35,6 +35,7 @@ import { checkCollaboratorChanges, } from '@/lib/collaborations'; import { jobService } from '@/services/job-service'; +import { configuration } from '@/lib/configuration'; const debug = createDebugger('server:lib:nodes'); @@ -169,6 +170,18 @@ export const createNode = async ( }); } + // Schedule node embedding + await jobService.addJob( + { + type: 'embed_node', + nodeId: input.nodeId, + }, + { + jobId: `embed_node:${input.nodeId}`, + delay: configuration.ai.entryEmbedDelay, + } + ); + return { node: createdNode, }; @@ -290,6 +303,18 @@ export const tryUpdateNode = async ( }); } + // Schedule node embedding + await jobService.addJob( + { + type: 'embed_node', + nodeId: input.nodeId, + }, + { + jobId: `embed_node:${input.nodeId}`, + delay: configuration.ai.entryEmbedDelay, + } + ); + return { type: 'success', output: { @@ -396,6 +421,18 @@ export const createNodeFromMutation = async ( }); } + // Schedule node embedding + await jobService.addJob( + { + type: 'embed_node', + nodeId: mutation.id, + }, + { + jobId: `embed_node:${mutation.id}`, + delay: configuration.ai.entryEmbedDelay, + } + ); + return { node: createdNode, }; @@ -527,6 +564,18 @@ const tryUpdateNodeFromMutation = async ( }); } + // Schedule node embedding + await jobService.addJob( + { + type: 'embed_node', + nodeId: mutation.id, + }, + { + jobId: `embed_node:${mutation.id}`, + delay: configuration.ai.entryEmbedDelay, + } + ); + return { type: 'success', output: { diff --git a/apps/server/src/services/chunking-service.ts b/apps/server/src/services/chunking-service.ts index b60a98e2..f02acc69 100644 --- a/apps/server/src/services/chunking-service.ts +++ b/apps/server/src/services/chunking-service.ts @@ -1,68 +1,98 @@ +// Updated chunking-service.ts import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter'; -import { ChatOpenAI } from '@langchain/openai'; -import { SystemMessage } from '@langchain/core/messages'; - import { configuration } from '@/lib/configuration'; +import { database } from '@/data/database'; +import { addContextToChunk } from '@/services/llm-service'; + +export type ChunkingMetadata = { + nodeId: string; + type: string; + name?: string; + parentName?: string; + spaceName?: string; + createdAt: Date; +}; export class ChunkingService { - public async chunkText(text: string): Promise { + // Unified chunkText that optionally enriches the chunk with context metadata. + public async chunkText( + text: string, + metadataInfo?: { type: 'node' | 'document'; id: string } + ): Promise { const chunkSize = configuration.ai.chunking.defaultChunkSize; const chunkOverlap = configuration.ai.chunking.defaultOverlap; - const splitter = new RecursiveCharacterTextSplitter({ chunkSize, chunkOverlap, }); - const docs = await splitter.createDocuments([text]); - let chunks = docs.map((doc) => doc.pageContent); - - chunks = chunks.filter((c) => c.trim().length > 10); - + let chunks = docs + .map((doc) => doc.pageContent) + .filter((c) => c.trim().length > 10); if (configuration.ai.chunking.enhanceWithContext) { - const enriched: string[] = []; - for (const chunk of chunks) { - const c = await this.addContextToChunk(chunk, text); - enriched.push(c); - } - return enriched; + // Fetch unified metadata (using a single query if possible) + const metadata = metadataInfo + ? await this.fetchMetadata(metadataInfo) + : undefined; + chunks = await Promise.all( + chunks.map(async (chunk) => { + return addContextToChunk(chunk, text, metadata); + }) + ); } - return chunks; } - private async addContextToChunk( - chunk: string, - fullText: string - ): Promise { - try { - const chat = new ChatOpenAI({ - openAIApiKey: configuration.ai.openai.apiKey, - modelName: configuration.ai.chunking.contextEnhancerModel, - temperature: configuration.ai.chunking.contextEnhancerTemperature, - maxTokens: 200, - }); - - const prompt = ` - -${fullText} - - -Here is the chunk we want to situate in context: - -${chunk} - - -Generate a short (50-100 tokens) contextual prefix that seamlessly provides background or location info for the chunk. Then include the original chunk text below it without any extra separator. - `; - - const response = await chat.invoke([new SystemMessage(prompt)]); - const context = (response.content.toString() ?? '').trim(); - - return `${context} ${chunk}`; - } catch (err) { - console.error('Error adding context to chunk:', err); - return chunk; + // A unified metadata fetch which uses a join to gather node and parent (space) details. + private async fetchMetadata(info: { + type: 'node' | 'document'; + id: string; + }): Promise { + if (info.type === 'node') { + // Fetch node along with parent (if exists) and the root (assumed to be the space) + const result = await database + .selectFrom('nodes') + .leftJoin('nodes as parent', 'nodes.parent_id', 'parent.id') + .leftJoin('nodes as root', 'nodes.root_id', 'root.id') + .select([ + 'nodes.id as nodeId', + 'nodes.type', + "nodes.attributes->>'name' as name", + "parent.attributes->>'name' as parentName", + "root.attributes->>'name' as spaceName", + 'nodes.created_at as createdAt', + ]) + .where('nodes.id', '=', info.id) + .executeTakeFirst(); + if (!result) return undefined; + return { + nodeId: result.nodeId, + type: result.type, + name: result.name, + parentName: result.parentName, + spaceName: result.spaceName, + createdAt: result.createdAt, + }; + } else { + // For documents, assume similar metadata based on associated node. + const result = await database + .selectFrom('documents') + .innerJoin('nodes', 'documents.id', 'nodes.id') + .select([ + 'nodes.id as nodeId', + 'nodes.type', + "nodes.attributes->>'name' as name", + 'nodes.created_at as createdAt', + ]) + .where('documents.id', '=', info.id) + .executeTakeFirst(); + if (!result) return undefined; + return { + nodeId: result.nodeId, + type: result.type, + name: result.name, + createdAt: result.createdAt, + }; } } } diff --git a/apps/server/src/services/document-retrieval-service.ts b/apps/server/src/services/document-retrieval-service.ts new file mode 100644 index 00000000..16cc7a2f --- /dev/null +++ b/apps/server/src/services/document-retrieval-service.ts @@ -0,0 +1,189 @@ +import { OpenAIEmbeddings } from '@langchain/openai'; +import { Document } from '@langchain/core/documents'; +import { database } from '@/data/database'; +import { configuration } from '@/lib/configuration'; +import { sql } from 'kysely'; + +type SearchResult = { + id: string; + text: string; + score: number; + type: 'semantic' | 'keyword'; + createdAt?: Date; + chunkIndex: number; +}; + +export class DocumentRetrievalService { + private embeddings = new OpenAIEmbeddings({ + apiKey: configuration.ai.embedding.apiKey, + modelName: configuration.ai.embedding.modelName, + dimensions: configuration.ai.embedding.dimensions, + }); + + public async retrieve( + query: string, + workspaceId: string, + limit = configuration.ai.retrieval.hybridSearch.maxResults + ): Promise { + const embedding = await this.embeddings.embedQuery(query); + if (!embedding) return []; + const semanticResults = await this.semanticSearch( + embedding, + workspaceId, + limit + ); + const keywordResults = await this.keywordSearch(query, workspaceId, limit); + return this.combineSearchResults(semanticResults, keywordResults); + } + + private async semanticSearch( + embedding: number[], + workspaceId: string, + limit: number + ): Promise { + const results = await database + .selectFrom('document_embeddings') + .innerJoin('documents', 'documents.id', 'document_embeddings.document_id') + .select((eb) => [ + 'document_embeddings.document_id as id', + 'document_embeddings.text', + 'documents.created_at', + 'document_embeddings.chunk as chunk_index', + sql`('[${embedding.join(',')}]'::vector) <=> document_embeddings.embedding_vector`.as( + 'similarity' + ), + ]) + .where('document_embeddings.workspace_id', '=', workspaceId) + .groupBy([ + 'document_embeddings.document_id', + 'document_embeddings.text', + 'documents.created_at', + 'document_embeddings.chunk', + ]) + .orderBy('similarity', 'asc') + .limit(limit) + .execute(); + + return results.map((result) => ({ + id: result.id, + text: result.text, + score: result.similarity, + type: 'semantic', + createdAt: result.created_at, + chunkIndex: result.chunk_index, + })); + } + + private async keywordSearch( + query: string, + workspaceId: string, + limit: number + ): Promise { + const results = await database + .selectFrom('document_embeddings') + .innerJoin('documents', 'documents.id', 'document_embeddings.document_id') + .select((eb) => [ + 'document_embeddings.document_id as id', + 'document_embeddings.text', + 'documents.created_at', + 'document_embeddings.chunk as chunk_index', + sql`ts_rank(document_embeddings.search_vector, websearch_to_tsquery('english', ${query}))`.as( + 'rank' + ), + ]) + .where('document_embeddings.workspace_id', '=', workspaceId) + .where( + (eb) => + sql`document_embeddings.search_vector @@ websearch_to_tsquery('english', ${query})` + ) + .groupBy([ + 'document_embeddings.document_id', + 'document_embeddings.text', + 'documents.created_at', + 'document_embeddings.chunk', + ]) + .orderBy('rank', 'desc') + .limit(limit) + .execute(); + + return results.map((result) => ({ + id: result.id, + text: result.text, + score: result.rank, + type: 'keyword', + createdAt: result.created_at, + chunkIndex: result.chunk_index, + })); + } + + private combineSearchResults( + semanticResults: SearchResult[], + keywordResults: SearchResult[] + ): Document[] { + const { semanticSearchWeight, keywordSearchWeight } = + configuration.ai.retrieval.hybridSearch; + const maxSemanticScore = Math.max( + ...semanticResults.map((r) => r.score), + 1 + ); + const maxKeywordScore = Math.max(...keywordResults.map((r) => r.score), 1); + const combined = new Map(); + const createKey = (result: SearchResult) => + `${result.id}-${result.chunkIndex}`; + const calculateRecencyBoost = ( + createdAt: Date | undefined | null + ): number => { + if (!createdAt) return 1; + const now = new Date(); + const ageInDays = + (now.getTime() - createdAt.getTime()) / (1000 * 60 * 60 * 24); + return ageInDays <= 7 ? 1 + (1 - ageInDays / 7) * 0.2 : 1; + }; + + semanticResults.forEach((result) => { + const key = createKey(result); + const recencyBoost = calculateRecencyBoost(result.createdAt); + const normalizedScore = + ((maxSemanticScore - result.score) / maxSemanticScore) * + semanticSearchWeight; + combined.set(key, { + ...result, + finalScore: normalizedScore * recencyBoost, + }); + }); + + keywordResults.forEach((result) => { + const key = createKey(result); + const recencyBoost = calculateRecencyBoost(result.createdAt); + const normalizedScore = + (result.score / maxKeywordScore) * keywordSearchWeight; + if (combined.has(key)) { + const existing = combined.get(key)!; + existing.finalScore += normalizedScore * recencyBoost; + } else { + combined.set(key, { + ...result, + finalScore: normalizedScore * recencyBoost, + }); + } + }); + + return Array.from(combined.values()) + .sort((a, b) => b.finalScore - a.finalScore) + .map( + (result) => + new Document({ + pageContent: result.text, + metadata: { + id: result.id, + score: result.finalScore, + createdAt: result.createdAt, + type: 'document', + chunkIndex: result.chunkIndex, + }, + }) + ); + } +} + +export const documentRetrievalService = new DocumentRetrievalService(); diff --git a/apps/server/src/services/llm-service.ts b/apps/server/src/services/llm-service.ts new file mode 100644 index 00000000..f4df156f --- /dev/null +++ b/apps/server/src/services/llm-service.ts @@ -0,0 +1,278 @@ +// Updated llm-service.ts +import { ChatOpenAI } from '@langchain/openai'; +import { ChatGoogleGenerativeAI } from '@langchain/google-genai'; +import { PromptTemplate, ChatPromptTemplate } from '@langchain/core/prompts'; +import { StringOutputParser } from '@langchain/core/output_parsers'; +import { HumanMessage } from '@langchain/core/messages'; +import { configuration } from '@/lib/configuration'; +import { Document } from '@langchain/core/documents'; +import { z } from 'zod'; +import { NodeAttributes } from '@colanode/core'; + +// Use proper Zod schemas and updated prompt templates + +const rerankedDocumentsSchema = z.object({ + rankings: z.array( + z.object({ + index: z.number(), + score: z.number().min(0).max(1), + type: z.enum(['node', 'document']), + sourceId: z.string(), + }) + ), +}); +type RerankedDocuments = z.infer; + +const citedAnswerSchema = z.object({ + answer: z.string(), + citations: z.array( + z.object({ + sourceId: z.string(), + quote: z.string(), + }) + ), +}); +type CitedAnswer = z.infer; + +export function getChatModel( + task: keyof typeof configuration.ai.models +): ChatOpenAI | ChatGoogleGenerativeAI { + const modelConfig = configuration.ai.models[task]; + if (!configuration.ai.enabled) { + throw new Error('AI is disabled.'); + } + const providerConfig = configuration.ai.providers[modelConfig.provider]; + if (!providerConfig.enabled) { + throw new Error(`${modelConfig.provider} provider is disabled.`); + } + switch (modelConfig.provider) { + case 'openai': + return new ChatOpenAI({ + modelName: modelConfig.modelName, + temperature: modelConfig.temperature, + openAIApiKey: providerConfig.apiKey, + }); + case 'google': + return new ChatGoogleGenerativeAI({ + modelName: modelConfig.modelName, + temperature: modelConfig.temperature, + apiKey: providerConfig.apiKey, + }); + default: + throw new Error(`Unsupported AI provider: ${modelConfig.provider}`); + } +} + +// Updated prompt templates using type-safe node types and context +const queryRewritePrompt = PromptTemplate.fromTemplate( + `You are an expert at rewriting queries for information retrieval within Colanode. + +Guidelines: +1. Extract the core information need. +2. Remove filler words. +3. Preserve key technical terms and dates. + +Original query: +{query} + +Rewrite the query and return only the rewritten version.` +); + +const summarizationPrompt = PromptTemplate.fromTemplate( + `Summarize the following text focusing on key points relevant to the user's query. +If the text is short (<100 characters), return it as is. + +Text: {text} +User Query: {query}` +); + +const rerankPrompt = PromptTemplate.fromTemplate( + `Re-rank the following list of documents by their relevance to the query. +For each document, provide: +- Original index (from input) +- A relevance score between 0 and 1 +- Document type (node or document) +- Source ID + +User query: +{query} + +Documents: +{context} + +Return an array of rankings in JSON format.` +); + +const answerPrompt = ChatPromptTemplate.fromTemplate( + `You are Colanode's AI assistant. + +CURRENT TIME: {currentTimestamp} +WORKSPACE: {workspaceName} +USER: {userName} ({userEmail}) + +CONVERSATION HISTORY: +{formattedChatHistory} + +RELATED CONTEXT: +Messages: +{formattedMessages} +Documents: +{formattedDocuments} + +USER QUERY: +{question} + +Provide a clear, professional answer. Then, in a separate citations array, list exact quotes (with source IDs) used to form your answer. +Return the result as JSON with keys "answer" and "citations".` +); + +const intentRecognitionPrompt = PromptTemplate.fromTemplate( + `Determine if the following user query requires retrieving additional context. +Return exactly one value: "retrieve" or "no_context". + +Conversation History: +{formattedChatHistory} + +User Query: +{question}` +); + +const noContextPrompt = PromptTemplate.fromTemplate( + `Answer the following query concisely using general knowledge, without retrieving additional context. + +Conversation History: +{formattedChatHistory} + +User Query: +{question} + +Return only the answer.` +); + +export async function rewriteQuery(query: string): Promise { + const task = 'queryRewrite'; + const model = getChatModel(task); + return queryRewritePrompt + .pipe(model) + .pipe(new StringOutputParser()) + .invoke({ query }); +} + +export async function summarizeDocument( + document: Document, + query: string +): Promise { + const task = 'summarization'; + const model = getChatModel(task); + return summarizationPrompt + .pipe(model) + .pipe(new StringOutputParser()) + .invoke({ text: document.pageContent, query }); +} + +export async function rerankDocuments( + documents: { content: string; type: string; sourceId: string }[], + query: string +): Promise< + Array<{ index: number; score: number; type: string; sourceId: string }> +> { + const task = 'rerank'; + const model = getChatModel(task).withStructuredOutput( + rerankedDocumentsSchema + ); + const formattedContext = documents + .map( + (doc, idx) => + `${idx}. Type: ${doc.type}, Content: ${doc.content}, ID: ${doc.sourceId}\n` + ) + .join('\n'); + const result = (await rerankPrompt + .pipe(model) + .invoke({ query, context: formattedContext })) as RerankedDocuments; + return result.rankings; +} + +export async function generateFinalAnswer(promptArgs: { + currentTimestamp: string; + workspaceName: string; + userName: string; + userEmail: string; + formattedChatHistory: string; + formattedMessages: string; + formattedDocuments: string; + question: string; +}): Promise<{ + answer: string; + citations: Array<{ sourceId: string; quote: string }>; +}> { + const task = 'response'; + const model = getChatModel(task).withStructuredOutput(citedAnswerSchema); + return (await answerPrompt.pipe(model).invoke(promptArgs)) as CitedAnswer; +} + +export async function generateNoContextAnswer( + query: string, + chatHistory: string = '' +): Promise { + const task = 'noContext'; + const model = getChatModel(task); + return noContextPrompt + .pipe(model) + .pipe(new StringOutputParser()) + .invoke({ question: query, formattedChatHistory: chatHistory }); +} + +export async function assessUserIntent( + query: string, + chatHistory: string +): Promise<'retrieve' | 'no_context'> { + const task = 'intentRecognition'; + const model = getChatModel(task); + const result = await intentRecognitionPrompt + .pipe(model) + .pipe(new StringOutputParser()) + .invoke({ question: query, formattedChatHistory: chatHistory }); + return result.trim().toLowerCase() === 'no_context' + ? 'no_context' + : 'retrieve'; +} + +export async function addContextToChunk( + chunk: string, + fullText: string, + metadata?: any +): Promise { + try { + const task = 'contextEnhancer'; + const model = getChatModel(task); + // Choose a prompt variant based on metadata type (node/document) if available. + const promptTemplate = PromptTemplate.fromTemplate( + `Using the following context information: +{contextInfo} + +Full content: +{fullText} + +Given this chunk: +{chunk} + +Generate a short (50–100 tokens) contextual prefix (do not repeat the chunk) and prepend it to the chunk.` + ); + const contextInfo = metadata + ? `Type: ${metadata.type}, Name: ${metadata.name || 'N/A'}, Parent: ${metadata.parentName || 'N/A'}, Space: ${metadata.spaceName || 'N/A'}` + : 'No additional context available.'; + const prompt = await promptTemplate.format({ + contextInfo, + fullText, + chunk, + }); + const response = await model.invoke([ + new HumanMessage({ content: prompt }), + ]); + const prefix = (response.content.toString() || '').trim(); + return prefix ? `${prefix}\n\n${chunk}` : chunk; + } catch (err) { + console.error('Error in addContextToChunk:', err); + return chunk; + } +} diff --git a/apps/server/src/services/node-retrieval-service.ts b/apps/server/src/services/node-retrieval-service.ts new file mode 100644 index 00000000..470d0f09 --- /dev/null +++ b/apps/server/src/services/node-retrieval-service.ts @@ -0,0 +1,199 @@ +import { OpenAIEmbeddings } from '@langchain/openai'; +import { Document } from '@langchain/core/documents'; +import { database } from '@/data/database'; +import { configuration } from '@/lib/configuration'; +import { sql } from 'kysely'; + +type SearchResult = { + id: string; + text: string; + score: number; + type: 'semantic' | 'keyword'; + createdAt?: Date; + chunkIndex: number; +}; + +export class NodeRetrievalService { + private embeddings = new OpenAIEmbeddings({ + apiKey: configuration.ai.embedding.apiKey, + modelName: configuration.ai.embedding.modelName, + dimensions: configuration.ai.embedding.dimensions, + }); + + public async retrieve( + query: string, + workspaceId: string, + userId: string, + limit = configuration.ai.retrieval.hybridSearch.maxResults + ): Promise { + const embedding = await this.embeddings.embedQuery(query); + if (!embedding) return []; + const semanticResults = await this.semanticSearch( + embedding, + workspaceId, + userId, + limit + ); + const keywordResults = await this.keywordSearch( + query, + workspaceId, + userId, + limit + ); + return this.combineSearchResults(semanticResults, keywordResults); + } + + private async semanticSearch( + embedding: number[], + workspaceId: string, + userId: string, + limit: number + ): Promise { + const results = await database + .selectFrom('node_embeddings') + .innerJoin('nodes', 'nodes.id', 'node_embeddings.node_id') + .select((eb) => [ + 'node_embeddings.node_id as id', + 'node_embeddings.text', + 'nodes.created_at', + 'node_embeddings.chunk as chunk_index', + // Wrap raw expression to satisfy type: + sql`('[${embedding.join(',')}]'::vector) <=> node_embeddings.embedding_vector`.as( + 'similarity' + ), + ]) + .where('node_embeddings.workspace_id', '=', workspaceId) + .groupBy([ + 'node_embeddings.node_id', + 'node_embeddings.text', + 'nodes.created_at', + 'node_embeddings.chunk', + ]) + .orderBy('similarity', 'asc') + .limit(limit) + .execute(); + + return results.map((result) => ({ + id: result.id, + text: result.text, + score: result.similarity, + type: 'semantic', + createdAt: result.created_at, + chunkIndex: result.chunk_index, + })); + } + + private async keywordSearch( + query: string, + workspaceId: string, + userId: string, + limit: number + ): Promise { + const results = await database + .selectFrom('node_embeddings') + .innerJoin('nodes', 'nodes.id', 'node_embeddings.node_id') + .select((eb) => [ + 'node_embeddings.node_id as id', + 'node_embeddings.text', + 'nodes.created_at', + 'node_embeddings.chunk as chunk_index', + sql`ts_rank(node_embeddings.search_vector, websearch_to_tsquery('english', ${query}))`.as( + 'rank' + ), + ]) + .where('node_embeddings.workspace_id', '=', workspaceId) + .where( + (eb) => + sql`node_embeddings.search_vector @@ websearch_to_tsquery('english', ${query})` + ) + .groupBy([ + 'node_embeddings.node_id', + 'node_embeddings.text', + 'nodes.created_at', + 'node_embeddings.chunk', + ]) + .orderBy('rank', 'desc') + .limit(limit) + .execute(); + + return results.map((result) => ({ + id: result.id, + text: result.text, + score: result.rank, + type: 'keyword', + createdAt: result.created_at, + chunkIndex: result.chunk_index, + })); + } + + private combineSearchResults( + semanticResults: SearchResult[], + keywordResults: SearchResult[] + ): Document[] { + const { semanticSearchWeight, keywordSearchWeight } = + configuration.ai.retrieval.hybridSearch; + const maxSemanticScore = Math.max( + ...semanticResults.map((r) => r.score), + 1 + ); + const maxKeywordScore = Math.max(...keywordResults.map((r) => r.score), 1); + const combined = new Map(); + const createKey = (result: SearchResult) => + `${result.id}-${result.chunkIndex}`; + const calculateRecencyBoost = ( + createdAt: Date | undefined | null + ): number => { + if (!createdAt) return 1; + const now = new Date(); + const ageInDays = + (now.getTime() - createdAt.getTime()) / (1000 * 60 * 60 * 24); + return ageInDays <= 7 ? 1 + (1 - ageInDays / 7) * 0.2 : 1; + }; + + semanticResults.forEach((result) => { + const key = createKey(result); + const recencyBoost = calculateRecencyBoost(result.createdAt); + const normalizedScore = + ((maxSemanticScore - result.score) / maxSemanticScore) * + semanticSearchWeight; + combined.set(key, { + ...result, + finalScore: normalizedScore * recencyBoost, + }); + }); + + keywordResults.forEach((result) => { + const key = createKey(result); + const recencyBoost = calculateRecencyBoost(result.createdAt); + const normalizedScore = + (result.score / maxKeywordScore) * keywordSearchWeight; + if (combined.has(key)) { + const existing = combined.get(key)!; + existing.finalScore += normalizedScore * recencyBoost; + } else { + combined.set(key, { + ...result, + finalScore: normalizedScore * recencyBoost, + }); + } + }); + + return Array.from(combined.values()) + .sort((a, b) => b.finalScore - a.finalScore) + .map( + (result) => + new Document({ + pageContent: result.text, + metadata: { + id: result.id, + score: result.finalScore, + createdAt: result.createdAt, + type: 'node', + chunkIndex: result.chunkIndex, + }, + }) + ); + } +} + +export const nodeRetrievalService = new NodeRetrievalService();