Separated ai embeddings into entry and message embedding tables, updated migration script and updated embed jobs accordingly

This commit is contained in:
Ylber Gashi
2025-01-02 22:51:18 +01:00
parent e7d20f47f0
commit fb4783f4a0
5 changed files with 222 additions and 58 deletions

View File

@@ -848,13 +848,15 @@ const createEntryPathsTable: Migration = {
},
};
const createAIEmbeddingsTable: Migration = {
const createMessageEmbeddingsTable: Migration = {
up: async (db) => {
await db.schema
.createTable('ai_embeddings')
.addColumn('id', 'varchar(30)', (col) => col.notNull().primaryKey())
.addColumn('entity_id', 'varchar(30)', (col) => col.notNull())
.addColumn('entity_type', 'varchar(30)', (col) => col.notNull())
.createTable('message_embeddings')
.addColumn('message_id', 'varchar(30)', (col) => col.notNull())
.addColumn('chunk', 'integer', (col) => col.notNull())
.addColumn('parent_id', 'varchar(30)', (col) => col.notNull())
.addColumn('root_id', 'varchar(30)', (col) => col.notNull())
.addColumn('workspace_id', 'varchar(30)', (col) => col.notNull())
.addColumn('content', 'text', (col) => col.notNull())
.addColumn('embedding', sql`vector(2000)`, (col) => col.notNull())
.addColumn('fts', sql`tsvector GENERATED ALWAYS AS (to_tsvector('english', content)) STORED`)
@@ -864,8 +866,8 @@ const createAIEmbeddingsTable: Migration = {
.execute();
await sql`
CREATE INDEX ai_embeddings_embedding_idx
ON ai_embeddings
CREATE INDEX message_embeddings_embedding_idx
ON message_embeddings
USING hnsw(embedding vector_cosine_ops)
WITH (
m = 16,
@@ -874,13 +876,51 @@ const createAIEmbeddingsTable: Migration = {
`.execute(db);
await sql`
CREATE INDEX ai_embeddings_fts_idx
ON ai_embeddings
CREATE INDEX message_embeddings_fts_idx
ON message_embeddings
USING GIN (fts);
`.execute(db);
},
down: async (db) => {
await db.schema.dropTable('ai_embeddings').execute();
await db.schema.dropTable('message_embeddings').execute();
},
};
const createEntryEmbeddingsTable: Migration = {
up: async (db) => {
await db.schema
.createTable('entry_embeddings')
.addColumn('entry_id', 'varchar(30)', (col) => col.notNull())
.addColumn('chunk', 'integer', (col) => col.notNull())
.addColumn('parent_id', 'varchar(30)', (col) => col.notNull())
.addColumn('root_id', 'varchar(30)', (col) => col.notNull())
.addColumn('workspace_id', 'varchar(30)', (col) => col.notNull())
.addColumn('content', 'text', (col) => col.notNull())
.addColumn('embedding', sql`vector(2000)`, (col) => col.notNull())
.addColumn('fts', sql`tsvector GENERATED ALWAYS AS (to_tsvector('english', content)) STORED`)
.addColumn('metadata', 'jsonb')
.addColumn('created_at', 'timestamptz', (col) => col.notNull())
.addColumn('updated_at', 'timestamptz')
.execute();
await sql`
CREATE INDEX entry_embeddings_embedding_idx
ON entry_embeddings
USING hnsw(embedding vector_cosine_ops)
WITH (
m = 16,
ef_construction = 64
);
`.execute(db);
await sql`
CREATE INDEX entry_embeddings_fts_idx
ON entry_embeddings
USING GIN (fts);
`.execute(db);
},
down: async (db) => {
await db.schema.dropTable('entry_embeddings').execute();
},
};
@@ -901,5 +941,6 @@ export const databaseMigrations: Record<string, Migration> = {
'00014_create_file_interactions_table': createFileInteractionsTable,
'00015_create_file_tombstones_table': createFileTombstonesTable,
'00016_create_collaborations_table': createCollaborationsTable,
'00017_create_ai_embeddings_table': createAIEmbeddingsTable,
'00017_create_entry_embeddings_table': createEntryEmbeddingsTable,
'00018_create_message_embeddings_table': createMessageEmbeddingsTable,
};

View File

@@ -287,10 +287,12 @@ export type SelectFileTombstone = Selectable<FileTombstoneTable>;
export type CreateFileTombstone = Insertable<FileTombstoneTable>;
export type UpdateFileTombstone = Updateable<FileTombstoneTable>;
interface AIEmbeddingsTable {
id: ColumnType<string, string, never>;
entity_id: ColumnType<string, string, never>;
entity_type: ColumnType<string, string, never>;
interface EntryEmbeddingTable {
entry_id: ColumnType<string, string, never>;
chunk: ColumnType<number, number, number>;
parent_id: ColumnType<string, string, never>;
root_id: ColumnType<string, string, never>;
workspace_id: ColumnType<string, string, never>;
content: ColumnType<string, string, string>;
embedding: ColumnType<number[], number[], number[]>;
fts: ColumnType<never, never, never>;
@@ -299,10 +301,27 @@ interface AIEmbeddingsTable {
updated_at: ColumnType<Date | null, Date | null, Date | null>;
}
export type SelectAIEmbedding = Selectable<AIEmbeddingsTable>;
export type CreateAIEmbedding = Insertable<AIEmbeddingsTable>;
export type UpdateAIEmbedding = Updateable<AIEmbeddingsTable>;
export type SelectEntryEmbedding = Selectable<EntryEmbeddingTable>;
export type CreateEntryEmbedding = Insertable<EntryEmbeddingTable>;
export type UpdateEntryEmbedding = Updateable<EntryEmbeddingTable>;
interface MessageEmbeddingTable {
message_id: ColumnType<string, string, never>;
chunk: ColumnType<number, number, number>;
parent_id: ColumnType<string, string, never>;
root_id: ColumnType<string, string, never>;
workspace_id: ColumnType<string, string, never>;
content: ColumnType<string, string, string>;
embedding: ColumnType<number[], number[], number[]>;
fts: ColumnType<never, never, never>;
metadata: ColumnType<string | null, string | null, string | null>;
created_at: ColumnType<Date, Date, never>;
updated_at: ColumnType<Date | null, Date | null, Date | null>;
}
export type SelectMessageEmbedding = Selectable<MessageEmbeddingTable>;
export type CreateMessageEmbedding = Insertable<MessageEmbeddingTable>;
export type UpdateMessageEmbedding = Updateable<MessageEmbeddingTable>;
export interface DatabaseSchema {
accounts: AccountTable;
@@ -321,5 +340,6 @@ export interface DatabaseSchema {
file_interactions: FileInteractionTable;
file_tombstones: FileTombstoneTable;
collaborations: CollaborationTable;
ai_embeddings: AIEmbeddingsTable;
entry_embeddings: EntryEmbeddingTable;
message_embeddings: MessageEmbeddingTable;
}

View File

@@ -3,7 +3,7 @@ import { ChunkingService } from '@/services/chunking-service';
import { database } from '@/data/database';
import { OpenAIEmbeddings } from '@langchain/openai';
import { aiSettings } from '@/lib/ai-settings';
import { generateId, IdType, extractEntryText } from '@colanode/core';
import { extractEntryText } from '@colanode/core';
export type EmbedEntryInput = {
type: 'embed_entry';
@@ -25,7 +25,7 @@ export const embedEntryHandler: JobHandler<EmbedEntryInput> = async (
const entry = await database
.selectFrom('entries')
.select(['id', 'attributes'])
.select(['id', 'attributes', 'parent_id', 'root_id', 'workspace_id'])
.where('id', '=', entryId)
.executeTakeFirst();
@@ -47,7 +47,22 @@ export const embedEntryHandler: JobHandler<EmbedEntryInput> = async (
dimensions: aiSettings.openai.embeddingDimensions,
});
const embeddingVectors = await embeddings.embedDocuments(chunks)
const existingEmbeddings = await database
.selectFrom('entry_embeddings')
.select(['chunk', 'content'])
.where('entry_id', '=', entryId)
.execute();
const embeddingsToCreateOrUpdate: {
entry_id: string;
chunk: number;
parent_id: string;
root_id: string;
workspace_id: string;
content: string;
embedding: number[];
metadata: string | null;
}[] = [];
for (let i = 0; i < chunks.length; i++) {
const chunk = chunks[i];
@@ -55,28 +70,66 @@ export const embedEntryHandler: JobHandler<EmbedEntryInput> = async (
continue;
}
const embedding = embeddingVectors[i]
if (!embedding) {
const existingEmbedding = existingEmbeddings.find(
(e) => e.chunk === i
);
if (existingEmbedding && existingEmbedding.content === chunk) {
continue;
}
const id = generateId(IdType.AiEmbedding);
const metadata = JSON.stringify({
chunk_index: i,
embeddingsToCreateOrUpdate.push({
entry_id: entryId,
chunk: i,
parent_id: entry.parent_id,
root_id: entry.root_id,
workspace_id: entry.workspace_id,
content: chunk,
embedding: [],
metadata: null,
});
}
const batchSize = aiSettings.openai.embeddingBatchSize;
for (let i = 0; i < embeddingsToCreateOrUpdate.length; i += batchSize) {
const batch = embeddingsToCreateOrUpdate.slice(i, i + batchSize);
const textsToEmbed = batch.map((item) => item.content);
const embeddingVectors = await embeddings.embedDocuments(textsToEmbed);
for (let j = 0; j < batch.length; j++) {
const vector = embeddingVectors[j];
const batchItem = batch[j];
if (vector && batchItem) {
batchItem.embedding = vector;
}
}
}
for (const item of embeddingsToCreateOrUpdate) {
if (item.embedding.length === 0) continue;
await database
.insertInto('ai_embeddings')
.values(
{
id,
entity_id: entryId,
entity_type: 'entry',
embedding: embedding,
content: chunk,
metadata,
created_at: new Date(),
.insertInto('entry_embeddings')
.values({
entry_id: item.entry_id,
chunk: item.chunk,
parent_id: item.parent_id,
root_id: item.root_id,
workspace_id: item.workspace_id,
content: item.content,
embedding: item.embedding,
metadata: item.metadata,
created_at: new Date(),
updated_at: new Date(),
})
.onConflict((oc) =>
oc.columns(['entry_id', 'chunk']).doUpdateSet({
content: item.content,
embedding: item.embedding,
metadata: item.metadata,
updated_at: new Date(),
})
)
.execute();
}
};
};

View File

@@ -3,7 +3,6 @@ import { ChunkingService } from '@/services/chunking-service';
import { database } from '@/data/database';
import { OpenAIEmbeddings } from '@langchain/openai';
import { aiSettings } from '@/lib/ai-settings';
import { generateId, IdType } from '@colanode/core';
export type EmbedMessageInput = {
type: 'embed_message';
@@ -32,9 +31,6 @@ export const embedMessageHandler: JobHandler<EmbedMessageInput> = async (
if (!message) {
return;
}
if (!message.content) {
return;
}
const chunkingService = new ChunkingService();
const chunks = await chunkingService.chunkText(message.content);
@@ -45,7 +41,22 @@ export const embedMessageHandler: JobHandler<EmbedMessageInput> = async (
dimensions: aiSettings.openai.embeddingDimensions,
});
const embeddingVectors = await embeddings.embedDocuments(chunks)
const existingEmbeddings = await database
.selectFrom('message_embeddings')
.select(['chunk', 'content'])
.where('message_id', '=', messageId)
.execute();
const embeddingsToCreateOrUpdate: {
message_id: string;
chunk: number;
parent_id: string;
root_id: string;
workspace_id: string;
content: string;
embedding: number[];
metadata: string | null;
}[] = [];
for (let i = 0; i < chunks.length; i++) {
const chunk = chunks[i];
@@ -53,28 +64,66 @@ export const embedMessageHandler: JobHandler<EmbedMessageInput> = async (
continue;
}
const embedding = embeddingVectors[i]
if (!embedding) {
const existingEmbedding = existingEmbeddings.find(
(e) => e.chunk === i
);
if (existingEmbedding && existingEmbedding.content === chunk) {
continue;
}
const id = generateId(IdType.AiEmbedding);
const metadata = JSON.stringify({
chunk_index: i,
embeddingsToCreateOrUpdate.push({
message_id: messageId,
chunk: i,
parent_id: message.parent_id,
root_id: message.root_id,
workspace_id: message.workspace_id,
content: chunk,
embedding: [],
metadata: null,
});
}
const batchSize = aiSettings.openai.embeddingBatchSize;
for (let i = 0; i < embeddingsToCreateOrUpdate.length; i += batchSize) {
const batch = embeddingsToCreateOrUpdate.slice(i, i + batchSize);
const textsToEmbed = batch.map((item) => item.content);
const embeddingVectors = await embeddings.embedDocuments(textsToEmbed);
for (let j = 0; j < batch.length; j++) {
const vector = embeddingVectors[j];
const batchItem = batch[j];
if (vector && batchItem) {
batchItem.embedding = vector;
}
}
}
for (const item of embeddingsToCreateOrUpdate) {
if (item.embedding.length === 0) continue;
await database
.insertInto('ai_embeddings')
.values(
{
id,
entity_id: messageId,
entity_type: 'message',
embedding: embedding,
content: chunk,
metadata,
created_at: new Date(),
.insertInto('message_embeddings')
.values({
message_id: item.message_id,
chunk: item.chunk,
parent_id: item.parent_id,
root_id: item.root_id,
workspace_id: item.workspace_id,
content: item.content,
embedding: item.embedding,
metadata: item.metadata,
created_at: new Date(),
updated_at: new Date(),
})
.onConflict((oc) =>
oc.columns(['message_id', 'chunk']).doUpdateSet({
content: item.content,
embedding: item.embedding,
metadata: item.metadata,
updated_at: new Date(),
})
)
.execute();
}
};

View File

@@ -3,6 +3,7 @@ export const aiSettings = {
apiKey: process.env.OPENAI_API_KEY || '',
embeddingModel: 'text-embedding-3-large',
embeddingDimensions: 2000,
embeddingBatchSize: 50,
},
chunking: {