initial version of using records/database filtering as context for assistant

This commit is contained in:
Ylber Gashi
2025-02-18 00:03:18 +01:00
parent 1217ee487b
commit a0a160cc8f
4 changed files with 769 additions and 2 deletions

View File

@@ -4,6 +4,9 @@ import {
generateNodeIndex,
getNodeModel,
NodeAttributes,
DatabaseNode,
RecordNode,
DatabaseAttributes,
} from '@colanode/core';
import { Document } from '@langchain/core/documents';
import { StateGraph, Annotation } from '@langchain/langgraph';
@@ -19,10 +22,12 @@ import {
generateFinalAnswer,
generateNoContextAnswer,
assessUserIntent,
generateDatabaseFilters,
} from '@/services/llm-service';
import { CallbackHandler } from 'langfuse-langchain';
import { fetchNode } from '@/lib/nodes';
import { sql } from 'kysely';
import { recordsRetrievalService } from '@/services/records-retrieval-service';
// ---------------------------------------------------------------------
// Job Input & Type Definitions
@@ -67,6 +72,21 @@ const ResponseState = Annotation.Root({
citations: Annotation<Array<{ sourceId: string; quote: string }>>(),
originalMessage: Annotation<any>(),
intent: Annotation<'retrieve' | 'no_context'>(),
databaseContext: Annotation<
Array<{
id: string;
name: string;
fields: Record<string, { type: string; name: string }>;
sampleRecords: any[];
}>
>(),
databaseFilters: Annotation<{
shouldFilter: boolean;
filters: Array<{
databaseId: string;
filters: any[]; // DatabaseViewFilterAttributes[]
}>;
}>(),
});
// ---------------------------------------------------------------------
@@ -123,7 +143,59 @@ async function fetchContextDocuments(state: typeof ResponseState.State) {
state.userId
),
]);
return { contextDocuments: [...nodeResults, ...documentResults] };
let databaseResults: Document[] = [];
// If we have database filters, fetch the filtered records
if (state.databaseFilters.shouldFilter) {
const filteredRecords = await Promise.all(
state.databaseFilters.filters.map(async (filter) => {
const records = await recordsRetrievalService.retrieveByFilters(
filter.databaseId,
state.workspaceId,
state.userId,
{
filters: filter.filters,
sorts: [],
page: 1,
count: 10, // Fetch top 10 matching records
}
);
// Get the database node to access its name
const dbNode = await fetchNode(filter.databaseId);
if (!dbNode || dbNode.type !== 'database') return [];
const dbAttributes = dbNode.attributes as DatabaseAttributes;
// Convert records to Documents
return records.map((record) => {
const recordNode = record as unknown as RecordNode;
return new Document({
pageContent: `Database Record from ${dbAttributes.name}:\n${Object.entries(
recordNode.attributes.fields || {}
)
.map(([key, value]) => `${key}: ${value}`)
.join('\n')}`,
metadata: {
id: record.id,
type: 'record',
createdAt: record.created_at,
author: record.created_by,
databaseId: filter.databaseId,
},
});
});
})
);
// Flatten the array of arrays into a single array
databaseResults = filteredRecords.flat();
}
return {
contextDocuments: [...nodeResults, ...documentResults, ...databaseResults],
};
}
async function fetchChatHistory(state: typeof ResponseState.State) {
@@ -240,6 +312,80 @@ async function generateResponse(state: typeof ResponseState.State) {
};
}
async function fetchDatabaseContext(state: typeof ResponseState.State) {
// Fetch all databases the user has access to
const databases = await database
.selectFrom('nodes as n')
.innerJoin('collaborations as c', 'c.node_id', 'n.root_id')
.where('n.type', '=', 'database')
.where('n.workspace_id', '=', state.workspaceId)
.where('c.collaborator_id', '=', state.userId)
.where('c.deleted_at', 'is', null)
.selectAll()
.execute();
// For each database, fetch schema and sample records
const databaseContext = await Promise.all(
databases.map(async (db) => {
const dbNode = db as unknown as DatabaseNode;
// Get sample records
const sampleRecords = await recordsRetrievalService.retrieveByFilters(
db.id,
state.workspaceId,
state.userId,
{
filters: [],
sorts: [],
page: 1,
count: 5, // Fetch 5 sample records
}
);
// Extract field information from database attributes
const fields = dbNode.attributes.fields || {};
const formattedFields = Object.entries(fields).reduce(
(acc, [id, field]) => ({
...acc,
[id]: {
type: field.type,
name: field.name,
},
}),
{}
);
return {
id: db.id,
name: dbNode.attributes.name || 'Untitled Database',
fields: formattedFields,
sampleRecords,
};
})
);
return { databaseContext };
}
async function generateDatabaseFilterAttributes(
state: typeof ResponseState.State
) {
if (state.intent === 'no_context' || !state.databaseContext.length) {
return {
databaseFilters: {
shouldFilter: false,
filters: [],
},
};
}
const filters = await generateDatabaseFilters({
query: state.userInput,
databases: state.databaseContext,
});
return { databaseFilters: filters };
}
// ---------------------------------------------------------------------
// Build the Response Chain Graph
// ---------------------------------------------------------------------
@@ -252,13 +398,17 @@ const assistantResponseChain = new StateGraph(ResponseState)
.addNode('generateResponse', generateResponse)
.addNode('assessIntent', assessIntent)
.addNode('generateNoContextResponse', generateNoContextResponse)
.addNode('fetchDatabaseContext', fetchDatabaseContext)
.addNode('generateDatabaseFilterAttributes', generateDatabaseFilterAttributes)
.addEdge('__start__', 'fetchChatHistory')
.addEdge('fetchChatHistory', 'assessIntent')
.addConditionalEdges('assessIntent', (state) => {
return state.intent === 'no_context'
? 'generateNoContextResponse'
: 'generateRewrittenQuery';
: 'fetchDatabaseContext';
})
.addEdge('fetchDatabaseContext', 'generateDatabaseFilterAttributes')
.addEdge('generateDatabaseFilterAttributes', 'generateRewrittenQuery')
.addEdge('generateRewrittenQuery', 'fetchContextDocuments')
.addEdge('fetchContextDocuments', 'rerankContextDocuments')
.addEdge('rerankContextDocuments', 'selectRelevantDocuments')

