From 557a0eb8979a74007e6bdb3a76cfac518e48e592 Mon Sep 17 00:00:00 2001 From: pranav-kural Date: Sat, 13 Jul 2024 15:06:16 -0400 Subject: [PATCH] implemented functionality to optionally provide `docs` or `splitDocs` to data retriever --- src/rag/data-retrievers/data-retrievers.ts | 80 ++++++++++++---------- 1 file changed, 44 insertions(+), 36 deletions(-) diff --git a/src/rag/data-retrievers/data-retrievers.ts b/src/rag/data-retrievers/data-retrievers.ts index aaa4315..3d11351 100644 --- a/src/rag/data-retrievers/data-retrievers.ts +++ b/src/rag/data-retrievers/data-retrievers.ts @@ -1,27 +1,28 @@ -import { MemoryVectorStore } from "langchain/vectorstores/memory"; -import { GoogleGenerativeAIEmbeddings } from "@langchain/google-genai"; -import { formatDocumentsAsString } from "langchain/util/document"; -import { Runnable, RunnableConfig } from "@langchain/core/runnables"; -import { EmbeddingsInterface } from "@langchain/core/embeddings"; +import {MemoryVectorStore} from 'langchain/vectorstores/memory'; +import {GoogleGenerativeAIEmbeddings} from '@langchain/google-genai'; +import {formatDocumentsAsString} from 'langchain/util/document'; +import {Runnable, RunnableConfig} from '@langchain/core/runnables'; +import {EmbeddingsInterface} from '@langchain/core/embeddings'; import { VectorStore, VectorStoreRetrieverInput, -} from "@langchain/core/vectorstores"; +} from '@langchain/core/vectorstores'; +import {Document} from 'langchain/document'; import { CSVLoaderOptions, JSONLoaderKeysToInclude, PDFLoaderOptions, SupportedDataLoaderTypes, getDocs, -} from "../data-loaders/data-loaders"; +} from '../data-loaders/data-loaders'; import { ChunkingConfig, DataSplitterConfig, SupportedDataSplitterTypes, runDataSplitter, -} from "../data-splitters/data-splitters"; -import { GOOGLE_GENAI_EMBEDDING_MODELS } from "../data-embeddings/embedding-models"; -import { getEnvironmentVariable } from "../../utils/utils"; +} from '../data-splitters/data-splitters'; +import {GOOGLE_GENAI_EMBEDDING_MODELS} from '../data-embeddings/embedding-models'; +import {getEnvironmentVariable} from '../../utils/utils'; /** * Type denoting a retriever that retrieves text data. @@ -41,6 +42,8 @@ export type RetrievalOptions = * Represents the configuration for the retriever when generating embeddings. * @property {SupportedDataLoaderTypes} dataType - The type of data loader to use. * @property {string} filePath - The path to the file containing the data. + * @property {Document>[]} [docs] - Optional: Provide an array containing LangChain document objects for the data. + * @property {Document>[]} [splitDocs] - Optional: Provide an array containing LangChain document objects for the split data. * @property {JSONLoaderKeysToInclude} [jsonLoaderKeysToInclude] - The keys to include when loading JSON data. * @property {CSVLoaderOptions} [csvLoaderOptions] - The options for loading CSV data. * @property {PDFLoaderOptions} [pdfLoaderOptions] - The options for loading PDF data. @@ -55,6 +58,8 @@ export type RetrievalOptions = export type RetrieverConfigGeneratingEmbeddings = { dataType: SupportedDataLoaderTypes; filePath: string; + docs?: Document>[]; + splitDocs?: Document>[]; jsonLoaderKeysToInclude?: JSONLoaderKeysToInclude; csvLoaderOptions?: CSVLoaderOptions; pdfLoaderOptions?: PDFLoaderOptions; @@ -91,12 +96,12 @@ export type RetrieverConfig = * Task type for embedding content. */ export enum TaskType { - TASK_TYPE_UNSPECIFIED = "TASK_TYPE_UNSPECIFIED", - RETRIEVAL_QUERY = "RETRIEVAL_QUERY", - RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT", - SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY", - CLASSIFICATION = "CLASSIFICATION", - CLUSTERING = "CLUSTERING", + TASK_TYPE_UNSPECIFIED = 'TASK_TYPE_UNSPECIFIED', + RETRIEVAL_QUERY = 'RETRIEVAL_QUERY', + RETRIEVAL_DOCUMENT = 'RETRIEVAL_DOCUMENT', + SEMANTIC_SIMILARITY = 'SEMANTIC_SIMILARITY', + CLASSIFICATION = 'CLASSIFICATION', + CLUSTERING = 'CLUSTERING', } /** @@ -113,19 +118,19 @@ export const getAppropriateDataSplitter = ( defaultSplitterConfig?: DataSplitterConfig; } => { switch (dataType) { - case "csv": - case "json": + case 'csv': + case 'json': return { - defaultDataSplitterType: "character", + defaultDataSplitterType: 'character', defaultSplitterConfig: { textSplitterConfig: { - separators: ["\n"], + separators: ['\n'], }, }, }; default: return { - defaultDataSplitterType: "text", + defaultDataSplitterType: 'text', }; } }; @@ -152,7 +157,7 @@ export const getDataRetriever = async ( // vector store must be provided if (!config.vectorStore) throw new Error( - "Vector store must be provided when not generating embeddings" + 'Vector store must be provided when not generating embeddings' ); // return retriever else @@ -164,36 +169,39 @@ export const getDataRetriever = async ( // if generating embeddings, data type must be provided if (!config.dataType) { throw new Error( - "Data type and file path must be provided when generating embeddings" + 'Data type and file path must be provided when generating embeddings' ); } // if generating embeddings, file path must be provided - if (!config.filePath || config.filePath === "") { - throw new Error("Invalid file path. File path must be provided"); + if (!config.filePath || config.filePath === '') { + throw new Error('Invalid file path. File path must be provided'); } try { // Retrieve the documents from the specified file path - const docs = await getDocs(config.dataType, config.filePath); + const docs: Document>[] = + config.docs ?? (await getDocs(config.dataType, config.filePath)); - const { defaultDataSplitterType, defaultSplitterConfig } = + const {defaultDataSplitterType, defaultSplitterConfig} = getAppropriateDataSplitter(config.dataType); // Split the retrieved documents into chunks using the data splitter - const splitDocs = await runDataSplitter({ - docs, - dataSplitterType: config.dataSplitterType ?? defaultDataSplitterType, - chunkingConfig: config.chunkingConfig ?? defaultChunkingConfig, - splitterConfig: config.splitterConfig ?? defaultSplitterConfig, - }); + const splitDocs: Document>[] = + config.splitDocs ?? + (await runDataSplitter({ + docs, + dataSplitterType: config.dataSplitterType ?? defaultDataSplitterType, + chunkingConfig: config.chunkingConfig ?? defaultChunkingConfig, + splitterConfig: config.splitterConfig ?? defaultSplitterConfig, + })); // embedding model - if not provided, use the default Google Generative AI Embeddings model const embeddings: EmbeddingsInterface = config.embeddingModel ?? new GoogleGenerativeAIEmbeddings({ - apiKey: getEnvironmentVariable("GOOGLE_GENAI_API_KEY"), - model: GOOGLE_GENAI_EMBEDDING_MODELS["text-embedding-004"].name, + apiKey: getEnvironmentVariable('GOOGLE_GENAI_API_KEY'), + model: GOOGLE_GENAI_EMBEDDING_MODELS['text-embedding-004'].name, taskType: TaskType.RETRIEVAL_DOCUMENT, }); @@ -207,7 +215,7 @@ export const getDataRetriever = async ( config.embeddingModel ?? new GoogleGenerativeAIEmbeddings({ apiKey: process.env.GOOGLE_GENAI_API_KEY, - model: GOOGLE_GENAI_EMBEDDING_MODELS["text-embedding-004"].name, + model: GOOGLE_GENAI_EMBEDDING_MODELS['text-embedding-004'].name, taskType: TaskType.RETRIEVAL_QUERY, });