Skip to content

Commit

Permalink
[Feat] Rag scrap notice and embedding for vectorDB (#191)
Browse files Browse the repository at this point in the history
* setting: Chroma Vector DB 의존성 설정

* feat: 환경설정 파일 수정

* feat(QueryVectorStoreAdapter): QueryVectorStoreAdapter를 ChromaVectorStore를 사용하여 구현

* feat(Notice): Notice 테이블에 embedded boolean 필드 추가

해당 공지가 임베딩 되었는지 확인하는 컬럼을 추가한다

* feat(NoticeTextParserTemplate): 공지의 본문, 제목, 아이디를 파싱하는 ParserTemplate 구현

* test: ChromaDB test container 설정

* feat(NoticeApiClient): 단일 페이지를 scrap하는 requestSinglePageWithUrl 구현

* fix(NoticeJdbcRepository): 공지에 추가된 embedded 필드를 위해 bulk insert method 일부 수정

* feat(NoticeRepository): updateNoticeEmbeddingStatus, findNotYetEmbeddingNotice 메서드 구현

* fix(KuisHomepageNoticeTextParser): 본문을 포함하는 추가 테그를 파싱하는 로직 추가

* feat(KuisHomepageNoticeInfo): textParser 의존성 추가

* feat(ChromaVectorStoreAdapter): ChromaVector 구현

* test(KuisHomepageNoticeScraperTemplateTest): 임베딩 테스트 scrapForEmbedding 작성

* feat(RAGConfiguration): RAG 환경설정 구현

* feat(NoticeEmbeddingUpdater): 공지 embedding을 위한 Updater 구현

* feat: 공지 updater 작업 수행 시간 변경

* chore: 설정파일에 collection-name 추가

* fix(ChromaVectorStoreAdapter): embedding 메서드 수정과 테스트 추가

* feat(ChromaVectorStoreAdapter): 유사도 임계치 제거

유사도가 낮아도 답변을 꼭 생성하는쪽으로 구현

* feat: 사용하지 않는 RestTemplateConfig 제거

* chore: Public 접근 제어자 제거

* feat(ChromaVectorStoreAdapter): Top-K 를 2로 변경

* feat(User): 한달 질문 가능 횟수를 3번으로 변경

* feat(UserUpdater#questionCountReset): 매달 마지막날 사용자 질문 카운트 초기화 작업 구현

* feat(UserRegisterNonChainingFilter): 사용자 중복 등록 예외 로그를 남기도록 처리

* feat(UserUpdater): 사용자 제거작업 중지

* setting: ai max token 1000으로 변경

* feat(RAGQueryApiV2): RAGQueryApi 문서화

* refactor: SecurityRequirement에서 상수를 사용하도록 변경

* feat(User): 사용자 질문 횟수 2로 제한
  • Loading branch information
zbqmgldjfh authored Jul 22, 2024
1 parent c8cd35a commit 1fb4c97
Show file tree
Hide file tree
Showing 61 changed files with 2,343 additions and 121 deletions.
8 changes: 5 additions & 3 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ dependencies {

// AI
implementation "org.springframework.ai:spring-ai-openai-spring-boot-starter:${springAiVersion}"
implementation "org.springframework.ai:spring-ai-chroma-store-spring-boot-starter:${springAiVersion}"

// DB
implementation 'org.springframework.boot:spring-boot-starter-data-jpa'
Expand Down Expand Up @@ -102,9 +103,10 @@ dependencies {

// Test Container
testImplementation 'org.junit.jupiter:junit-jupiter:5.8.1'
testImplementation 'org.testcontainers:testcontainers:1.19.3'
testImplementation 'org.testcontainers:junit-jupiter:1.19.3'
testImplementation 'org.testcontainers:mariadb:1.19.3'
testImplementation 'org.testcontainers:testcontainers:1.19.8'
testImplementation 'org.testcontainers:junit-jupiter:1.19.8'
testImplementation 'org.testcontainers:mariadb:1.19.8'
testImplementation 'org.testcontainers:chromadb:1.19.8'
}

dependencyManagement {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
import com.kustacks.kuring.ai.application.port.in.RAGQueryUseCase;
import com.kustacks.kuring.common.annotation.RestWebAdapter;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.security.SecurityRequirement;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestParam;
import reactor.core.publisher.Flux;

@Tag(name = "AI-Query", description = "AI Assistant")
@RequiredArgsConstructor
@RestWebAdapter(path = "/api/v2/ai/messages")
public class RAGQueryApiV2 {
Expand All @@ -20,10 +23,10 @@ public class RAGQueryApiV2 {
private final RAGQueryUseCase ragQueryUseCase;

@Operation(summary = "사용자 AI에 질문요청", description = "사용자가 궁금한 학교 정보를 AI에게 질문합니다.")
@SecurityRequirement(name = "User-Token")
@SecurityRequirement(name = USER_TOKEN_HEADER_KEY)
@GetMapping(produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<String> askAIQuery(
@RequestParam("question") String question,
@Parameter(description = "사용자 질문") @RequestParam("question") String question,
@RequestHeader(USER_TOKEN_HEADER_KEY) String id
) {
return ragQueryUseCase.askAiModel(question, id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

@Slf4j
@Component
@Profile("dev | local | test")
@Profile("dev | test")
@RequiredArgsConstructor
public class InMemoryQueryAiModelAdapter implements QueryAiModelPort {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@Slf4j
@Component
@Profile("prod")
@Profile("prod | local")
@RequiredArgsConstructor
public class QueryAiModelAdapter implements QueryAiModelPort {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package com.kustacks.kuring.ai.adapter.out.persistence;

import com.kustacks.kuring.ai.application.port.out.CommandVectorStorePort;
import com.kustacks.kuring.ai.application.port.out.QueryVectorStorePort;
import com.kustacks.kuring.notice.domain.CategoryName;
import com.kustacks.kuring.worker.parser.notice.PageTextDto;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.document.Document;
import org.springframework.ai.reader.TextReader;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.ChromaVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.context.annotation.Profile;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Component;

import java.util.List;

@Component
@Profile("prod | local")
@RequiredArgsConstructor
public class ChromaVectorStoreAdapter implements QueryVectorStorePort, CommandVectorStorePort {

private static final int TOP_K = 2;

private final ChromaVectorStore chromaVectorStore;

@Override
public List<String> findSimilarityContents(String question) {
return chromaVectorStore.similaritySearch(
SearchRequest.query(question).withTopK(TOP_K)
).stream()
.map(Document::getContent)
.toList();
}

@Override
public void embedding(List<PageTextDto> extractTextResults, CategoryName categoryName) {
TokenTextSplitter textSplitter = new TokenTextSplitter();

for (PageTextDto textResult : extractTextResults) {
if (textResult.text().isBlank()) continue;

List<Document> documents = createDocuments(categoryName, textResult);
List<Document> splitDocuments = textSplitter.apply(documents);
chromaVectorStore.accept(splitDocuments);
}
}

private List<Document> createDocuments(CategoryName categoryName, PageTextDto textResult) {
Resource resource = new ByteArrayResource(textResult.text().getBytes()) {
@Override
public String getFilename() {
return textResult.title();
}
};

TextReader textReader = new TextReader(resource);
textReader.getCustomMetadata().put("articleId", textResult.articleId());
textReader.getCustomMetadata().put("category", categoryName.getName());
return textReader.get();
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package com.kustacks.kuring.ai.adapter.out.persistence;

import com.kustacks.kuring.ai.application.port.out.CommandVectorStorePort;
import com.kustacks.kuring.ai.application.port.out.QueryVectorStorePort;
import com.kustacks.kuring.notice.domain.CategoryName;
import com.kustacks.kuring.worker.parser.notice.PageTextDto;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
Expand All @@ -12,10 +15,10 @@
import java.util.stream.Stream;

@Slf4j
@Profile("local | dev | test")
@Profile("dev | test")
@Component
@RequiredArgsConstructor
public class InMemoryQueryVectorStoreAdapter implements QueryVectorStorePort {
public class InMemoryVectorStoreAdapter implements QueryVectorStorePort, CommandVectorStorePort {

@Override
public List<String> findSimilarityContents(String question) {
Expand All @@ -28,11 +31,18 @@ public List<String> findSimilarityContents(String question) {
.toList();
}

@Override
public void embedding(List<PageTextDto> extractTextResults, CategoryName categoryName) {
log.info("[InMemoryQueryVectorStoreAdapter] embedding {}", categoryName);
}

private Document createDocument(HashMap<String, Object> metadata) {
return new Document(
"a5a7414f-f676-409b-9f2e-1042f9846c97",
"● 등록금 전액 완납 또는 분할납부 1차분을 정해진 기간에 미납할 경우 분할납부 신청은 자동 취소되며, 미납 등록금은 이후\n" +
"추가 등록기간에 전액 납부해야 함.\n",
"""
● 등록금 전액 완납 또는 분할납부 1차분을 정해진 기간에 미납할 경우 분할납부 신청은 자동 취소되며,
미납 등록금은 이후 추가 등록기간에 전액 납부해야 함.\n
""",
metadata);
}

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.kustacks.kuring.ai.application.port.out;

import com.kustacks.kuring.notice.domain.CategoryName;
import com.kustacks.kuring.worker.parser.notice.PageTextDto;

import java.util.List;

public interface CommandVectorStorePort {
void embedding(List<PageTextDto> extractTextResults, CategoryName categoryName);
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ public class RAGQueryService implements RAGQueryUseCase {

@Override
public Flux<String> askAiModel(String question, String id) {
Prompt completePrompt = buildCompletePrompt(question);
ragEventPort.userDecreaseQuestionCountEvent(id);
Prompt completePrompt = buildCompletePrompt(question);
return ragChatModel.call(completePrompt);
}

Expand All @@ -45,7 +45,7 @@ private void init() {

private Prompt buildCompletePrompt(String question) {
List<String> similarDocuments = vectorStorePort.findSimilarityContents(question);
if(similarDocuments.isEmpty()) {
if (similarDocuments.isEmpty()) {
throw new InvalidStateException(ErrorCode.AI_SIMILAR_DOCUMENTS_NOT_FOUND);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,24 @@
import com.kustacks.kuring.auth.exception.RegisterException;
import com.kustacks.kuring.auth.handler.AuthenticationFailureHandler;
import com.kustacks.kuring.auth.handler.AuthenticationSuccessHandler;
import com.kustacks.kuring.common.properties.ServerProperties;
import com.kustacks.kuring.message.application.port.in.FirebaseWithUserUseCase;
import com.kustacks.kuring.message.application.port.in.dto.UserSubscribeCommand;
import com.kustacks.kuring.common.properties.ServerProperties;
import com.kustacks.kuring.user.application.port.out.UserCommandPort;
import com.kustacks.kuring.user.domain.User;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.dao.DuplicateKeyException;
import org.springframework.web.servlet.HandlerInterceptor;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.stream.Collectors;

import static com.kustacks.kuring.message.application.service.FirebaseSubscribeService.ALL_DEVICE_SUBSCRIBED_TOPIC;

@Slf4j
@RequiredArgsConstructor
public class UserRegisterNonChainingFilter implements HandlerInterceptor {

Expand All @@ -35,7 +38,7 @@ public class UserRegisterNonChainingFilter implements HandlerInterceptor {
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
try {
if(request.getMethod().equals(REGISTER_HTTP_METHOD)) {
if (request.getMethod().equals(REGISTER_HTTP_METHOD)) {
String userFcmToken = convert(request);
register(userFcmToken);
afterAuthentication(request, response);
Expand All @@ -50,7 +53,12 @@ public boolean preHandle(HttpServletRequest request, HttpServletResponse respons
}

private void register(String userFcmToken) {
userCommandPort.save(new User(userFcmToken));
try {
userCommandPort.save(new User(userFcmToken));
} catch (DuplicateKeyException e) { // 이미 등록된 사용자에 대한 처리는 필요없다
log.warn("User already exists: {}", userFcmToken, e);
}

UserSubscribeCommand command =
new UserSubscribeCommand(
userFcmToken,
Expand Down
18 changes: 17 additions & 1 deletion src/main/java/com/kustacks/kuring/config/RAGConfiguration.java
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
package com.kustacks.kuring.config;

import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chroma.ChromaApi;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.reader.TextReader;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.ChromaVectorStore;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Profile;
import org.springframework.core.io.Resource;
import org.springframework.web.client.RestTemplate;

import java.io.File;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;

@Slf4j
@Profile("local")
@Configuration
public class RAGConfiguration {

Expand All @@ -28,6 +30,20 @@ public class RAGConfiguration {
@Value("vectorstore.json")
private String vectorStoreName;

@Profile("test")
@Bean
public ChromaApi chromaApi(RestTemplate restTemplate) {
String chromaUrl = "http://127.0.0.1:8000";
return new ChromaApi(chromaUrl, restTemplate);
}

@Profile("test")
@Bean
public ChromaVectorStore chromaVectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi) {
return new ChromaVectorStore(embeddingModel, chromaApi, false);
}

@Profile("local")
@Bean
public SimpleVectorStore simpleVectorStore(EmbeddingModel embeddingModel) {
SimpleVectorStore simpleVectorStore = new SimpleVectorStore(embeddingModel);
Expand Down
23 changes: 0 additions & 23 deletions src/main/java/com/kustacks/kuring/config/RestTemplateConfig.java

This file was deleted.

Loading

0 comments on commit 1fb4c97

Please sign in to comment.