Skip to content

Commit

Permalink
implemented functionality to optionally provide docs or splitDocs
Browse files Browse the repository at this point in the history
… to data retriever
  • Loading branch information
pranav-kural committed Jul 13, 2024
1 parent 6cd321f commit 557a0eb
Showing 1 changed file with 44 additions and 36 deletions.
80 changes: 44 additions & 36 deletions src/rag/data-retrievers/data-retrievers.ts
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
import { MemoryVectorStore } from "langchain/vectorstores/memory";
import { GoogleGenerativeAIEmbeddings } from "@langchain/google-genai";
import { formatDocumentsAsString } from "langchain/util/document";
import { Runnable, RunnableConfig } from "@langchain/core/runnables";
import { EmbeddingsInterface } from "@langchain/core/embeddings";
import {MemoryVectorStore} from 'langchain/vectorstores/memory';
import {GoogleGenerativeAIEmbeddings} from '@langchain/google-genai';
import {formatDocumentsAsString} from 'langchain/util/document';
import {Runnable, RunnableConfig} from '@langchain/core/runnables';
import {EmbeddingsInterface} from '@langchain/core/embeddings';
import {
VectorStore,
VectorStoreRetrieverInput,
} from "@langchain/core/vectorstores";
} from '@langchain/core/vectorstores';
import {Document} from 'langchain/document';
import {
CSVLoaderOptions,
JSONLoaderKeysToInclude,
PDFLoaderOptions,
SupportedDataLoaderTypes,
getDocs,
} from "../data-loaders/data-loaders";
} from '../data-loaders/data-loaders';
import {
ChunkingConfig,
DataSplitterConfig,
SupportedDataSplitterTypes,
runDataSplitter,
} from "../data-splitters/data-splitters";
import { GOOGLE_GENAI_EMBEDDING_MODELS } from "../data-embeddings/embedding-models";
import { getEnvironmentVariable } from "../../utils/utils";
} from '../data-splitters/data-splitters';
import {GOOGLE_GENAI_EMBEDDING_MODELS} from '../data-embeddings/embedding-models';
import {getEnvironmentVariable} from '../../utils/utils';

/**
* Type denoting a retriever that retrieves text data.
Expand All @@ -41,6 +42,8 @@ export type RetrievalOptions =
* Represents the configuration for the retriever when generating embeddings.
* @property {SupportedDataLoaderTypes} dataType - The type of data loader to use.
* @property {string} filePath - The path to the file containing the data.
* @property {Document<Record<string, string>>[]} [docs] - Optional: Provide an array containing LangChain document objects for the data.
* @property {Document<Record<string, unknown>>[]} [splitDocs] - Optional: Provide an array containing LangChain document objects for the split data.
* @property {JSONLoaderKeysToInclude} [jsonLoaderKeysToInclude] - The keys to include when loading JSON data.
* @property {CSVLoaderOptions} [csvLoaderOptions] - The options for loading CSV data.
* @property {PDFLoaderOptions} [pdfLoaderOptions] - The options for loading PDF data.
Expand All @@ -55,6 +58,8 @@ export type RetrievalOptions =
export type RetrieverConfigGeneratingEmbeddings = {
dataType: SupportedDataLoaderTypes;
filePath: string;
docs?: Document<Record<string, string>>[];
splitDocs?: Document<Record<string, unknown>>[];
jsonLoaderKeysToInclude?: JSONLoaderKeysToInclude;
csvLoaderOptions?: CSVLoaderOptions;
pdfLoaderOptions?: PDFLoaderOptions;
Expand Down Expand Up @@ -91,12 +96,12 @@ export type RetrieverConfig =
* Task type for embedding content.
*/
export enum TaskType {
TASK_TYPE_UNSPECIFIED = "TASK_TYPE_UNSPECIFIED",
RETRIEVAL_QUERY = "RETRIEVAL_QUERY",
RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT",
SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY",
CLASSIFICATION = "CLASSIFICATION",
CLUSTERING = "CLUSTERING",
TASK_TYPE_UNSPECIFIED = 'TASK_TYPE_UNSPECIFIED',
RETRIEVAL_QUERY = 'RETRIEVAL_QUERY',
RETRIEVAL_DOCUMENT = 'RETRIEVAL_DOCUMENT',
SEMANTIC_SIMILARITY = 'SEMANTIC_SIMILARITY',
CLASSIFICATION = 'CLASSIFICATION',
CLUSTERING = 'CLUSTERING',
}

