Add vector embedding support for nodes and documents

This commit is contained in:
Ylber Gashi
2025-02-17 13:56:31 +01:00
parent f4ad5756f3
commit 01c8f4273e
15 changed files with 1575 additions and 74 deletions

View File

@@ -0,0 +1,10 @@
import { Migration, sql } from 'kysely';
export const createVectorExtension: Migration = {
up: async (db) => {
await sql`CREATE EXTENSION IF NOT EXISTS vector`.execute(db);
},
down: async (db) => {
await sql`DROP EXTENSION IF EXISTS vector`.execute(db);
},
};

View File

@@ -0,0 +1,42 @@
import { Migration, sql } from 'kysely';
export const createNodeEmbeddingsTable: Migration = {
up: async (db) => {
await db.schema
.createTable('node_embeddings')
.addColumn('node_id', 'varchar(30)', (col) => col.notNull())
.addColumn('chunk', 'integer', (col) => col.notNull())
.addColumn('parent_id', 'varchar(30)')
.addColumn('root_id', 'varchar(30)', (col) => col.notNull())
.addColumn('workspace_id', 'varchar(30)', (col) => col.notNull())
.addColumn('text', 'text', (col) => col.notNull())
.addColumn('embedding_vector', sql`vector(2000)`, (col) => col.notNull())
.addColumn(
'search_vector',
sql`tsvector GENERATED ALWAYS AS (to_tsvector('english', text)) STORED`
)
.addColumn('created_at', 'timestamptz', (col) => col.notNull())
.addColumn('updated_at', 'timestamptz')
.addPrimaryKeyConstraint('node_embeddings_pkey', ['node_id', 'chunk'])
.execute();
await sql`
CREATE INDEX node_embeddings_embedding_vector_idx
ON node_embeddings
USING hnsw(embedding_vector vector_cosine_ops)
WITH (
m = 16,
ef_construction = 64
);
`.execute(db);
await sql`
CREATE INDEX node_embeddings_search_vector_idx
ON node_embeddings
USING GIN (search_vector);
`.execute(db);
},
down: async (db) => {
await db.schema.dropTable('node_embeddings').execute();
},
};

View File

@@ -0,0 +1,43 @@
import { Migration, sql } from 'kysely';
export const createDocumentEmbeddingsTable: Migration = {
up: async (db) => {
await db.schema
.createTable('document_embeddings')
.addColumn('document_id', 'varchar(30)', (col) => col.notNull())
.addColumn('chunk', 'integer', (col) => col.notNull())
.addColumn('workspace_id', 'varchar(30)', (col) => col.notNull())
.addColumn('text', 'text', (col) => col.notNull())
.addColumn('embedding_vector', sql`vector(2000)`, (col) => col.notNull())
.addColumn(
'search_vector',
sql`tsvector GENERATED ALWAYS AS (to_tsvector('english', text)) STORED`
)
.addColumn('created_at', 'timestamptz', (col) => col.notNull())
.addColumn('updated_at', 'timestamptz')
.addPrimaryKeyConstraint('document_embeddings_pkey', [
'document_id',
'chunk',
])
.execute();
await sql`
CREATE INDEX document_embeddings_embedding_vector_idx
ON document_embeddings
USING hnsw(embedding_vector vector_cosine_ops)
WITH (
m = 16,
ef_construction = 64
);
`.execute(db);
await sql`
CREATE INDEX document_embeddings_search_vector_idx
ON document_embeddings
USING GIN (search_vector);
`.execute(db);
},
down: async (db) => {
await db.schema.dropTable('document_embeddings').execute();
},
};

View File

