Refactor chunking and embedding services with improved metadata handling and code clarity

This commit is contained in:
Ylber Gashi
2025-02-18 19:33:38 +01:00
parent a0a160cc8f
commit 2ba0d9e6be
4 changed files with 125 additions and 127 deletions

View File

@@ -34,7 +34,7 @@ const extractDocumentText = async (
if (documentText) {
return documentText;
} else {
// Fallback to block text extraction if the model doesn't handle it
// Fallback to block text extraction if the node model doesn't handle it
const blocksText = extractBlockTexts(node.id, content.blocks);
if (blocksText) {
return blocksText;
@@ -53,19 +53,19 @@ export const embedDocumentHandler = async (input: {
}
const { documentId } = input;
const document = await database
.selectFrom('documents')
.select(['id', 'content', 'workspace_id', 'created_at'])
.where('id', '=', documentId)
.executeTakeFirst();
if (!document) {
return;
}
const node = await fetchNode(documentId);
if (!node) return;
if (!node) {
return;
}
const nodeModel = getNodeModel(node.attributes.type);
if (!nodeModel?.documentSchema) {
@@ -82,7 +82,7 @@ export const embedDocumentHandler = async (input: {
}
const chunkingService = new ChunkingService();
const chunks = await chunkingService.chunkText(text, {
const textChunks = await chunkingService.chunkText(text, {
type: 'document',
node: node,
});
@@ -99,16 +99,16 @@ export const embedDocumentHandler = async (input: {
.execute();
const embeddingsToCreateOrUpdate: CreateDocumentEmbedding[] = [];
for (let i = 0; i < chunks.length; i++) {
const chunk = chunks[i];
if (!chunk) continue;
for (let i = 0; i < textChunks.length; i++) {
const textChunk = textChunks[i];
if (!textChunk) continue;
const existing = existingEmbeddings.find((e) => e.chunk === i);
if (existing && existing.text === chunk) continue;
if (existing && existing.text === textChunk) continue;
embeddingsToCreateOrUpdate.push({
document_id: documentId,
chunk: i,
workspace_id: document.workspace_id,
text: chunk,
text: textChunk,
embedding_vector: [],
created_at: new Date(),
});

View File

@@ -2,7 +2,7 @@ import { OpenAIEmbeddings } from '@langchain/openai';
import { ChunkingService } from '@/services/chunking-service';
import { database } from '@/data/database';
import { configuration } from '@/lib/configuration';
import { CreateNodeEmbedding } from '@/data/schema';
import { CreateNodeEmbedding, SelectNode } from '@/data/schema';
import { sql } from 'kysely';
import { fetchNode } from '@/lib/nodes';
import { getNodeModel } from '@colanode/core';
@@ -20,16 +20,17 @@ declare module '@/types/jobs' {
}
}
const extractNodeText = async (
nodeId: string,
node: Awaited<ReturnType<typeof fetchNode>>
): Promise<string> => {
if (!node) return '';
const extractNodeText = (node: SelectNode): string => {
if (!node) {
return '';
}
const nodeModel = getNodeModel(node.attributes.type);
if (!nodeModel) return '';
if (!nodeModel) {
return '';
}
const attributesText = nodeModel.getAttributesText(nodeId, node.attributes);
const attributesText = nodeModel.getAttributesText(node.id, node.attributes);
if (attributesText) {
return attributesText;
}
@@ -51,13 +52,13 @@ export const embedNodeHandler = async (input: {
return;
}
// Skip nodes that are handled by document embeddings
const nodeModel = getNodeModel(node.attributes.type);
// Skip nodes that are handled by embed documents job
if (!nodeModel || nodeModel.documentSchema) {
return;
}
const text = await extractNodeText(nodeId, node);
const text = extractNodeText(node);
if (!text || text.trim() === '') {
await database
.deleteFrom('node_embeddings')
@@ -67,7 +68,7 @@ export const embedNodeHandler = async (input: {
}
const chunkingService = new ChunkingService();
const chunks = await chunkingService.chunkText(text, {
const textChunks = await chunkingService.chunkText(text, {
type: 'node',
node,
});
@@ -84,18 +85,24 @@ export const embedNodeHandler = async (input: {
.execute();
const embeddingsToCreateOrUpdate: CreateNodeEmbedding[] = [];
for (let i = 0; i < chunks.length; i++) {
const chunk = chunks[i];
if (!chunk) continue;
for (let i = 0; i < textChunks.length; i++) {
const textChunk = textChunks[i];
if (!textChunk) {
continue;
}
const existing = existingEmbeddings.find((e) => e.chunk === i);
if (existing && existing.text === chunk) continue;
if (existing && existing.text === textChunk) {
continue;
}
embeddingsToCreateOrUpdate.push({
node_id: nodeId,
chunk: i,
parent_id: node.parent_id,
root_id: node.root_id,
workspace_id: node.workspace_id,
text: chunk,
text: textChunk,
embedding_vector: [],
created_at: new Date(),
});

View File

@@ -2,41 +2,36 @@ import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter';
import { configuration } from '@/lib/configuration';
import { database } from '@/data/database';
import { addContextToChunk } from '@/services/llm-service';
import {
DocumentContent,
getNodeModel,
Node,
NodeAttributes,
} from '@colanode/core';
import type { SelectNode, SelectDocument, SelectUser } from '@/data/schema';
import { DocumentContent, getNodeModel, NodeType } from '@colanode/core';
import type { SelectNode, SelectDocument } from '@/data/schema';
type BaseMetadata = {
id: string;
name?: string;
createdAt: Date;
createdBy: string;
updatedAt?: Date | null;
updatedBy?: string | null;
author?: { id: string; name: string };
lastAuthor?: { id: string; name: string };
parentContext?: {
id: string;
type: string;
name?: string;
path?: string;
};
collaborators?: Array<{ id: string; name: string; role: string }>;
lastUpdated?: Date;
updatedBy?: { id: string; name: string };
collaborators?: Array<{ id: string; name: string }>;
workspace?: { id: string; name: string };
};
export type NodeMetadata = BaseMetadata & {
type: 'node';
nodeType: string;
nodeType: NodeType;
fields?: Record<string, unknown> | null;
};
export type DocumentMetadata = BaseMetadata & {
type: 'document';
content: DocumentContent;
};
export type ChunkingMetadata = NodeMetadata | DocumentMetadata;
@@ -52,19 +47,26 @@ export class ChunkingService {
chunkSize,
chunkOverlap,
});
const docs = await splitter.createDocuments([text]);
let chunks = docs.map((doc) => doc.pageContent);
chunks = chunks.filter((c) => c.trim().length > 10);
chunks = chunks.filter((c) => c.trim().length > 5); // We skip chunks that are 5 characters or less
if (configuration.ai.chunking.enhanceWithContext) {
const enrichedMetadata = await this.fetchMetadata(metadata);
const enriched: string[] = [];
for (const chunk of chunks) {
const c = await addContextToChunk(chunk, text, enrichedMetadata);
enriched.push(c);
const enrichedChunk = await addContextToChunk(
chunk,
text,
enrichedMetadata
);
enriched.push(enrichedChunk);
}
return enriched;
}
return chunks;
}
@@ -79,11 +81,11 @@ export class ChunkingService {
if (metadata.type === 'node') {
return this.buildNodeMetadata(metadata.node);
} else {
const document = (await database
const document = await database
.selectFrom('documents')
.selectAll()
.where('id', '=', metadata.node.id)
.executeTakeFirst()) as SelectDocument | undefined;
.executeTakeFirst();
if (!document) {
return undefined;
}
@@ -92,28 +94,27 @@ export class ChunkingService {
}
}
private async buildNodeMetadata(node: SelectNode): Promise<NodeMetadata> {
private async buildNodeMetadata(
node: SelectNode
): Promise<NodeMetadata | undefined> {
const nodeModel = getNodeModel(node.attributes.type);
if (!nodeModel) {
throw new Error(`No model found for node type: ${node.attributes.type}`);
return undefined;
}
const baseMetadata = await this.buildBaseMetadata(node);
if (!baseMetadata) {
return undefined;
}
// Add collaborators if the node type supports them
if ('collaborators' in node.attributes) {
baseMetadata.collaborators = await this.fetchCollaborators(
Object.keys(
(
node.attributes as NodeAttributes & {
collaborators: Record<string, string>;
}
).collaborators
)
Object.keys(node.attributes.collaborators)
);
}
// Add parent context if needed
// Add parent context if the node has a parent
if (node.parent_id) {
const parentContext = await this.buildParentContext(node);
if (parentContext) {
@@ -132,47 +133,48 @@ export class ChunkingService {
private async buildDocumentMetadata(
document: SelectDocument,
node?: SelectNode
): Promise<DocumentMetadata> {
): Promise<DocumentMetadata | undefined> {
let baseMetadata: BaseMetadata = {
id: document.id,
createdAt: document.created_at,
createdBy: document.created_by,
};
if (node) {
const nodeModel = getNodeModel(node.attributes.type);
if (nodeModel) {
const nodeName = nodeModel.getName(node.id, node.attributes);
if (nodeName) {
baseMetadata.name = nodeName;
}
if (!node) {
return undefined;
}
// Add parent context if available
if (node.parent_id) {
const parentContext = await this.buildParentContext(node);
if (parentContext) {
baseMetadata.parentContext = parentContext;
}
}
const nodeModel = getNodeModel(node.attributes.type);
if (nodeModel) {
const nodeName = nodeModel.getName(node.id, node.attributes);
if (nodeName) {
baseMetadata.name = nodeName;
}
return {
type: 'document',
content: document.content,
...baseMetadata,
};
// Add parent context if available
if (node.parent_id) {
const parentContext = await this.buildParentContext(node);
if (parentContext) {
baseMetadata.parentContext = parentContext;
}
}
}
return {
type: 'document',
content: document.content,
...baseMetadata,
};
}
private async buildBaseMetadata(node: SelectNode): Promise<BaseMetadata> {
private async buildBaseMetadata(
node: SelectNode
): Promise<BaseMetadata | undefined> {
const nodeModel = getNodeModel(node.attributes.type);
const nodeName = nodeModel?.getName(node.id, node.attributes);
if (!nodeModel) {
return undefined;
}
const nodeName = nodeModel.getName(node.id, node.attributes);
const author = await database
.selectFrom('users')
@@ -180,7 +182,7 @@ export class ChunkingService {
.where('id', '=', node.created_by)
.executeTakeFirst();
const updatedBy = node.updated_by
const lastAuthor = node.updated_by
? await database
.selectFrom('users')
.select(['id', 'name'])
@@ -199,10 +201,11 @@ export class ChunkingService {
name: nodeName ?? '',
createdAt: node.created_at,
createdBy: node.created_by,
author: author ?? undefined,
lastUpdated: node.updated_at ?? undefined,
updatedBy: updatedBy ?? undefined,
workspace: workspace ?? undefined,
updatedAt: node.updated_at,
updatedBy: node.updated_by,
author: author,
lastAuthor: lastAuthor,
workspace: workspace,
};
}
@@ -255,7 +258,7 @@ export class ChunkingService {
private async fetchCollaborators(
collaboratorIds: string[]
): Promise<Array<{ id: string; name: string; role: string }>> {
): Promise<Array<{ id: string; name: string }>> {
if (!collaboratorIds.length) {
return [];
}
@@ -266,21 +269,9 @@ export class ChunkingService {
.where('id', 'in', collaboratorIds)
.execute();
// Get roles for each collaborator
return Promise.all(
collaborators.map(async (c) => {
const collaboration = await database
.selectFrom('collaborations')
.select(['role'])
.where('collaborator_id', '=', c.id)
.executeTakeFirst();
return {
id: c.id,
name: c.name,
role: collaboration?.role ?? 'unknown',
};
})
);
return collaborators.map((c) => ({
id: c.id,
name: c.name,
}));
}
}

View File

@@ -1,9 +1,8 @@
// Updated llm-service.ts
import { ChatOpenAI } from '@langchain/openai';
import { ChatGoogleGenerativeAI } from '@langchain/google-genai';
import { PromptTemplate, ChatPromptTemplate } from '@langchain/core/prompts';
import { StringOutputParser } from '@langchain/core/output_parsers';
import { HumanMessage } from '@langchain/core/messages';
import { SystemMessage } from '@langchain/core/messages';
import { configuration } from '@/lib/configuration';
import { Document } from '@langchain/core/documents';
import { z } from 'zod';
@@ -297,8 +296,8 @@ const getNodeContextPrompt = (metadata: NodeMetadata): string => {
const basePrompt = `Given the following context about a {nodeType}:
Name: {name}
Created by: {authorName} on {createdAt}
Last updated: {lastUpdated} by {updatedByName}
Location: {path}
Last updated: {updatedAt} by {lastAuthorName}
Path: {path}
Workspace: {workspaceName}
{additionalContext}
@@ -314,17 +313,8 @@ Generate a brief (50-100 tokens) contextual prefix that:
3. Makes the chunk more understandable in isolation
Do not repeat the chunk content. Return only the contextual prefix.`;
const getCollaboratorInfo = (
collaborators?: Array<{ id: string; name: string; role: string }>
) => {
if (!collaborators?.length) return 'No collaborators';
return collaborators.map((c) => `${c.name} (${c.role})`).join(', ');
};
const formatDate = (date?: Date) => {
if (!date) return 'unknown';
return new Date(date).toLocaleString();
};
const collaborators =
metadata.collaborators?.map((c) => `${c.name}`).join(', ') ?? '';
switch (metadata.nodeType) {
case 'message':
@@ -332,7 +322,7 @@ Do not repeat the chunk content. Return only the contextual prefix.`;
'{additionalContext}',
`In: ${metadata.parentContext?.type ?? 'unknown'} "${metadata.parentContext?.name ?? 'unknown'}"
Path: ${metadata.parentContext?.path ?? 'unknown'}
Participants: ${getCollaboratorInfo(metadata.collaborators)}`
Participants: ${collaborators}`
);
case 'record':
@@ -347,7 +337,7 @@ Fields: ${Object.keys(metadata.fields ?? {}).join(', ')}`
return basePrompt.replace(
'{additionalContext}',
`Location: ${metadata.parentContext?.path ?? 'root level'}
Collaborators: ${getCollaboratorInfo(metadata.collaborators)}`
Collaborators: ${collaborators}`
);
case 'database':
@@ -355,7 +345,7 @@ Collaborators: ${getCollaboratorInfo(metadata.collaborators)}`
'{additionalContext}',
`Path: ${metadata.parentContext?.path ?? 'root level'}
Fields: ${Object.keys(metadata.fields ?? {}).join(', ')}
Collaborators: ${getCollaboratorInfo(metadata.collaborators)}`
Collaborators: ${collaborators}`
);
case 'channel':
@@ -363,7 +353,7 @@ Collaborators: ${getCollaboratorInfo(metadata.collaborators)}`
'{additionalContext}',
`Type: Channel
Path: ${metadata.parentContext?.path ?? 'root level'}
Members: ${getCollaboratorInfo(metadata.collaborators)}`
Members: ${collaborators}`
);
default:
@@ -374,10 +364,10 @@ Members: ${getCollaboratorInfo(metadata.collaborators)}`
interface PromptVariables {
nodeType: string;
name: string;
authorName: string;
createdAt: string;
lastUpdated: string;
updatedByName: string;
updatedAt: string;
authorName: string;
lastAuthorName: string;
path: string;
workspaceName: string;
fullText: string;
@@ -391,7 +381,7 @@ Type: {nodeType}
Name: {name}
Location: {path}
Created by: {authorName} on {createdAt}
Last updated: {lastUpdated} by {updatedByName}
Last updated: {updatedAt} by {lastAuthorName}
Workspace: {workspaceName}
Full content:
@@ -413,6 +403,14 @@ export async function addContextToChunk(
metadata?: ChunkingMetadata
): Promise<string> {
try {
if (!chunk || chunk.trim() === '') {
return chunk;
}
if (!fullText || fullText.trim() === '') {
return chunk;
}
if (!metadata) {
return chunk;
}
@@ -425,22 +423,24 @@ export async function addContextToChunk(
const formatDate = (date?: Date) => {
if (!date) return 'unknown';
return new Date(date).toLocaleString();
return new Date(date).toUTCString();
};
const baseVars = {
nodeType: metadata.type === 'node' ? metadata.nodeType : metadata.type,
name: metadata.name ?? 'Untitled',
authorName: metadata.author?.name ?? 'Unknown',
createdAt: formatDate(metadata.createdAt),
lastUpdated: formatDate(metadata.lastUpdated),
updatedByName: metadata.updatedBy?.name ?? 'Unknown',
path: metadata.parentContext?.path ?? 'root level',
updatedAt: metadata.updatedAt ? formatDate(metadata.updatedAt) : '',
authorName: metadata.author?.name ?? 'Unknown',
lastAuthorName: metadata.lastAuthor?.name ?? '',
path: metadata.parentContext?.path ?? '',
workspaceName: metadata.workspace?.name ?? 'Unknown Workspace',
fullText,
chunk,
};
//TODO: if metadata is empty, use a default context prompt for chunk
if (metadata.type === 'node') {
prompt = getNodeContextPrompt(metadata);
promptVars = baseVars;
@@ -455,7 +455,7 @@ export async function addContextToChunk(
);
const response = await model.invoke([
new HumanMessage({ content: formattedPrompt }),
new SystemMessage({ content: formattedPrompt }),
]);
const prefix = (response.content.toString() || '').trim();