View File

@@ -102,6 +102,7 @@ export interface AiConfiguration {
contextEnhancer: AIModelConfiguration;
noContext: AIModelConfiguration;
intentRecognition: AIModelConfiguration;
databaseFilter: AIModelConfiguration;
};
embedding: {
provider: AIProvider;
@@ -279,6 +280,14 @@ export const configuration: Configuration = {
getOptionalEnv('INTENT_RECOGNITION_TEMPERATURE') || '0.3'
),
},
databaseFilter: {
provider: (getOptionalEnv('DATABASE_FILTER_PROVIDER') ||
'openai') as AIProvider,
modelName: getOptionalEnv('DATABASE_FILTER_MODEL') || 'gpt-4o-mini',
temperature: parseFloat(
getOptionalEnv('DATABASE_FILTER_TEMPERATURE') || '0.3'
),
},
},
embedding: {
provider: (getOptionalEnv('EMBEDDING_PROVIDER') ||

View File

@@ -35,6 +35,18 @@ const citedAnswerSchema = z.object({
});
type CitedAnswer = z.infer<typeof citedAnswerSchema>;
const databaseFilterSchema = z.object({
shouldFilter: z.boolean(),
filters: z.array(
z.object({
databaseId: z.string(),
filters: z.array(z.any()), // Using any for DatabaseViewFilterAttributes since it's complex
})
),
});
type DatabaseFilterResult = z.infer<typeof databaseFilterSchema>;
export function getChatModel(
task: keyof typeof configuration.ai.models
): ChatOpenAI | ChatGoogleGenerativeAI {
@@ -150,6 +162,49 @@ User Query:
Return only the answer.`
);
const databaseFilterPrompt = ChatPromptTemplate.fromTemplate(
`You are an expert at analyzing natural language queries and converting them into structured database filters.
Available Databases:
{databasesInfo}
User Query:
{query}
Your task is to:
1. Determine if this query is asking or makes sense to answer by filtering/searching databases
2. If yes, generate appropriate filter attributes for each relevant database
3. If no, return shouldFilter: false
Return a JSON object with:
- shouldFilter: boolean
- filters: array of objects with:
- databaseId: string
- filters: array of DatabaseViewFilterAttributes
Only include databases that are relevant to the query.
For each filter, use the exact field IDs from the database schema.
Use appropriate operators based on field types.
Example Response:
{
"shouldFilter": true,
"filters": [
{
"databaseId": "db1",
"filters": [
{
"type": "field",
"fieldId": "field1",
"operator": "contains",
"value": "search term"
}
]
}
]
}`
);
export async function rewriteQuery(query: string): Promise<string> {
const task = 'queryRewrite';
const model = getChatModel(task);
@@ -410,3 +465,43 @@ export async function addContextToChunk(
return chunk;
}
}
export async function generateDatabaseFilters(args: {
query: string;
databases: Array<{
id: string;
name: string;
fields: Record<string, { type: string; name: string }>;
sampleRecords: any[];
}>;
}): Promise<DatabaseFilterResult> {
const task = 'databaseFilter';
const model = getChatModel(task).withStructuredOutput(databaseFilterSchema);
// Format database information for the prompt
const databasesInfo = args.databases
.map(
(db) => `
Database: ${db.name} (ID: ${db.id})
Fields:
${Object.entries(db.fields)
.map(([id, field]) => `- ${field.name} (ID: ${id}, Type: ${field.type})`)
.join('\n')}
Sample Records:
${db.sampleRecords
.map(
(record, i) =>
`${i + 1}. ${Object.entries(record.attributes.fields)
.map(([fieldId, value]) => `${db.fields[fieldId]?.name}: ${value}`)
.join(', ')}`
)
.join('\n')}
`
)
.join('\n\n');
return databaseFilterPrompt
.pipe(model)
.invoke({ query: args.query, databasesInfo });
}

View File

@@ -0,0 +1,513 @@
import { sql, Kysely, Expression, SqlBool } from 'kysely';
import { database } from '@/data/database';
import {
BooleanFieldAttributes,
CreatedAtFieldAttributes,
DatabaseNode,
DateFieldAttributes,
EmailFieldAttributes,
FieldAttributes,
isStringArray,
NumberFieldAttributes,
PhoneFieldAttributes,
SelectFieldAttributes,
TextFieldAttributes,
UrlFieldAttributes,
DatabaseViewFieldFilterAttributes,
DatabaseViewFilterAttributes,
DatabaseViewSortAttributes,
MultiSelectFieldAttributes,
} from '@colanode/core';
import { DatabaseSchema } from '@/data/schema';
type FilterInput = {
filters: DatabaseViewFilterAttributes[];
sorts: DatabaseViewSortAttributes[];
page: number;
count: number;
};
type SearchInput = {
searchQuery: string;
exclude?: string[];
};
type TextBasedFieldAttributes =
| TextFieldAttributes
| EmailFieldAttributes
| PhoneFieldAttributes
| UrlFieldAttributes;
export class RecordsRetrievalService {
constructor(private readonly db: Kysely<DatabaseSchema>) {}
public async retrieveByFilters(
databaseId: string,
workspaceId: string,
userId: string,
input: FilterInput
) {
const database = await this.fetchDatabase(databaseId, workspaceId);
const filterQuery = this.buildFiltersQuery(
input.filters,
database.attributes.fields
);
const orderByQuery =
input.sorts.length > 0
? this.buildSortOrdersQuery(input.sorts, database.attributes.fields)
: 'n.id ASC';
const offset = (input.page - 1) * input.count;
const query = this.db
.selectFrom('nodes as n')
.innerJoin('collaborations as c', 'c.node_id', 'n.root_id')
.where('n.parent_id', '=', databaseId)
.where('n.type', '=', 'record')
.where('n.workspace_id', '=', workspaceId)
.where('c.collaborator_id', '=', userId)
.where('c.deleted_at', 'is', null);
if (filterQuery) {
query.where(filterQuery);
}
const result = await query
.orderBy(sql.raw(orderByQuery))
.limit(input.count)
.offset(offset)
.selectAll()
.execute();
return result;
}
public async searchRecords(
databaseId: string,
workspaceId: string,
userId: string,
input: SearchInput
) {
if (!input.searchQuery) {
return this.fetchAllRecords(
databaseId,
workspaceId,
userId,
input.exclude
);
}
const searchCondition = sql<SqlBool>`
to_tsvector('english', n.attributes->>'name') @@ plainto_tsquery('english', ${input.searchQuery})
OR EXISTS (
SELECT 1
FROM jsonb_each_text(n.attributes->'fields') fields
WHERE to_tsvector('english', fields.value::text) @@ plainto_tsquery('english', ${input.searchQuery})
)
`;
const query = this.db
.selectFrom('nodes as n')
.innerJoin('collaborations as c', 'c.node_id', 'n.root_id')
.where('n.parent_id', '=', databaseId)
.where('n.type', '=', 'record')
.where('n.workspace_id', '=', workspaceId)
.where('c.collaborator_id', '=', userId)
.where('c.deleted_at', 'is', null)
.where(searchCondition);
if (input.exclude?.length) {
query.where('n.id', 'not in', input.exclude);
}
return query.selectAll().execute();
}
private async fetchAllRecords(
databaseId: string,
workspaceId: string,
userId: string,
exclude?: string[]
) {
return this.db
.selectFrom('nodes as n')
.innerJoin('collaborations as c', 'c.node_id', 'n.root_id')
.where('n.parent_id', '=', databaseId)
.where('n.type', '=', 'record')
.where('n.workspace_id', '=', workspaceId)
.where('c.collaborator_id', '=', userId)
.where('c.deleted_at', 'is', null)
.$if(!!exclude?.length, (qb) => qb.where('n.id', 'not in', exclude!))
.selectAll()
.execute();
}
private async fetchDatabase(
databaseId: string,
workspaceId: string
): Promise<DatabaseNode> {
const row = await this.db
.selectFrom('nodes')
.where('id', '=', databaseId)
.where('workspace_id', '=', workspaceId)
.where('type', '=', 'database')
.selectAll()
.executeTakeFirst();
if (!row) {
throw new Error('Database not found');
}
return row as unknown as DatabaseNode;
}
private buildFiltersQuery(
filters: DatabaseViewFilterAttributes[],
fields: Record<string, FieldAttributes>
): Expression<SqlBool> | undefined {
if (filters.length === 0) {
return undefined;
}
const filterQueries = filters
.map((filter) => this.buildFilterQuery(filter, fields))
.filter((query): query is Expression<SqlBool> => query !== null);
if (filterQueries.length === 0) {
return undefined;
}
return sql<SqlBool>`(${sql.join(filterQueries, sql` AND `)})`;
}
private buildFilterQuery(
filter: DatabaseViewFilterAttributes,
fields: Record<string, FieldAttributes>
): Expression<SqlBool> | null {
if (filter.type === 'group') {
return null;
}
const field = fields[filter.fieldId];
if (!field) {
return null;
}
switch (field.type) {
case 'boolean':
return this.buildBooleanFilterQuery(filter, field);
case 'created_at':
return this.buildCreatedAtFilterQuery(filter, field);
case 'date':
return this.buildDateFilterQuery(filter, field);
case 'email':
case 'phone':
case 'url':
case 'text':
return this.buildTextFilterQuery(
filter,
field as TextBasedFieldAttributes
);
case 'multi_select':
return this.buildMultiSelectFilterQuery(filter, field);
case 'number':
return this.buildNumberFilterQuery(filter, field);
case 'select':
return this.buildSelectFilterQuery(filter, field);
default:
return null;
}
}
private buildBooleanFilterQuery(
filter: DatabaseViewFieldFilterAttributes,
field: BooleanFieldAttributes
): Expression<SqlBool> | null {
if (filter.operator === 'is_true') {
return sql<SqlBool>`(n.attributes->'fields'->${field.id}->>'value')::boolean = true`;
}
if (filter.operator === 'is_false') {
return sql<SqlBool>`((n.attributes->'fields'->${field.id}->>'value')::boolean = false OR n.attributes->'fields'->${field.id}->>'value' IS NULL)`;
}
return null;
}
private buildNumberFilterQuery(
filter: DatabaseViewFieldFilterAttributes,
field: NumberFieldAttributes
): Expression<SqlBool> | null {
if (filter.operator === 'is_empty') {
return sql<SqlBool>`n.attributes->'fields'->${field.id}->>'value' IS NULL`;
}
if (filter.operator === 'is_not_empty') {
return sql<SqlBool>`n.attributes->'fields'->${field.id}->>'value' IS NOT NULL`;
}
if (filter.value === null || typeof filter.value !== 'number') {
return null;
}
const value = filter.value;
let operator: string;
switch (filter.operator) {
case 'is_equal_to':
operator = '=';
break;
case 'is_not_equal_to':
operator = '!=';
break;
case 'is_greater_than':
operator = '>';
break;
case 'is_less_than':
operator = '<';
break;
case 'is_greater_than_or_equal_to':
operator = '>=';
break;
case 'is_less_than_or_equal_to':
operator = '<=';
break;
default:
return null;
}
return sql<SqlBool>`(n.attributes->'fields'->${field.id}->>'value')::numeric ${sql.raw(operator)} ${value}`;
}
private buildTextFilterQuery(
filter: DatabaseViewFieldFilterAttributes,
field: TextBasedFieldAttributes
): Expression<SqlBool> | null {
if (filter.operator === 'is_empty') {
return sql<SqlBool>`n.attributes->'fields'->${field.id}->>'value' IS NULL`;
}
if (filter.operator === 'is_not_empty') {
return sql<SqlBool>`n.attributes->'fields'->${field.id}->>'value' IS NOT NULL`;
}
if (filter.value === null || typeof filter.value !== 'string') {
return null;
}
const value = filter.value;
switch (filter.operator) {
case 'is_equal_to':
return sql<SqlBool>`n.attributes->'fields'->${field.id}->>'value' = ${value}`;
case 'is_not_equal_to':
return sql<SqlBool>`n.attributes->'fields'->${field.id}->>'value' != ${value}`;
case 'contains':
return sql<SqlBool>`n.attributes->'fields'->${field.id}->>'value' ILIKE ${'%' + value + '%'}`;
case 'does_not_contain':
return sql<SqlBool>`n.attributes->'fields'->${field.id}->>'value' NOT ILIKE ${'%' + value + '%'}`;
case 'starts_with':
return sql<SqlBool>`n.attributes->'fields'->${field.id}->>'value' ILIKE ${value + '%'}`;
case 'ends_with':
return sql<SqlBool>`n.attributes->'fields'->${field.id}->>'value' ILIKE ${'%' + value}`;
default:
return null;
}
}
private buildEmailFilterQuery = this.buildTextFilterQuery;
private buildPhoneFilterQuery = this.buildTextFilterQuery;
private buildUrlFilterQuery = this.buildTextFilterQuery;
private buildSelectFilterQuery(
filter: DatabaseViewFieldFilterAttributes,
field: SelectFieldAttributes
): Expression<SqlBool> | null {
if (filter.operator === 'is_empty') {
return sql<SqlBool>`n.attributes->'fields'->${field.id}->>'value' IS NULL`;
}
if (filter.operator === 'is_not_empty') {
return sql<SqlBool>`n.attributes->'fields'->${field.id}->>'value' IS NOT NULL`;
}
if (!isStringArray(filter.value) || filter.value.length === 0) {
return null;
}
switch (filter.operator) {
case 'is_in':
return sql<SqlBool>`n.attributes->'fields'->${field.id}->>'value' IN (${sql.join(filter.value)})`;
case 'is_not_in':
return sql<SqlBool>`n.attributes->'fields'->${field.id}->>'value' NOT IN (${sql.join(filter.value)})`;
default:
return null;
}
}
private buildMultiSelectFilterQuery(
filter: DatabaseViewFieldFilterAttributes,
field: MultiSelectFieldAttributes
): Expression<SqlBool> | null {
if (filter.operator === 'is_empty') {
return sql<SqlBool>`(n.attributes->'fields'->${field.id}->>'value' IS NULL OR jsonb_array_length(n.attributes->'fields'->${field.id}->'value') = 0)`;
}
if (filter.operator === 'is_not_empty') {
return sql<SqlBool>`(n.attributes->'fields'->${field.id}->>'value' IS NOT NULL AND jsonb_array_length(n.attributes->'fields'->${field.id}->'value') > 0)`;
}
if (!isStringArray(filter.value) || filter.value.length === 0) {
return null;
}
switch (filter.operator) {
case 'is_in':
return sql<SqlBool>`EXISTS (
SELECT 1
FROM jsonb_array_elements_text(n.attributes->'fields'->${field.id}->'value') val
WHERE val IN (${sql.join(filter.value)})
)`;
case 'is_not_in':
return sql<SqlBool>`NOT EXISTS (
SELECT 1
FROM jsonb_array_elements_text(n.attributes->'fields'->${field.id}->'value') val
WHERE val IN (${sql.join(filter.value)})
)`;
default:
return null;
}
}
private buildDateFilterQuery(
filter: DatabaseViewFieldFilterAttributes,
field: DateFieldAttributes
): Expression<SqlBool> | null {
if (filter.operator === 'is_empty') {
return sql<SqlBool>`n.attributes->'fields'->${field.id}->>'value' IS NULL`;
}
if (filter.operator === 'is_not_empty') {
return sql<SqlBool>`n.attributes->'fields'->${field.id}->>'value' IS NOT NULL`;
}
if (filter.value === null || typeof filter.value !== 'string') {
return null;
}
const date = new Date(filter.value);
if (isNaN(date.getTime())) {
return null;
}
const dateString = date.toISOString().split('T')[0];
let operator: string;
switch (filter.operator) {
case 'is_equal_to':
operator = '=';
break;
case 'is_not_equal_to':
operator = '!=';
break;
case 'is_on_or_after':
operator = '>=';
break;
case 'is_on_or_before':
operator = '<=';
break;
case 'is_after':
operator = '>';
break;
case 'is_before':
operator = '<';
break;
default:
return null;
}
return sql<SqlBool>`DATE(n.attributes->'fields'->${field.id}->>'value') ${sql.raw(operator)} ${dateString}`;
}
private buildCreatedAtFilterQuery(
filter: DatabaseViewFieldFilterAttributes,
_: CreatedAtFieldAttributes
): Expression<SqlBool> | null {
if (filter.operator === 'is_empty') {
return sql<SqlBool>`n.created_at IS NULL`;
}
if (filter.operator === 'is_not_empty') {
return sql<SqlBool>`n.created_at IS NOT NULL`;
}
if (filter.value === null || typeof filter.value !== 'string') {
return null;
}
const date = new Date(filter.value);
if (isNaN(date.getTime())) {
return null;
}
const dateString = date.toISOString().split('T')[0];
let operator: string;
switch (filter.operator) {
case 'is_equal_to':
operator = '=';
break;
case 'is_not_equal_to':
operator = '!=';
break;
case 'is_on_or_after':
operator = '>=';
break;
case 'is_on_or_before':
operator = '<=';
break;
case 'is_after':
operator = '>';
break;
case 'is_before':
operator = '<';
break;
default:
return null;
}
return sql<SqlBool>`DATE(n.created_at) ${sql.raw(operator)} ${dateString}`;
}
private buildSortOrdersQuery(
sorts: DatabaseViewSortAttributes[],
fields: Record<string, FieldAttributes>
): string {
return sorts
.map((sort) => this.buildSortOrderQuery(sort, fields))
.filter((query): query is string => query !== null)
.join(', ');
}
private buildSortOrderQuery(
sort: DatabaseViewSortAttributes,
fields: Record<string, FieldAttributes>
): string | null {
const field = fields[sort.fieldId];
if (!field) {
return null;
}
if (field.type === 'created_at') {
return `n.created_at ${sort.direction}`;
}
if (field.type === 'created_by') {
return `n.created_by ${sort.direction}`;
}
return `n.attributes->'fields'->${sort.fieldId}->>'value' ${sort.direction}`;
}
}
export const recordsRetrievalService = new RecordsRetrievalService(database);