@@ -12,6 +12,9 @@ import { createNodePathsTable } from './00009-create-node-paths-table';
import { createCollaborationsTable } from './00010-create-collaborations-table'; import { createCollaborationsTable } from './00010-create-collaborations-table';
import { createDocumentsTable } from './00011-create-documents-table'; import { createDocumentsTable } from './00011-create-documents-table';
import { createDocumentUpdatesTable } from './00012-create-document-updates-table'; import { createDocumentUpdatesTable } from './00012-create-document-updates-table';
import { createNodeEmbeddingsTable } from './00014-create-node-embeddings-table';
import { createDocumentEmbeddingsTable } from './00015-create-document-embeddings-table';
import { createVectorExtension } from './00013-create-vector-extension';
export const databaseMigrations: Record<string, Migration> = { export const databaseMigrations: Record<string, Migration> = {
'00001_create_accounts_table': createAccountsTable, '00001_create_accounts_table': createAccountsTable,
@@ -26,4 +29,7 @@ export const databaseMigrations: Record<string, Migration> = {
'00010_create_collaborations_table': createCollaborationsTable, '00010_create_collaborations_table': createCollaborationsTable,
'00011_create_documents_table': createDocumentsTable, '00011_create_documents_table': createDocumentsTable,
'00012_create_document_updates_table': createDocumentUpdatesTable, '00012_create_document_updates_table': createDocumentUpdatesTable,
'00013_create_vector_extension': createVectorExtension,
'00014_create_node_embeddings_table': createNodeEmbeddingsTable,
'00015_create_document_embeddings_table': createDocumentEmbeddingsTable,
}; };

View File

@@ -221,6 +221,38 @@ export type SelectDocumentUpdate = Selectable<DocumentUpdateTable>;
export type CreateDocumentUpdate = Insertable<DocumentUpdateTable>; export type CreateDocumentUpdate = Insertable<DocumentUpdateTable>;
export type UpdateDocumentUpdate = Updateable<DocumentUpdateTable>; export type UpdateDocumentUpdate = Updateable<DocumentUpdateTable>;
interface NodeEmbeddingTable {
node_id: ColumnType<string, string, never>;
chunk: ColumnType<number, number, number>;
parent_id: ColumnType<string | null, string | null, string | null>;
root_id: ColumnType<string, string, never>;
workspace_id: ColumnType<string, string, never>;
text: ColumnType<string, string, string>;
embedding_vector: ColumnType<number[], number[], number[]>;
search_vector: ColumnType<never, never, never>;
created_at: ColumnType<Date, Date, never>;
updated_at: ColumnType<Date | null, Date | null, Date | null>;
}
export type SelectNodeEmbedding = Selectable<NodeEmbeddingTable>;
export type CreateNodeEmbedding = Insertable<NodeEmbeddingTable>;
export type UpdateNodeEmbedding = Updateable<NodeEmbeddingTable>;
interface DocumentEmbeddingTable {
document_id: ColumnType<string, string, never>;
chunk: ColumnType<number, number, number>;
workspace_id: ColumnType<string, string, never>;
text: ColumnType<string, string, string>;
embedding_vector: ColumnType<number[], number[], number[]>;
search_vector: ColumnType<never, never, never>;
created_at: ColumnType<Date, Date, never>;
updated_at: ColumnType<Date | null, Date | null, Date | null>;
}
export type SelectDocumentEmbedding = Selectable<DocumentEmbeddingTable>;
export type CreateDocumentEmbedding = Insertable<DocumentEmbeddingTable>;
export type UpdateDocumentEmbedding = Updateable<DocumentEmbeddingTable>;
export interface DatabaseSchema { export interface DatabaseSchema {
accounts: AccountTable; accounts: AccountTable;
devices: DeviceTable; devices: DeviceTable;
@@ -234,4 +266,6 @@ export interface DatabaseSchema {
collaborations: CollaborationTable; collaborations: CollaborationTable;
documents: DocumentTable; documents: DocumentTable;
document_updates: DocumentUpdateTable; document_updates: DocumentUpdateTable;
node_embeddings: NodeEmbeddingTable;
document_embeddings: DocumentEmbeddingTable;
} }

View File

@@ -0,0 +1,198 @@
import {
generateId,
IdType,
extractBlockTexts,
MessageAttributes,
NodeAttributes,
} from '@colanode/core';
import { Document } from '@langchain/core/documents';
import { database } from '@/data/database';
import { eventBus } from '@/lib/event-bus';
import { configuration } from '@/lib/configuration';
import { fetchNode } from '@/lib/nodes';
import { nodeRetrievalService } from '@/services/node-retrieval-service';
import { documentRetrievalService } from '@/services/document-retrieval-service';
import { JobHandler } from '@/types/jobs';
import {
rewriteQuery,
rerankDocuments,
generateFinalAnswer,
generateNoContextAnswer,
assessUserIntent,
} from '@/services/llm-service';
export type AssistantResponseInput = {
type: 'assistant_response';
nodeId: string;
workspaceId: string;
};
declare module '@/types/jobs' {
interface JobMap {
assistant_response: {
input: AssistantResponseInput;
};
}
}
export const assistantResponseHandler = async (
input: AssistantResponseInput
) => {
const { nodeId, workspaceId } = input;
console.log('Starting assistant response handler', { nodeId, workspaceId });
if (!configuration.ai.enabled) return;
const node = await fetchNode(nodeId);
if (!node) return;
// Assume nodes of type 'message' carry the user query.
const userInputText = extractBlockTexts(
node.id,
(node.attributes as MessageAttributes).content
);
if (!userInputText || userInputText.trim() === '') return;
// Fetch user details (assuming created_by is the user id)
const user = await database
.selectFrom('users')
.where('id', '=', node.created_by)
.selectAll()
.executeTakeFirst();
if (!user) return;
// Get conversation history: for example, sibling nodes (other messages with the same parent)
const chatHistoryNodes = await database
.selectFrom('nodes')
.selectAll()
.where('parent_id', '=', node.parent_id)
.where('id', '!=', node.id)
.orderBy('created_at', 'asc')
.limit(10)
.execute();
const chatHistory = chatHistoryNodes.map(
(n) =>
new Document({
pageContent:
extractBlockTexts(
n.id,
(n.attributes as MessageAttributes).content
) || '',
metadata: { id: n.id, type: n.type, createdAt: n.created_at },
})
);
const formattedChatHistory = chatHistory
.map((doc) => {
const ts = doc.metadata.createdAt
? new Date(doc.metadata.createdAt).toLocaleString()
: 'Unknown';
return `- [${ts}] ${doc.metadata.id}: ${doc.pageContent}`;
})
.join('\n');
const intent = await assessUserIntent(userInputText, formattedChatHistory);
let finalAnswer: string;
let citations: Array<{ sourceId: string; quote: string }>;
if (intent === 'no_context') {
finalAnswer = await generateNoContextAnswer(
userInputText,
formattedChatHistory
);
citations = [];
} else {
const rewrittenQuery = await rewriteQuery(userInputText);
const nodeDocs = await nodeRetrievalService.retrieve(
rewrittenQuery,
workspaceId,
user.id
);
const documentDocs = await documentRetrievalService.retrieve(
rewrittenQuery,
workspaceId
);
const allContext = [...nodeDocs, ...documentDocs];
const reranked = await rerankDocuments(
allContext.map((doc, idx) => ({
content: doc.pageContent,
type: doc.metadata.type,
sourceId: doc.metadata.id,
})),
rewrittenQuery
);
const topContext = reranked.slice(0, 5);
const formattedMessages = allContext
.filter((doc) => doc.metadata.type === 'message')
.map((doc) => {
const ts = doc.metadata.createdAt
? new Date(doc.metadata.createdAt).toLocaleString()
: 'Unknown';
return `- [${ts}] ${doc.metadata.id}: ${doc.pageContent}`;
})
.join('\n');
const formattedDocuments = allContext
.filter(
(doc) =>
doc.metadata.type === 'page' || doc.metadata.type === 'document'
)
.map((doc) => {
const ts = doc.metadata.createdAt
? new Date(doc.metadata.createdAt).toLocaleString()
: 'Unknown';
return `- [${ts}] ${doc.metadata.id}: ${doc.pageContent}`;
})
.join('\n');
const promptArgs = {
currentTimestamp: new Date().toISOString(),
workspaceName: workspaceId, // Or retrieve workspace details
userName: user.name || 'User',
userEmail: user.email || 'unknown@example.com',
formattedChatHistory,
formattedMessages,
formattedDocuments,
question: userInputText,
};
const result = await generateFinalAnswer(promptArgs);
finalAnswer = result.answer;
citations = result.citations;
}
// Create a response node (answer message)
const responseNodeId = generateId(IdType.Node);
const responseAttributes = {
type: 'message',
subtype: 'standard',
content: {
[generateId(IdType.Block)]: {
id: generateId(IdType.Block),
type: 'paragraph',
parentId: responseNodeId,
index: 'a',
content: [{ type: 'text', text: finalAnswer }],
},
},
parentId: node.parent_id,
referenceId: node.id,
};
await database
.insertInto('nodes')
.values({
id: responseNodeId,
root_id: node.root_id,
workspace_id: workspaceId,
attributes: JSON.stringify(responseAttributes),
state: Buffer.from(''), // Initial empty state
created_at: new Date(),
created_by: 'colanode_ai',
})
.executeTakeFirst();
eventBus.publish({
type: 'node_created',
nodeId: responseNodeId,
rootId: node.root_id,
workspaceId,
});
};

View File

@@ -0,0 +1,121 @@
// Updated embed-document.ts
import { OpenAIEmbeddings } from '@langchain/openai';
import { ChunkingService } from '@/services/chunking-service';
import { database } from '@/data/database';
import { configuration } from '@/lib/configuration';
import { CreateDocumentEmbedding } from '@/data/schema';
import { sql } from 'kysely';
import { fetchNodeWithContext } from '@/services/node-retrieval-service';
import { DocumentContent, extractBlockTexts } from '@colanode/core';
import { JobHandler } from '@/types/jobs';
export type EmbedDocumentInput = {
type: 'embed_document';
documentId: string;
};
export const embedDocumentHandler: JobHandler<EmbedDocumentInput> = async (
input
) => {
if (!configuration.ai.enabled) return;
const { documentId } = input;
// Retrieve document along with its associated node in one query
const document = await database
.selectFrom('documents')
.select(['id', 'content', 'workspace_id', 'created_at'])
.where('id', '=', documentId)
.executeTakeFirst();
if (!document) return;
// Fetch associated node for context (page node, etc.)
const nodeContext = await fetchNodeWithContext(documentId);
// If available, include node type and parent (space) name in the context.
let header = '';
if (nodeContext) {
const { node, parent, root } = nodeContext;
header = `${node.attributes.type} "${(node.attributes as any).name || ''}"`;
if (parent && parent.attributes.name) {
header += ` in "${parent.attributes.name}"`;
}
if (root && root.attributes.name) {
header += ` [${root.attributes.name}]`;
}
}
// Extract text using document content. (For pages and richtext, use extractBlockTexts.)
const docText = extractBlockTexts(documentId, document.content.blocks) || '';
const fullText = header ? `${header}\n\nContent:\n${docText}` : docText;
if (!fullText.trim()) {
await database
.deleteFrom('document_embeddings')
.where('document_id', '=', documentId)
.execute();
return;
}
const chunkingService = new ChunkingService();
const chunks = await chunkingService.chunkText(fullText, {
type: 'document',
id: documentId,
});
const embeddings = new OpenAIEmbeddings({
apiKey: configuration.ai.embedding.apiKey,
modelName: configuration.ai.embedding.modelName,
dimensions: configuration.ai.embedding.dimensions,
});
const existingEmbeddings = await database
.selectFrom('document_embeddings')
.select(['chunk', 'text'])
.where('document_id', '=', documentId)
.execute();
const embeddingsToUpsert: CreateDocumentEmbedding[] = [];
for (let i = 0; i < chunks.length; i++) {
const chunk = chunks[i];
const existing = existingEmbeddings.find((e) => e.chunk === i);
if (existing && existing.text === chunk) continue;
embeddingsToUpsert.push({
document_id: documentId,
chunk: i,
workspace_id: document.workspace_id,
text: chunk,
embedding_vector: [],
created_at: new Date(),
});
}
const batchSize = configuration.ai.embedding.batchSize;
for (let i = 0; i < embeddingsToUpsert.length; i += batchSize) {
const batch = embeddingsToUpsert.slice(i, i + batchSize);
const textsToEmbed = batch.map((item) => item.text);
const embeddingVectors = await embeddings.embedDocuments(textsToEmbed);
for (let j = 0; j < batch.length; j++) {
batch[j].embedding_vector = embeddingVectors[j] ?? [];
}
}
for (const embedding of embeddingsToUpsert) {
await database
.insertInto('document_embeddings')
.values({
document_id: embedding.document_id,
chunk: embedding.chunk,
workspace_id: embedding.workspace_id,
text: embedding.text,
embedding_vector: sql.raw(
`'[${embedding.embedding_vector.join(',')}]'::vector`
),
created_at: embedding.created_at,
})
.onConflict((oc) =>
oc.columns(['document_id', 'chunk']).doUpdateSet({
text: sql.ref('excluded.text'),
embedding_vector: sql.ref('excluded.embedding_vector'),
updated_at: new Date(),
})
)
.execute();
}
};

View File

@@ -0,0 +1,183 @@
import { OpenAIEmbeddings } from '@langchain/openai';
import { extractBlockTexts, NodeAttributes, FieldValue } from '@colanode/core';
import { ChunkingService } from '@/services/chunking-service';
import { database } from '@/data/database';
import { configuration } from '@/lib/configuration';
import { CreateNodeEmbedding } from '@/data/schema';
import { sql } from 'kysely';
import { fetchNode } from '@/lib/nodes';
import { JobHandler } from '@/types/jobs';
export type EmbedNodeInput = {
type: 'embed_node';
nodeId: string;
};
declare module '@/types/jobs' {
interface JobMap {
embed_node: {
input: EmbedNodeInput;
};
}
}
const formatFieldValue = (fieldValue: FieldValue): string => {
switch (fieldValue.type) {
case 'boolean':
return `${fieldValue.value ? 'Yes' : 'No'}`;
case 'string_array':
return (fieldValue.value as string[]).join(', ');
case 'number':
case 'string':
case 'text':
return String(fieldValue.value);
default:
return '';
}
};
const extractNodeText = async (
nodeId: string,
attributes: NodeAttributes
): Promise<string> => {
switch (attributes.type) {
case 'message':
return extractBlockTexts(nodeId, attributes.content) ?? '';
case 'record': {
const sections: string[] = [];
// Fetch the database node to get its name
const databaseNode = await fetchNode(attributes.databaseId);
const databaseName =
databaseNode?.attributes.type === 'database'
? databaseNode.attributes.name
: attributes.databaseId;
// Add field context with database name
sections.push(`Field "${attributes.name}" in database "${databaseName}"`);
// Process field value
Object.entries(attributes.fields).forEach(([fieldName, fieldValue]) => {
if (!fieldValue || !('type' in fieldValue)) {
return;
}
const value = formatFieldValue(fieldValue as FieldValue);
if (value) {
sections.push(value);
}
});
return sections.join('\n');
}
default:
return '';
}
};
export const embedNodeHandler = async (input: {
type: 'embed_node';
nodeId: string;
}) => {
if (!configuration.ai.enabled) {
return;
}
const { nodeId } = input;
const node = await fetchNode(nodeId);
if (!node) {
return;
}
// Skip page nodes (document content is handled separately in the embed document job)
if (node.type === 'page') {
return;
}
const text = await extractNodeText(node.id, node.attributes);
if (!text || text.trim() === '') {
await database
.deleteFrom('node_embeddings')
.where('node_id', '=', nodeId)
.execute();
return;
}
const chunkingService = new ChunkingService();
const chunks = await chunkingService.chunkText(text, {
type: 'node',
id: nodeId,
});
const embeddings = new OpenAIEmbeddings({
apiKey: configuration.ai.embedding.apiKey,
modelName: configuration.ai.embedding.modelName,
dimensions: configuration.ai.embedding.dimensions,
});
const existingEmbeddings = await database
.selectFrom('node_embeddings')
.select(['chunk', 'text'])
.where('node_id', '=', nodeId)
.execute();
const embeddingsToCreateOrUpdate: CreateNodeEmbedding[] = [];
for (let i = 0; i < chunks.length; i++) {
const chunk = chunks[i];
if (!chunk) continue;
const existing = existingEmbeddings.find((e) => e.chunk === i);
if (existing && existing.text === chunk) continue;
embeddingsToCreateOrUpdate.push({
node_id: nodeId,
chunk: i,
parent_id: node.parent_id,
root_id: node.root_id,
workspace_id: node.workspace_id,
text: chunk,
embedding_vector: [],
created_at: new Date(),
});
}
const batchSize = configuration.ai.embedding.batchSize;
for (let i = 0; i < embeddingsToCreateOrUpdate.length; i += batchSize) {
const batch = embeddingsToCreateOrUpdate.slice(i, i + batchSize);
const textsToEmbed = batch.map((item) => item.text);
const embeddingVectors = await embeddings.embedDocuments(textsToEmbed);
for (let j = 0; j < batch.length; j++) {
const vector = embeddingVectors[j];
const batchItem = batch[j];
if (batchItem) {
batchItem.embedding_vector = vector ?? [];
}
}
}
if (embeddingsToCreateOrUpdate.length === 0) {
return;
}
for (const embedding of embeddingsToCreateOrUpdate) {
await database
.insertInto('node_embeddings')
.values({
node_id: embedding.node_id,
chunk: embedding.chunk,
parent_id: embedding.parent_id,
root_id: embedding.root_id,
workspace_id: embedding.workspace_id,
text: embedding.text,
embedding_vector: sql.raw(
`'[${embedding.embedding_vector.join(',')}]'::vector`
),
created_at: embedding.created_at,
})
.onConflict((oc) =>
oc.columns(['node_id', 'chunk']).doUpdateSet({
text: sql.ref('excluded.text'),
embedding_vector: sql.ref('excluded.embedding_vector'),
updated_at: new Date(),
})
)
.execute();
}
};

View File

@@ -2,7 +2,9 @@ import { cleanNodeDataHandler } from '@/jobs/clean-node-data';
import { cleanWorkspaceDataHandler } from '@/jobs/clean-workspace-data'; import { cleanWorkspaceDataHandler } from '@/jobs/clean-workspace-data';
import { JobHandler, JobMap } from '@/types/jobs'; import { JobHandler, JobMap } from '@/types/jobs';
import { sendEmailVerifyEmailHandler } from '@/jobs/send-email-verify-email'; import { sendEmailVerifyEmailHandler } from '@/jobs/send-email-verify-email';
import { embedNodeHandler } from './embed-node';
import { embedDocumentHandler } from './embed-document';
import { assistantResponseHandler } from './assistant-response';
type JobHandlerMap = { type JobHandlerMap = {
[K in keyof JobMap]: JobHandler<JobMap[K]['input']>; [K in keyof JobMap]: JobHandler<JobMap[K]['input']>;
}; };
@@ -11,4 +13,7 @@ export const jobHandlerMap: JobHandlerMap = {
send_email_verify_email: sendEmailVerifyEmailHandler, send_email_verify_email: sendEmailVerifyEmailHandler,
clean_workspace_data: cleanWorkspaceDataHandler, clean_workspace_data: cleanWorkspaceDataHandler,
clean_node_data: cleanNodeDataHandler, clean_node_data: cleanNodeDataHandler,
embed_node: embedNodeHandler,
embed_document: embedDocumentHandler,
assistant_response: assistantResponseHandler,
}; };

View File

@@ -69,26 +69,63 @@ export interface SmtpConfiguration {
}; };
} }
export type AIProvider = 'openai' | 'google';
export interface AIProviderConfiguration {
apiKey: string;
enabled?: boolean;
}
export interface AIModelConfiguration {
provider: AIProvider;
modelName: string;
temperature: number;
}
export interface AiConfiguration { export interface AiConfiguration {
enabled: boolean; enabled: boolean;
entryEmbedDelay: number; entryEmbedDelay: number;
openai: OpenAiConfiguration; providers: {
openai: AIProviderConfiguration;
google: AIProviderConfiguration;
};
langfuse: {
publicKey: string;
secretKey: string;
baseUrl: string;
};
models: {
queryRewrite: AIModelConfiguration;
response: AIModelConfiguration;
rerank: AIModelConfiguration;
summarization: AIModelConfiguration;
contextEnhancer: AIModelConfiguration;
noContext: AIModelConfiguration;
intentRecognition: AIModelConfiguration;
};
embedding: {
provider: AIProvider;
apiKey: string;
modelName: string;
dimensions: number;
batchSize: number;
};
chunking: ChunkingConfiguration; chunking: ChunkingConfiguration;
} retrieval: RetrievalConfiguration;
export interface OpenAiConfiguration {
apiKey: string;
embeddingModel: string;
embeddingDimensions: number;
embeddingBatchSize: number;
} }
export interface ChunkingConfiguration { export interface ChunkingConfiguration {
defaultChunkSize: number; defaultChunkSize: number;
defaultOverlap: number; defaultOverlap: number;
enhanceWithContext: boolean; enhanceWithContext: boolean;
contextEnhancerModel: string; }
contextEnhancerTemperature: number;
export interface RetrievalConfiguration {
hybridSearch: {
semanticSearchWeight: number;
keywordSearchWeight: number;
maxResults: number;
};
} }
const getRequiredEnv = (env: string): string => { const getRequiredEnv = (env: string): string => {
@@ -172,15 +209,84 @@ export const configuration: Configuration = {
entryEmbedDelay: parseInt( entryEmbedDelay: parseInt(
getOptionalEnv('AI_ENTRY_EMBED_DELAY') || '60000' getOptionalEnv('AI_ENTRY_EMBED_DELAY') || '60000'
), ),
openai: { providers: {
apiKey: getOptionalEnv('OPENAI_API_KEY') || '', openai: {
embeddingModel: getOptionalEnv('OPENAI_EMBEDDING_MODEL') || '', apiKey: getOptionalEnv('OPENAI_API_KEY') || '',
embeddingDimensions: parseInt( enabled: getOptionalEnv('OPENAI_ENABLED') === 'true',
getOptionalEnv('OPENAI_EMBEDDING_DIMENSIONS') || '2000' },
), google: {
embeddingBatchSize: parseInt( apiKey: getOptionalEnv('GOOGLE_API_KEY') || '',
getOptionalEnv('OPENAI_EMBEDDING_BATCH_SIZE') || '50' enabled: getOptionalEnv('GOOGLE_ENABLED') === 'true',
), },
},
langfuse: {
publicKey: getOptionalEnv('LANGFUSE_PUBLIC_KEY') || '',
secretKey: getOptionalEnv('LANGFUSE_SECRET_KEY') || '',
baseUrl:
getOptionalEnv('LANGFUSE_BASE_URL') || 'https://cloud.langfuse.com',
},
models: {
queryRewrite: {
provider: (getOptionalEnv('QUERY_REWRITE_PROVIDER') ||
'openai') as AIProvider,
modelName: getOptionalEnv('QUERY_REWRITE_MODEL') || 'gpt-4o-mini',
temperature: parseFloat(
getOptionalEnv('QUERY_REWRITE_TEMPERATURE') || '0.3'
),
},
response: {
provider: (getOptionalEnv('RESPONSE_PROVIDER') ||
'openai') as AIProvider,
modelName: getOptionalEnv('RESPONSE_MODEL') || 'gpt-4o-mini',
temperature: parseFloat(
getOptionalEnv('RESPONSE_TEMPERATURE') || '0.3'
),
},
rerank: {
provider: (getOptionalEnv('RERANK_PROVIDER') || 'openai') as AIProvider,
modelName: getOptionalEnv('RERANK_MODEL') || 'gpt-4o-mini',
temperature: parseFloat(getOptionalEnv('RERANK_TEMPERATURE') || '0.3'),
},
summarization: {
provider: (getOptionalEnv('SUMMARIZATION_PROVIDER') ||
'openai') as AIProvider,
modelName: getOptionalEnv('SUMMARIZATION_MODEL') || 'gpt-4o-mini',
temperature: parseFloat(
getOptionalEnv('SUMMARIZATION_TEMPERATURE') || '0.3'
),
},
contextEnhancer: {
provider: (getOptionalEnv('CHUNK_CONTEXT_PROVIDER') ||
'openai') as AIProvider,
modelName: getOptionalEnv('CHUNK_CONTEXT_MODEL') || 'gpt-4o-mini',
temperature: parseFloat(
getOptionalEnv('CHUNK_CONTEXT_TEMPERATURE') || '0.3'
),
},
noContext: {
provider: (getOptionalEnv('NO_CONTEXT_PROVIDER') ||
'openai') as AIProvider,
modelName: getOptionalEnv('NO_CONTEXT_MODEL') || 'gpt-4o-mini',
temperature: parseFloat(
getOptionalEnv('NO_CONTEXT_TEMPERATURE') || '0.3'
),
},
intentRecognition: {
provider: (getOptionalEnv('INTENT_RECOGNITION_PROVIDER') ||
'openai') as AIProvider,
modelName: getOptionalEnv('INTENT_RECOGNITION_MODEL') || 'gpt-4o-mini',
temperature: parseFloat(
getOptionalEnv('INTENT_RECOGNITION_TEMPERATURE') || '0.3'
),
},
},
embedding: {
provider: (getOptionalEnv('EMBEDDING_PROVIDER') ||
'openai') as AIProvider,
modelName: getOptionalEnv('EMBEDDING_MODEL') || 'text-embedding-3-large',
dimensions: parseInt(getOptionalEnv('EMBEDDING_DIMENSIONS') || '2000'),
apiKey: getOptionalEnv('EMBEDDING_API_KEY') || '',
batchSize: parseInt(getOptionalEnv('EMBEDDING_BATCH_SIZE') || '50'),
}, },
chunking: { chunking: {
defaultChunkSize: parseInt( defaultChunkSize: parseInt(
@@ -191,11 +297,19 @@ export const configuration: Configuration = {
), ),
enhanceWithContext: enhanceWithContext:
getOptionalEnv('CHUNK_ENHANCE_WITH_CONTEXT') === 'true', getOptionalEnv('CHUNK_ENHANCE_WITH_CONTEXT') === 'true',
contextEnhancerModel: },
getOptionalEnv('CHUNK_CONTEXT_ENHANCER_MODEL') || 'gpt-4o-mini', retrieval: {
contextEnhancerTemperature: parseFloat( hybridSearch: {
getOptionalEnv('CHUNK_CONTEXT_ENHANCER_TEMPERATURE') || '0.3' semanticSearchWeight: parseFloat(
), getOptionalEnv('RETRIEVAL_HYBRID_SEARCH_SEMANTIC_WEIGHT') || '0.7'
),
keywordSearchWeight: parseFloat(
getOptionalEnv('RETRIEVAL_HYBRID_SEARCH_KEYWORD_WEIGHT') || '0.3'
),
maxResults: parseInt(
getOptionalEnv('RETRIEVAL_HYBRID_SEARCH_MAX_RESULTS') || '20'
),
},
}, },
}, },
}; };