/**
Expand All @@ -113,19 +118,19 @@ export const getAppropriateDataSplitter = (
defaultSplitterConfig?: DataSplitterConfig;
} => {
switch (dataType) {
case "csv":
case "json":
case 'csv':
case 'json':
return {
defaultDataSplitterType: "character",
defaultDataSplitterType: 'character',
defaultSplitterConfig: {
textSplitterConfig: {
separators: ["\n"],
separators: ['\n'],
},
},
};
default:
return {
defaultDataSplitterType: "text",
defaultDataSplitterType: 'text',
};
}
};
Expand All @@ -152,7 +157,7 @@ export const getDataRetriever = async (
// vector store must be provided
if (!config.vectorStore)
throw new Error(
"Vector store must be provided when not generating embeddings"
'Vector store must be provided when not generating embeddings'
);
// return retriever
else
Expand All @@ -164,36 +169,39 @@ export const getDataRetriever = async (
// if generating embeddings, data type must be provided
if (!config.dataType) {
throw new Error(
"Data type and file path must be provided when generating embeddings"
'Data type and file path must be provided when generating embeddings'
);
}

// if generating embeddings, file path must be provided
if (!config.filePath || config.filePath === "") {
throw new Error("Invalid file path. File path must be provided");
if (!config.filePath || config.filePath === '') {
throw new Error('Invalid file path. File path must be provided');
}

try {
// Retrieve the documents from the specified file path
const docs = await getDocs(config.dataType, config.filePath);
const docs: Document<Record<string, string>>[] =
config.docs ?? (await getDocs(config.dataType, config.filePath));

const { defaultDataSplitterType, defaultSplitterConfig } =
const {defaultDataSplitterType, defaultSplitterConfig} =
getAppropriateDataSplitter(config.dataType);

// Split the retrieved documents into chunks using the data splitter
const splitDocs = await runDataSplitter({
docs,
dataSplitterType: config.dataSplitterType ?? defaultDataSplitterType,
chunkingConfig: config.chunkingConfig ?? defaultChunkingConfig,
splitterConfig: config.splitterConfig ?? defaultSplitterConfig,
});
const splitDocs: Document<Record<string, unknown>>[] =
config.splitDocs ??
(await runDataSplitter({
docs,
dataSplitterType: config.dataSplitterType ?? defaultDataSplitterType,
chunkingConfig: config.chunkingConfig ?? defaultChunkingConfig,
splitterConfig: config.splitterConfig ?? defaultSplitterConfig,
}));

// embedding model - if not provided, use the default Google Generative AI Embeddings model
const embeddings: EmbeddingsInterface =
config.embeddingModel ??
new GoogleGenerativeAIEmbeddings({
apiKey: getEnvironmentVariable("GOOGLE_GENAI_API_KEY"),
model: GOOGLE_GENAI_EMBEDDING_MODELS["text-embedding-004"].name,
apiKey: getEnvironmentVariable('GOOGLE_GENAI_API_KEY'),
model: GOOGLE_GENAI_EMBEDDING_MODELS['text-embedding-004'].name,
taskType: TaskType.RETRIEVAL_DOCUMENT,
});

Expand All @@ -207,7 +215,7 @@ export const getDataRetriever = async (
config.embeddingModel ??
new GoogleGenerativeAIEmbeddings({
apiKey: process.env.GOOGLE_GENAI_API_KEY,
model: GOOGLE_GENAI_EMBEDDING_MODELS["text-embedding-004"].name,
model: GOOGLE_GENAI_EMBEDDING_MODELS['text-embedding-004'].name,
taskType: TaskType.RETRIEVAL_QUERY,
});

Expand Down

0 comments on commit 557a0eb

Please sign in to comment.