View File

@@ -35,6 +35,7 @@ import {
checkCollaboratorChanges, checkCollaboratorChanges,
} from '@/lib/collaborations'; } from '@/lib/collaborations';
import { jobService } from '@/services/job-service'; import { jobService } from '@/services/job-service';
import { configuration } from '@/lib/configuration';
const debug = createDebugger('server:lib:nodes'); const debug = createDebugger('server:lib:nodes');
@@ -169,6 +170,18 @@ export const createNode = async (
}); });
} }
// Schedule node embedding
await jobService.addJob(
{
type: 'embed_node',
nodeId: input.nodeId,
},
{
jobId: `embed_node:${input.nodeId}`,
delay: configuration.ai.entryEmbedDelay,
}
);
return { return {
node: createdNode, node: createdNode,
}; };
@@ -290,6 +303,18 @@ export const tryUpdateNode = async (
}); });
} }
// Schedule node embedding
await jobService.addJob(
{
type: 'embed_node',
nodeId: input.nodeId,
},
{
jobId: `embed_node:${input.nodeId}`,
delay: configuration.ai.entryEmbedDelay,
}
);
return { return {
type: 'success', type: 'success',
output: { output: {
@@ -396,6 +421,18 @@ export const createNodeFromMutation = async (
}); });
} }
// Schedule node embedding
await jobService.addJob(
{
type: 'embed_node',
nodeId: mutation.id,
},
{
jobId: `embed_node:${mutation.id}`,
delay: configuration.ai.entryEmbedDelay,
}
);
return { return {
node: createdNode, node: createdNode,
}; };
@@ -527,6 +564,18 @@ const tryUpdateNodeFromMutation = async (
}); });
} }
// Schedule node embedding
await jobService.addJob(
{
type: 'embed_node',
nodeId: mutation.id,
},
{
jobId: `embed_node:${mutation.id}`,
delay: configuration.ai.entryEmbedDelay,
}
);
return { return {
type: 'success', type: 'success',
output: { output: {

View File

@@ -1,68 +1,98 @@
// Updated chunking-service.ts
import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter'; import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter';
import { ChatOpenAI } from '@langchain/openai';
import { SystemMessage } from '@langchain/core/messages';
import { configuration } from '@/lib/configuration'; import { configuration } from '@/lib/configuration';
import { database } from '@/data/database';
import { addContextToChunk } from '@/services/llm-service';
export type ChunkingMetadata = {
nodeId: string;
type: string;
name?: string;
parentName?: string;
spaceName?: string;
createdAt: Date;
};
export class ChunkingService { export class ChunkingService {
public async chunkText(text: string): Promise<string[]> { // Unified chunkText that optionally enriches the chunk with context metadata.
public async chunkText(
text: string,
metadataInfo?: { type: 'node' | 'document'; id: string }
): Promise<string[]> {
const chunkSize = configuration.ai.chunking.defaultChunkSize; const chunkSize = configuration.ai.chunking.defaultChunkSize;
const chunkOverlap = configuration.ai.chunking.defaultOverlap; const chunkOverlap = configuration.ai.chunking.defaultOverlap;
const splitter = new RecursiveCharacterTextSplitter({ const splitter = new RecursiveCharacterTextSplitter({
chunkSize, chunkSize,
chunkOverlap, chunkOverlap,
}); });
const docs = await splitter.createDocuments([text]); const docs = await splitter.createDocuments([text]);
let chunks = docs.map((doc) => doc.pageContent); let chunks = docs
.map((doc) => doc.pageContent)
chunks = chunks.filter((c) => c.trim().length > 10); .filter((c) => c.trim().length > 10);
if (configuration.ai.chunking.enhanceWithContext) { if (configuration.ai.chunking.enhanceWithContext) {
const enriched: string[] = []; // Fetch unified metadata (using a single query if possible)
for (const chunk of chunks) { const metadata = metadataInfo
const c = await this.addContextToChunk(chunk, text); ? await this.fetchMetadata(metadataInfo)
enriched.push(c); : undefined;
} chunks = await Promise.all(
return enriched; chunks.map(async (chunk) => {
return addContextToChunk(chunk, text, metadata);
})
);
} }
return chunks; return chunks;
} }
private async addContextToChunk( // A unified metadata fetch which uses a join to gather node and parent (space) details.
chunk: string, private async fetchMetadata(info: {
fullText: string type: 'node' | 'document';
): Promise<string> { id: string;
try { }): Promise<ChunkingMetadata | undefined> {
const chat = new ChatOpenAI({ if (info.type === 'node') {
openAIApiKey: configuration.ai.openai.apiKey, // Fetch node along with parent (if exists) and the root (assumed to be the space)
modelName: configuration.ai.chunking.contextEnhancerModel, const result = await database
temperature: configuration.ai.chunking.contextEnhancerTemperature, .selectFrom('nodes')
maxTokens: 200, .leftJoin('nodes as parent', 'nodes.parent_id', 'parent.id')
}); .leftJoin('nodes as root', 'nodes.root_id', 'root.id')
.select([
const prompt = ` 'nodes.id as nodeId',
<document> 'nodes.type',
${fullText} "nodes.attributes->>'name' as name",
</document> "parent.attributes->>'name' as parentName",
"root.attributes->>'name' as spaceName",
Here is the chunk we want to situate in context: 'nodes.created_at as createdAt',
<chunk> ])
${chunk} .where('nodes.id', '=', info.id)
</chunk> .executeTakeFirst();
if (!result) return undefined;
Generate a short (50-100 tokens) contextual prefix that seamlessly provides background or location info for the chunk. Then include the original chunk text below it without any extra separator. return {
`; nodeId: result.nodeId,
type: result.type,
const response = await chat.invoke([new SystemMessage(prompt)]); name: result.name,
const context = (response.content.toString() ?? '').trim(); parentName: result.parentName,
spaceName: result.spaceName,
return `${context} ${chunk}`; createdAt: result.createdAt,
} catch (err) { };
console.error('Error adding context to chunk:', err); } else {
return chunk; // For documents, assume similar metadata based on associated node.
const result = await database
.selectFrom('documents')
.innerJoin('nodes', 'documents.id', 'nodes.id')
.select([
'nodes.id as nodeId',
'nodes.type',
"nodes.attributes->>'name' as name",
'nodes.created_at as createdAt',
])
.where('documents.id', '=', info.id)
.executeTakeFirst();
if (!result) return undefined;
return {
nodeId: result.nodeId,
type: result.type,
name: result.name,
createdAt: result.createdAt,
};
} }
} }
} }

View File

@@ -0,0 +1,189 @@
import { OpenAIEmbeddings } from '@langchain/openai';
import { Document } from '@langchain/core/documents';
import { database } from '@/data/database';
import { configuration } from '@/lib/configuration';
import { sql } from 'kysely';
type SearchResult = {
id: string;
text: string;
score: number;
type: 'semantic' | 'keyword';
createdAt?: Date;
chunkIndex: number;
};
export class DocumentRetrievalService {
private embeddings = new OpenAIEmbeddings({
apiKey: configuration.ai.embedding.apiKey,
modelName: configuration.ai.embedding.modelName,
dimensions: configuration.ai.embedding.dimensions,
});
public async retrieve(
query: string,
workspaceId: string,
limit = configuration.ai.retrieval.hybridSearch.maxResults
): Promise<Document[]> {
const embedding = await this.embeddings.embedQuery(query);
if (!embedding) return [];
const semanticResults = await this.semanticSearch(
embedding,
workspaceId,
limit
);
const keywordResults = await this.keywordSearch(query, workspaceId, limit);
return this.combineSearchResults(semanticResults, keywordResults);
}
private async semanticSearch(
embedding: number[],
workspaceId: string,
limit: number
): Promise<SearchResult[]> {
const results = await database
.selectFrom('document_embeddings')
.innerJoin('documents', 'documents.id', 'document_embeddings.document_id')
.select((eb) => [
'document_embeddings.document_id as id',
'document_embeddings.text',
'documents.created_at',
'document_embeddings.chunk as chunk_index',
sql<number>`('[${embedding.join(',')}]'::vector) <=> document_embeddings.embedding_vector`.as(
'similarity'
),
])
.where('document_embeddings.workspace_id', '=', workspaceId)
.groupBy([
'document_embeddings.document_id',
'document_embeddings.text',
'documents.created_at',
'document_embeddings.chunk',
])
.orderBy('similarity', 'asc')
.limit(limit)
.execute();
return results.map((result) => ({
id: result.id,
text: result.text,
score: result.similarity,
type: 'semantic',
createdAt: result.created_at,
chunkIndex: result.chunk_index,
}));
}
private async keywordSearch(
query: string,
workspaceId: string,
limit: number
): Promise<SearchResult[]> {
const results = await database
.selectFrom('document_embeddings')
.innerJoin('documents', 'documents.id', 'document_embeddings.document_id')
.select((eb) => [
'document_embeddings.document_id as id',
'document_embeddings.text',
'documents.created_at',
'document_embeddings.chunk as chunk_index',
sql<number>`ts_rank(document_embeddings.search_vector, websearch_to_tsquery('english', ${query}))`.as(
'rank'
),
])
.where('document_embeddings.workspace_id', '=', workspaceId)
.where(
(eb) =>
sql`document_embeddings.search_vector @@ websearch_to_tsquery('english', ${query})`
)
.groupBy([
'document_embeddings.document_id',
'document_embeddings.text',
'documents.created_at',
'document_embeddings.chunk',
])
.orderBy('rank', 'desc')
.limit(limit)
.execute();
return results.map((result) => ({
id: result.id,
text: result.text,
score: result.rank,
type: 'keyword',
createdAt: result.created_at,
chunkIndex: result.chunk_index,
}));
}
private combineSearchResults(
semanticResults: SearchResult[],
keywordResults: SearchResult[]
): Document[] {
const { semanticSearchWeight, keywordSearchWeight } =
configuration.ai.retrieval.hybridSearch;
const maxSemanticScore = Math.max(
...semanticResults.map((r) => r.score),
1
);
const maxKeywordScore = Math.max(...keywordResults.map((r) => r.score), 1);
const combined = new Map<string, SearchResult & { finalScore: number }>();
const createKey = (result: SearchResult) =>
`${result.id}-${result.chunkIndex}`;
const calculateRecencyBoost = (
createdAt: Date | undefined | null
): number => {
if (!createdAt) return 1;
const now = new Date();
const ageInDays =
(now.getTime() - createdAt.getTime()) / (1000 * 60 * 60 * 24);
return ageInDays <= 7 ? 1 + (1 - ageInDays / 7) * 0.2 : 1;
};
semanticResults.forEach((result) => {
const key = createKey(result);
const recencyBoost = calculateRecencyBoost(result.createdAt);
const normalizedScore =
((maxSemanticScore - result.score) / maxSemanticScore) *
semanticSearchWeight;
combined.set(key, {
...result,
finalScore: normalizedScore * recencyBoost,
});
});
keywordResults.forEach((result) => {
const key = createKey(result);
const recencyBoost = calculateRecencyBoost(result.createdAt);
const normalizedScore =
(result.score / maxKeywordScore) * keywordSearchWeight;
if (combined.has(key)) {
const existing = combined.get(key)!;
existing.finalScore += normalizedScore * recencyBoost;
} else {
combined.set(key, {
...result,
finalScore: normalizedScore * recencyBoost,
});
}
});
return Array.from(combined.values())
.sort((a, b) => b.finalScore - a.finalScore)
.map(
(result) =>
new Document({
pageContent: result.text,
metadata: {
id: result.id,
score: result.finalScore,
createdAt: result.createdAt,
type: 'document',
chunkIndex: result.chunkIndex,
},
})
);
}
}
export const documentRetrievalService = new DocumentRetrievalService();

View File

@@ -0,0 +1,278 @@
// Updated llm-service.ts
import { ChatOpenAI } from '@langchain/openai';
import { ChatGoogleGenerativeAI } from '@langchain/google-genai';
import { PromptTemplate, ChatPromptTemplate } from '@langchain/core/prompts';
import { StringOutputParser } from '@langchain/core/output_parsers';
import { HumanMessage } from '@langchain/core/messages';
import { configuration } from '@/lib/configuration';
import { Document } from '@langchain/core/documents';
import { z } from 'zod';
import { NodeAttributes } from '@colanode/core';
// Use proper Zod schemas and updated prompt templates
const rerankedDocumentsSchema = z.object({
rankings: z.array(
z.object({
index: z.number(),
score: z.number().min(0).max(1),
type: z.enum(['node', 'document']),
sourceId: z.string(),
})
),
});
type RerankedDocuments = z.infer<typeof rerankedDocumentsSchema>;
const citedAnswerSchema = z.object({
answer: z.string(),
citations: z.array(
z.object({
sourceId: z.string(),
quote: z.string(),
})
),
});
type CitedAnswer = z.infer<typeof citedAnswerSchema>;
export function getChatModel(
task: keyof typeof configuration.ai.models
): ChatOpenAI | ChatGoogleGenerativeAI {
const modelConfig = configuration.ai.models[task];
if (!configuration.ai.enabled) {
throw new Error('AI is disabled.');
}
const providerConfig = configuration.ai.providers[modelConfig.provider];
if (!providerConfig.enabled) {
throw new Error(`${modelConfig.provider} provider is disabled.`);
}
switch (modelConfig.provider) {
case 'openai':
return new ChatOpenAI({
modelName: modelConfig.modelName,
temperature: modelConfig.temperature,
openAIApiKey: providerConfig.apiKey,
});
case 'google':
return new ChatGoogleGenerativeAI({
modelName: modelConfig.modelName,
temperature: modelConfig.temperature,
apiKey: providerConfig.apiKey,
});
default:
throw new Error(`Unsupported AI provider: ${modelConfig.provider}`);
}
}
// Updated prompt templates using type-safe node types and context
const queryRewritePrompt = PromptTemplate.fromTemplate(
`You are an expert at rewriting queries for information retrieval within Colanode.
Guidelines:
1. Extract the core information need.
2. Remove filler words.
3. Preserve key technical terms and dates.
Original query:
{query}
Rewrite the query and return only the rewritten version.`
);
const summarizationPrompt = PromptTemplate.fromTemplate(
`Summarize the following text focusing on key points relevant to the user's query.
If the text is short (<100 characters), return it as is.
Text: {text}
User Query: {query}`
);
const rerankPrompt = PromptTemplate.fromTemplate(
`Re-rank the following list of documents by their relevance to the query.
For each document, provide:
- Original index (from input)
- A relevance score between 0 and 1
- Document type (node or document)
- Source ID
User query:
{query}
Documents:
{context}
Return an array of rankings in JSON format.`
);
const answerPrompt = ChatPromptTemplate.fromTemplate(
`You are Colanode's AI assistant.
CURRENT TIME: {currentTimestamp}
WORKSPACE: {workspaceName}
USER: {userName} ({userEmail})
CONVERSATION HISTORY:
{formattedChatHistory}
RELATED CONTEXT:
Messages:
{formattedMessages}
Documents:
{formattedDocuments}
USER QUERY:
{question}
Provide a clear, professional answer. Then, in a separate citations array, list exact quotes (with source IDs) used to form your answer.
Return the result as JSON with keys "answer" and "citations".`
);
const intentRecognitionPrompt = PromptTemplate.fromTemplate(
`Determine if the following user query requires retrieving additional context.
Return exactly one value: "retrieve" or "no_context".
Conversation History:
{formattedChatHistory}
User Query:
{question}`
);
const noContextPrompt = PromptTemplate.fromTemplate(
`Answer the following query concisely using general knowledge, without retrieving additional context.
Conversation History:
{formattedChatHistory}
User Query:
{question}
Return only the answer.`
);
export async function rewriteQuery(query: string): Promise<string> {
const task = 'queryRewrite';
const model = getChatModel(task);
return queryRewritePrompt
.pipe(model)
.pipe(new StringOutputParser())
.invoke({ query });
}
export async function summarizeDocument(
document: Document,
query: string
): Promise<string> {
const task = 'summarization';
const model = getChatModel(task);
return summarizationPrompt
.pipe(model)
.pipe(new StringOutputParser())
.invoke({ text: document.pageContent, query });
}
export async function rerankDocuments(
documents: { content: string; type: string; sourceId: string }[],
query: string
): Promise<
Array<{ index: number; score: number; type: string; sourceId: string }>
> {
const task = 'rerank';
const model = getChatModel(task).withStructuredOutput(
rerankedDocumentsSchema
);
const formattedContext = documents
.map(
(doc, idx) =>
`${idx}. Type: ${doc.type}, Content: ${doc.content}, ID: ${doc.sourceId}\n`
)
.join('\n');
const result = (await rerankPrompt
.pipe(model)
.invoke({ query, context: formattedContext })) as RerankedDocuments;
return result.rankings;
}
export async function generateFinalAnswer(promptArgs: {
currentTimestamp: string;
workspaceName: string;
userName: string;
userEmail: string;
formattedChatHistory: string;
formattedMessages: string;
formattedDocuments: string;
question: string;
}): Promise<{
answer: string;
citations: Array<{ sourceId: string; quote: string }>;
}> {
const task = 'response';
const model = getChatModel(task).withStructuredOutput(citedAnswerSchema);
return (await answerPrompt.pipe(model).invoke(promptArgs)) as CitedAnswer;
}
export async function generateNoContextAnswer(
query: string,
chatHistory: string = ''
): Promise<string> {
const task = 'noContext';
const model = getChatModel(task);
return noContextPrompt
.pipe(model)
.pipe(new StringOutputParser())
.invoke({ question: query, formattedChatHistory: chatHistory });
}
export async function assessUserIntent(
query: string,
chatHistory: string
): Promise<'retrieve' | 'no_context'> {
const task = 'intentRecognition';
const model = getChatModel(task);
const result = await intentRecognitionPrompt
.pipe(model)
.pipe(new StringOutputParser())
.invoke({ question: query, formattedChatHistory: chatHistory });
return result.trim().toLowerCase() === 'no_context'
? 'no_context'
: 'retrieve';
}
export async function addContextToChunk(
chunk: string,
fullText: string,
metadata?: any
): Promise<string> {
try {
const task = 'contextEnhancer';
const model = getChatModel(task);
// Choose a prompt variant based on metadata type (node/document) if available.
const promptTemplate = PromptTemplate.fromTemplate(
`Using the following context information:
{contextInfo}
Full content:
{fullText}
Given this chunk:
{chunk}
Generate a short (50100 tokens) contextual prefix (do not repeat the chunk) and prepend it to the chunk.`
);
const contextInfo = metadata
? `Type: ${metadata.type}, Name: ${metadata.name || 'N/A'}, Parent: ${metadata.parentName || 'N/A'}, Space: ${metadata.spaceName || 'N/A'}`
: 'No additional context available.';
const prompt = await promptTemplate.format({
contextInfo,
fullText,
chunk,
});
const response = await model.invoke([
new HumanMessage({ content: prompt }),
]);
const prefix = (response.content.toString() || '').trim();
return prefix ? `${prefix}\n\n${chunk}` : chunk;
} catch (err) {
console.error('Error in addContextToChunk:', err);
return chunk;
}
}

View File

@@ -0,0 +1,199 @@
import { OpenAIEmbeddings } from '@langchain/openai';
import { Document } from '@langchain/core/documents';
import { database } from '@/data/database';
import { configuration } from '@/lib/configuration';
import { sql } from 'kysely';
type SearchResult = {
id: string;
text: string;
score: number;
type: 'semantic' | 'keyword';
createdAt?: Date;
chunkIndex: number;
};
export class NodeRetrievalService {
private embeddings = new OpenAIEmbeddings({
apiKey: configuration.ai.embedding.apiKey,
modelName: configuration.ai.embedding.modelName,
dimensions: configuration.ai.embedding.dimensions,
});
public async retrieve(
query: string,
workspaceId: string,
userId: string,
limit = configuration.ai.retrieval.hybridSearch.maxResults
): Promise<Document[]> {
const embedding = await this.embeddings.embedQuery(query);
if (!embedding) return [];
const semanticResults = await this.semanticSearch(
embedding,
workspaceId,
userId,
limit
);
const keywordResults = await this.keywordSearch(
query,
workspaceId,
userId,
limit
);
return this.combineSearchResults(semanticResults, keywordResults);
}
private async semanticSearch(
embedding: number[],
workspaceId: string,
userId: string,
limit: number
): Promise<SearchResult[]> {
const results = await database
.selectFrom('node_embeddings')
.innerJoin('nodes', 'nodes.id', 'node_embeddings.node_id')
.select((eb) => [
'node_embeddings.node_id as id',
'node_embeddings.text',
'nodes.created_at',
'node_embeddings.chunk as chunk_index',
// Wrap raw expression to satisfy type:
sql<number>`('[${embedding.join(',')}]'::vector) <=> node_embeddings.embedding_vector`.as(
'similarity'
),
])
.where('node_embeddings.workspace_id', '=', workspaceId)
.groupBy([
'node_embeddings.node_id',
'node_embeddings.text',
'nodes.created_at',
'node_embeddings.chunk',
])
.orderBy('similarity', 'asc')
.limit(limit)
.execute();
return results.map((result) => ({
id: result.id,
text: result.text,
score: result.similarity,
type: 'semantic',
createdAt: result.created_at,
chunkIndex: result.chunk_index,
}));
}
private async keywordSearch(
query: string,
workspaceId: string,
userId: string,
limit: number
): Promise<SearchResult[]> {
const results = await database
.selectFrom('node_embeddings')
.innerJoin('nodes', 'nodes.id', 'node_embeddings.node_id')
.select((eb) => [
'node_embeddings.node_id as id',
'node_embeddings.text',
'nodes.created_at',
'node_embeddings.chunk as chunk_index',
sql<number>`ts_rank(node_embeddings.search_vector, websearch_to_tsquery('english', ${query}))`.as(
'rank'
),
])
.where('node_embeddings.workspace_id', '=', workspaceId)
.where(
(eb) =>
sql`node_embeddings.search_vector @@ websearch_to_tsquery('english', ${query})`
)
.groupBy([
'node_embeddings.node_id',
'node_embeddings.text',
'nodes.created_at',
'node_embeddings.chunk',
])
.orderBy('rank', 'desc')
.limit(limit)
.execute();
return results.map((result) => ({
id: result.id,
text: result.text,
score: result.rank,
type: 'keyword',
createdAt: result.created_at,
chunkIndex: result.chunk_index,
}));
}
private combineSearchResults(
semanticResults: SearchResult[],
keywordResults: SearchResult[]
): Document[] {
const { semanticSearchWeight, keywordSearchWeight } =
configuration.ai.retrieval.hybridSearch;
const maxSemanticScore = Math.max(
...semanticResults.map((r) => r.score),
1
);
const maxKeywordScore = Math.max(...keywordResults.map((r) => r.score), 1);
const combined = new Map<string, SearchResult & { finalScore: number }>();
const createKey = (result: SearchResult) =>
`${result.id}-${result.chunkIndex}`;
const calculateRecencyBoost = (
createdAt: Date | undefined | null
): number => {
if (!createdAt) return 1;
const now = new Date();
const ageInDays =
(now.getTime() - createdAt.getTime()) / (1000 * 60 * 60 * 24);
return ageInDays <= 7 ? 1 + (1 - ageInDays / 7) * 0.2 : 1;
};
semanticResults.forEach((result) => {
const key = createKey(result);
const recencyBoost = calculateRecencyBoost(result.createdAt);
const normalizedScore =
((maxSemanticScore - result.score) / maxSemanticScore) *
semanticSearchWeight;
combined.set(key, {
...result,
finalScore: normalizedScore * recencyBoost,
});
});
keywordResults.forEach((result) => {
const key = createKey(result);
const recencyBoost = calculateRecencyBoost(result.createdAt);
const normalizedScore =
(result.score / maxKeywordScore) * keywordSearchWeight;
if (combined.has(key)) {
const existing = combined.get(key)!;
existing.finalScore += normalizedScore * recencyBoost;
} else {
combined.set(key, {
...result,
finalScore: normalizedScore * recencyBoost,
});
}
});
return Array.from(combined.values())
.sort((a, b) => b.finalScore - a.finalScore)
.map(
(result) =>
new Document({
pageContent: result.text,
metadata: {
id: result.id,
score: result.finalScore,
createdAt: result.createdAt,
type: 'node',
chunkIndex: result.chunkIndex,
},
})
);
}
}
export const nodeRetrievalService = new NodeRetrievalService();