From 06bcdb03ac1471786f431af8af7c4dbfa9a2123e Mon Sep 17 00:00:00 2001 From: sal Date: Wed, 27 May 2026 14:06:22 +0900 Subject: [PATCH] Implement Phase 4~14: LangGraph Agent, RAG pipeline, Gradio Web UI, voice interface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Upgrade LLM to Qwen3-14B-4bit with Thinking mode (MlxChatModel as LangChain BaseChatModel) - Add LangGraph ReAct agent with tool calling loop (search_documents, web_search, get_current_date, remember/recall_user_info) - Add RAG pipeline: BAAI/bge-m3 embeddings + Qdrant vector store + semantic chunking (SemanticSplitter via cosine similarity) - Replace fixed-size RecursiveCharacterTextSplitter with meaning-based SemanticSplitter (numpy only, no extra deps) - Add Gradio Web UI (app.py): chat, document ingestion, document management tabs - Add multi-user support (user_id isolation in DB + per-user agent cache + dropdown selector) - Add conversation history restore from MySQL on agent init (Phase 11) - Add UserProfileRepository for persistent user profile (remember/recall tools) - Add thread-local DB connections to fix pymysql thread-safety with LangGraph ToolNode - Add Phase 14 voice interface: Whisper STT (microphone → text) + macOS TTS (say -v Yuna) - Enforce search_documents-first policy in system prompt and tool descriptions - Update ROADMAP2.md: Phase 14 완료, Phase 13 청킹 부분 완료 Co-Authored-By: Claude Sonnet 4.6 --- .env.example | 7 +- app.py | 250 ++++++++++++++ config.py | 34 +- container.py | 58 ++++ docs/01-plan/features/rag-tool-chain.plan.md | 271 +++++++++++++++ docs/ROADMAP.md | 56 +++ docs/ROADMAP2.md | 224 ++++++++++++ ingest.py | 28 ++ main.py | 35 +- requirements.txt | 18 + services/agent/__init__.py | 0 services/agent/agent_service.py | 248 ++++++++++++++ services/agent/tools.py | 96 ++++++ services/db/conversation_repository.py | 10 +- services/db/mysql_service.py | 104 ++++-- services/db/user_profile_repository.py | 31 ++ services/model/mlx_chat_model.py | 337 +++++++++++++++++++ services/rag/__init__.py | 0 services/rag/ingestion_service.py | 107 ++++++ services/rag/retriever_service.py | 67 ++++ 20 files changed, 1934 insertions(+), 47 deletions(-) create mode 100644 app.py create mode 100644 docs/01-plan/features/rag-tool-chain.plan.md create mode 100644 docs/ROADMAP.md create mode 100644 docs/ROADMAP2.md create mode 100644 ingest.py create mode 100644 services/agent/__init__.py create mode 100644 services/agent/agent_service.py create mode 100644 services/agent/tools.py create mode 100644 services/db/user_profile_repository.py create mode 100644 services/model/mlx_chat_model.py create mode 100644 services/rag/__init__.py create mode 100644 services/rag/ingestion_service.py create mode 100644 services/rag/retriever_service.py diff --git a/.env.example b/.env.example index 7b998a6..cb160f8 100644 --- a/.env.example +++ b/.env.example @@ -1,5 +1,5 @@ # LLM 모델 설정 -MODEL_ID=mlx-community/Qwen2.5-7B-Instruct-4bit +MODEL_ID=mlx-community/Qwen3-8B-4bit MAX_TOKENS=1024 MAX_HISTORY_TURNS=30 COMPACT_THRESHOLD=40 @@ -10,3 +10,8 @@ DB_PORT=3306 DB_NAME=youlbot DB_USER= DB_PASSWORD= + +# LangSmith 트레이싱 (Phase 7) — https://smith.langchain.com 에서 API 키 발급 +LANGCHAIN_TRACING_V2=false +LANGCHAIN_API_KEY= +LANGCHAIN_PROJECT=youlbot diff --git a/app.py b/app.py new file mode 100644 index 0000000..7eefd77 --- /dev/null +++ b/app.py @@ -0,0 +1,250 @@ +"""Gradio Web UI — 율봇 Phase 4 + Phase 9/10 + Phase 14(음성).""" +import os +import subprocess +import tempfile +import gradio as gr +from dotenv import load_dotenv +load_dotenv() + +from container import Container +from services.agent.agent_service import AgentService + +container = Container() + +db = container.db_service() +db.connect() +db.init_schema() + +ingestion = container.ingestion_service() +retriever = container.retriever_service() + +_cfg = container.config() +_agent_cache: dict[str, AgentService] = {} + +USER_LABELS = ["아록", "근혜", "도율", "하율"] +DEFAULT_USER = "아록" + +_whisper_model = None + + +def _get_whisper(): + global _whisper_model + if _whisper_model is None: + import whisper + _whisper_model = whisper.load_model(_cfg.whisper_model_size) + return _whisper_model + + +def transcribe_audio(filepath: str) -> str: + if not filepath: + return "" + model = _get_whisper() + result = model.transcribe(filepath, language="ko") + return result["text"].strip() + + +def tts_speak(text: str, voice: str) -> str | None: + """텍스트를 macOS say 명령어로 음성 변환, 재생용 wav 파일 경로 반환.""" + if not text: + return None + try: + tmp = tempfile.NamedTemporaryFile(suffix=".aiff", delete=False) + tmp.close() + subprocess.run( + ["say", "-v", voice, "-o", tmp.name, text], + check=True, + capture_output=True, + ) + return tmp.name + except Exception: + return None + + +def _get_agent(user_id: str) -> AgentService: + if user_id not in _agent_cache: + _agent_cache[user_id] = AgentService( + chat_model=container.chat_model(), + retriever_service=retriever, + system_prompt=_cfg.system_prompt, + rag_verbose=_cfg.rag_verbose, + rag_show_sources=_cfg.rag_show_sources, + langgraph_verbose=_cfg.langgraph_verbose, + think_verbose=_cfg.think_verbose, + user_profile_repository=container.user_profile_repository(), + conversation_repository=container.conversation_repository(), + user_id=user_id, + ) + return _agent_cache[user_id] + + +async def respond(message, history, show_thinking, user_id, use_tts): + if not message.strip(): + yield history, "", None + return + + agent = _get_agent(user_id) + history = list(history) + history.append({"role": "user", "content": message}) + history.append({"role": "assistant", "content": ""}) + yield history, "", None + + async for token in agent.stream_response(message, show_thinking=show_thinking): + history[-1]["content"] += token + yield history, "", None + + if use_tts: + response_text = history[-1]["content"] + audio_path = tts_speak(response_text, _cfg.tts_voice) + yield history, "", audio_path + + +def switch_user(user_id): + """사용자 전환 시 채팅 화면만 초기화 (대화 이력은 유지).""" + return [] + + +def reset_chat(user_id): + agent = _get_agent(user_id) + agent.reset() + return [] + + +def ingest_files(files): + if not files: + return "파일을 선택해주세요." + paths = [f if isinstance(f, str) else f.name for f in files] + try: + count = ingestion.ingest(paths) + names = ", ".join(p.split("/")[-1] for p in paths) + return f"완료: {names} → {count}개 청크 저장됨" + except Exception as e: + return f"오류: {e}" + + +def list_docs(): + try: + sources = retriever.list_documents() + return [[os.path.basename(s), s] for s in sources] + except Exception as e: + return [[f"오류: {e}", ""]] + + +def delete_doc(source): + if not source.strip(): + return "삭제할 파일 경로를 입력하세요.", list_docs() + try: + retriever.delete_document(source.strip()) + return f"삭제 완료: {os.path.basename(source.strip())}", list_docs() + except Exception as e: + return f"오류: {e}", list_docs() + + +with gr.Blocks(title="율봇") as demo: + gr.Markdown("# 율봇\n육아·금융 전문 AI 상담 도우미") + + user_state = gr.State(DEFAULT_USER) + + with gr.Tab("대화"): + with gr.Row(): + user_selector = gr.Dropdown( + choices=USER_LABELS, + value=DEFAULT_USER, + label="사용자", + scale=1, + ) + + chatbot = gr.Chatbot(label="율봇", height=500) + with gr.Row(): + msg_box = gr.Textbox( + placeholder="질문을 입력하세요... (Enter로 전송)", + label="", + scale=5, + autofocus=True, + ) + send_btn = gr.Button("전송", variant="primary", scale=1) + + # 음성 입력 (STT) + with gr.Row(): + audio_input = gr.Audio( + sources=["microphone"], + type="filepath", + label="음성으로 질문하기", + scale=4, + ) + transcribe_btn = gr.Button("음성 → 텍스트 변환", scale=1) + + with gr.Row(): + show_thinking = gr.Checkbox(label="사고 과정 표시", value=False) + use_tts = gr.Checkbox(label="음성으로 답변 읽기 (TTS)", value=False) + reset_btn = gr.Button("대화 초기화", size="sm") + + # TTS 출력 + tts_output = gr.Audio(label="음성 답변", autoplay=True, visible=False) + use_tts.change(lambda v: gr.Audio(visible=v), inputs=[use_tts], outputs=[tts_output]) + + user_selector.change( + switch_user, + inputs=[user_selector], + outputs=[chatbot], + ).then( + lambda u: u, inputs=[user_selector], outputs=[user_state] + ) + + transcribe_btn.click( + transcribe_audio, + inputs=[audio_input], + outputs=[msg_box], + ) + + send_btn.click( + respond, + inputs=[msg_box, chatbot, show_thinking, user_state, use_tts], + outputs=[chatbot, msg_box, tts_output], + ) + msg_box.submit( + respond, + inputs=[msg_box, chatbot, show_thinking, user_state, use_tts], + outputs=[chatbot, msg_box, tts_output], + ) + reset_btn.click(reset_chat, inputs=[user_state], outputs=[chatbot]) + + with gr.Tab("문서 등록"): + gr.Markdown("PDF 또는 TXT 파일을 업로드하면 율봇이 내용을 참고해 답변합니다.") + file_input = gr.File( + file_types=[".pdf", ".txt"], + file_count="multiple", + label="파일 선택", + ) + ingest_btn = gr.Button("문서 수집", variant="primary") + ingest_status = gr.Textbox(label="결과", interactive=False) + ingest_btn.click(ingest_files, inputs=[file_input], outputs=[ingest_status]) + + with gr.Tab("문서 관리"): + gr.Markdown("Qdrant에 등록된 문서 목록입니다. 불필요한 문서를 삭제할 수 있습니다.") + doc_table = gr.Dataframe( + headers=["파일명", "전체 경로"], + label="등록된 문서", + interactive=False, + ) + refresh_btn = gr.Button("새로고침") + gr.Markdown("---") + with gr.Row(): + delete_source = gr.Textbox( + label="삭제할 파일 경로", + placeholder="위 표에서 전체 경로를 복사해 붙여넣으세요", + scale=4, + ) + delete_btn = gr.Button("삭제", variant="stop", scale=1) + delete_status = gr.Textbox(label="결과", interactive=False) + + refresh_btn.click(list_docs, outputs=[doc_table]) + delete_btn.click( + delete_doc, + inputs=[delete_source], + outputs=[delete_status, doc_table], + ) + demo.load(list_docs, outputs=[doc_table]) + + +if __name__ == "__main__": + demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft()) diff --git a/config.py b/config.py index dde3e65..579b424 100644 --- a/config.py +++ b/config.py @@ -6,13 +6,16 @@ class Config(BaseSettings): env_file=".env", env_file_encoding="utf-8", frozen=True, + extra="ignore", ) # LLM - model_id: str = "mlx-community/Qwen2.5-7B-Instruct-4bit" + model_id: str = "mlx-community/Qwen3-14B-4bit" max_tokens: int = 1024 max_history_turns: int = 10 compact_threshold: int = 20 + enable_thinking: bool = True + think_verbose: bool = False # MySQL db_host: str = "localhost" @@ -21,11 +24,36 @@ class Config(BaseSettings): db_user: str = "" db_password: str = "" - system_prompt: str = """당신의 이름은 '율봇'입니다. 친절하고 따뜻한 한국어 상담 도우미입니다. + # Qdrant + qdrant_url: str = "http://localhost:6333" + qdrant_collection: str = "youlbot_docs" + + # Embedding + embedding_model_id: str = "BAAI/bge-m3" + embedding_device: str = "mps" + + # RAG + rag_top_k: int = 3 + semantic_breakpoint_threshold_type: str = "percentile" # percentile | standard_deviation | interquartile + rag_verbose: bool = False + rag_show_sources: bool = False + langgraph_verbose: bool = False + + # Voice (Phase 14) + whisper_model_size: str = "small" + tts_voice: str = "Yuna" # macOS say 명령어 한국어 음성 + + system_prompt: str = """모든 응답과 내부 사고 과정을 반드시 한국어로 작성하세요. + +당신의 이름은 '율봇'입니다. 친절하고 따뜻한 한국어 상담 도우미입니다. 육아와 금융 두 분야를 전문으로 합니다. - 육아: 아이 발달, 이유식, 수면, 훈육, 교육 등 부모가 궁금해하는 모든 것을 도와드립니다. - 금융: 저축, 투자, 보험, 대출, 세금 등 생활 금융 관련 질문에 답변드립니다. 항상 쉽고 친근한 말투로 설명하고, 전문 용어는 풀어서 설명합니다. -의학적 진단이나 법적 판단이 필요한 경우에는 반드시 전문가 상담을 권유합니다.""" +의학적 진단이나 법적 판단이 필요한 경우에는 반드시 전문가 상담을 권유합니다. + +## 문서 검색 규칙 +육아·금융 관련 질문이라면 자신의 학습 지식으로 직접 답하지 말고, 반드시 search_documents 도구를 먼저 호출하세요. +검색 결과가 없거나 관련 문서가 등록되어 있지 않은 경우에만 학습 지식을 보조적으로 활용합니다.""" diff --git a/container.py b/container.py index 4c9e12a..06a3a01 100644 --- a/container.py +++ b/container.py @@ -2,14 +2,20 @@ from dependency_injector import containers, providers from config import Config from services.model.mlx_model import MlxModelService +from services.model.mlx_chat_model import MlxChatModel from services.chat.history_service import HistoryService from services.chat.chat_service import ChatService from services.chat.compact_service import CompactService from services.db.mysql_service import DatabaseService from services.db.conversation_repository import ConversationRepository +from services.db.user_profile_repository import UserProfileRepository from services.ui.cli_service import CliUiService from services.events.event_bus import EventBus from services.events.handlers import StreamTokenHandler, StreamEndHandler +from langchain_huggingface import HuggingFaceEmbeddings +from services.rag.ingestion_service import IngestionService +from services.rag.retriever_service import RetrieverService +from services.agent.agent_service import AgentService class Container(containers.DeclarativeContainer): @@ -22,6 +28,14 @@ class Container(containers.DeclarativeContainer): model_id=providers.Callable(lambda c: c.model_id, config), ) + # LangGraph 에이전트용 BaseChatModel (Phase 1) + chat_model = providers.Singleton( + MlxChatModel, + model_id=providers.Callable(lambda c: c.model_id, config), + max_tokens=providers.Callable(lambda c: c.max_tokens, config), + enable_thinking=providers.Callable(lambda c: c.enable_thinking, config), + ) + compact_service = providers.Singleton( CompactService, model=model_service, @@ -41,6 +55,11 @@ class Container(containers.DeclarativeContainer): db=db_service, ) + user_profile_repository = providers.Singleton( + UserProfileRepository, + db=db_service, + ) + history_service = providers.Factory( HistoryService, system_prompt=providers.Callable(lambda c: c.system_prompt, config), @@ -62,3 +81,42 @@ class Container(containers.DeclarativeContainer): stream_token_handler = providers.Singleton(StreamTokenHandler) stream_end_handler = providers.Singleton(StreamEndHandler) + + # Phase 2 — RAG 파이프라인 + embeddings = providers.Singleton( + HuggingFaceEmbeddings, + model_name=providers.Callable(lambda c: c.embedding_model_id, config), + model_kwargs=providers.Callable(lambda c: {"device": c.embedding_device}, config), + ) + + ingestion_service = providers.Singleton( + IngestionService, + embeddings=embeddings, + qdrant_url=providers.Callable(lambda c: c.qdrant_url, config), + collection_name=providers.Callable(lambda c: c.qdrant_collection, config), + breakpoint_threshold_type=providers.Callable( + lambda c: c.semantic_breakpoint_threshold_type, config + ), + ) + + retriever_service = providers.Singleton( + RetrieverService, + embeddings=embeddings, + qdrant_url=providers.Callable(lambda c: c.qdrant_url, config), + collection_name=providers.Callable(lambda c: c.qdrant_collection, config), + top_k=providers.Callable(lambda c: c.rag_top_k, config), + ) + + # Phase 3 — LangGraph Agent + agent_service = providers.Singleton( + AgentService, + chat_model=chat_model, + retriever_service=retriever_service, + system_prompt=providers.Callable(lambda c: c.system_prompt, config), + rag_verbose=providers.Callable(lambda c: c.rag_verbose, config), + rag_show_sources=providers.Callable(lambda c: c.rag_show_sources, config), + langgraph_verbose=providers.Callable(lambda c: c.langgraph_verbose, config), + think_verbose=providers.Callable(lambda c: c.think_verbose, config), + user_profile_repository=user_profile_repository, + conversation_repository=conversation_repository, + ) diff --git a/docs/01-plan/features/rag-tool-chain.plan.md b/docs/01-plan/features/rag-tool-chain.plan.md new file mode 100644 index 0000000..242926d --- /dev/null +++ b/docs/01-plan/features/rag-tool-chain.plan.md @@ -0,0 +1,271 @@ +--- +template: plan +version: 1.4 +feature: rag-tool-chain +date: 2026-04-27 +author: sal +project: youlbot +status: Draft +--- + +# rag-tool-chain Planning Document + +> **Summary**: mlx-lm을 LangChain `BaseChatModel`로 래핑하고, LangGraph 에이전트로 RAG + Tool Calling을 통합한다. 커스텀 구현은 최소화하고 LangChain/LangGraph 생태계를 최대한 활용한다. +> +> **Project**: youlbot +> **Author**: sal +> **Date**: 2026-04-27 +> **Status**: Draft + +--- + +## Executive Summary + +| Perspective | Content | +|-------------|---------| +| **Problem** | 현재 율봇은 모델 파라미터 지식에만 의존하며, Tool Calling·RAG를 직접 구현하면 유지보수 부담이 큼 | +| **Solution** | mlx-lm을 `BaseChatModel`로 1회 래핑 후 LangGraph 에이전트와 LangChain RAG 생태계를 그대로 활용 | +| **Function/UX Effect** | 육아·금융 전문 문서 기반 답변, Tool 호출로 동적 정보 처리 가능 | +| **Core Value** | 커스텀 코드 최소화 — LangGraph가 Tool Calling 루프·상태 관리를 담당, LangChain이 RAG 파이프라인을 담당 | + +--- + +## Context Anchor + +| Key | Value | +|-----|-------| +| **WHY** | Tool Calling 루프·히스토리 관리·RAG 오케스트레이션을 직접 구현하면 버그 표면적이 넓고 유지보수 비용이 높음 | +| **WHO** | 개발자 (sal) — 단독 개발 | +| **RISK** | mlx-lm `BaseChatModel` 래퍼가 LangGraph와 완전 호환되는지 검증 필요 | +| **SUCCESS** | `create_react_agent(llm, tools)` 수준의 단순한 에이전트 구성으로 RAG·Tool Calling 동작 | +| **SCOPE** | Phase 1: mlx-lm BaseChatModel 래퍼 / Phase 2: RAG 파이프라인 / Phase 3: LangGraph 에이전트 통합 | + +--- + +## 1. Overview + +### 1.1 Architecture 결정 (Option B) + +``` +mlx-lm + └─ MlxChatModel(BaseChatModel) ← 1회 구현 (~80줄) + └─ LangGraph ReAct Agent ← Tool Calling 루프 내장 + ├─ RAG Tool ← LangChain-Qdrant 검색 + └─ 기타 Tools +``` + +**LangGraph가 처리하는 것 (커스텀 불필요):** +- Tool Calling 루프 (tool_call → 실행 → 재요청) +- 대화 상태 및 히스토리 관리 +- 조건부 라우팅 (일반 답변 vs Tool 호출) +- 최대 반복 횟수 제한 + +**LangChain이 처리하는 것 (커스텀 불필요):** +- 문서 로딩 (PDF, TXT, MD) +- 텍스트 청킹 +- 임베딩 생성 +- Qdrant 벡터 스토어 연동 + +**직접 구현하는 것 (최소):** +- `MlxChatModel(BaseChatModel)` — mlx-lm 래퍼 (~80줄) +- Tool 구현체 (비즈니스 로직 함수들) +- IoC Container 배선 + +### 1.2 Background +- 율봇의 도메인: 육아, 금융 — 신뢰성 있는 출처 기반 답변이 중요 +- Qwen2.5-7B-Instruct는 Tool Calling 네이티브 지원 +- LangGraph는 LangChain 공식 에이전트 오케스트레이션 프레임워크 (2024년 이후 표준) + +--- + +## 2. Scope + +### 2.1 In Scope + +**Phase 1 — MlxChatModel 래퍼** +- [ ] `services/model/mlx_chat_model.py` — `BaseChatModel` 서브클래스 + - `_generate()` — 단일 응답 (tool_call 포함 AIMessage 반환) + - `_stream()` — 스트리밍 청크 + - `bind_tools()` — LangChain 표준 Tool 바인딩 + +**Phase 2 — RAG 파이프라인** +- [ ] `services/rag/ingestion_service.py` — 문서 로드 → 청크 → 임베딩 → Qdrant 저장 +- [ ] `services/rag/retriever_service.py` — Qdrant 검색 → LangChain Tool 래핑 +- [ ] `config.py` 확장 — Qdrant, 임베딩 모델, RAG 설정 + +**Phase 3 — LangGraph 에이전트 통합** +- [ ] `services/agent/agent_service.py` — LangGraph `create_react_agent` 조립 +- [ ] `services/agent/tools.py` — Tool 구현체 (@tool 데코레이터) +- [ ] `container.py` 업데이트 — 신규 서비스 IoC 등록 +- [ ] 기존 `ChatService` 보존, `AgentService`로 선택적 전환 + +### 2.2 기존 코드 처리 + +| 기존 코드 | 처리 방향 | +|-----------|-----------| +| `AbstractModelService` + `MlxModelService` | 보존 (LangGraph 없는 단순 모드용) | +| `ChatService` | 보존 | +| `HistoryService` | LangGraph State로 대체 (Phase 3) | +| `CompactService` | LangGraph Memory 전략으로 추후 대체 | +| `EventBus` / `StreamTokenHandler` | LangGraph Streaming callback으로 대체 (Phase 3) | + +### 2.3 Out of Scope +- 웹 API 레이어 (FastAPI 등) +- 문서 관리 UI +- 외부 API 기반 Tool (날씨, 금융 API 등) — 추후 Phase +- LangGraph 퍼시스턴스 (체크포인터, 장기 메모리) — 추후 Phase + +--- + +## 3. Requirements + +### 3.1 Functional Requirements + +| ID | Requirement | Priority | +|----|-------------|----------| +| FR-01 | `MlxChatModel`이 LangChain `BaseChatModel` 인터페이스를 완전히 구현 | High | +| FR-02 | `bind_tools()`로 Tool을 바인딩하면 모델이 tool_call을 생성 | High | +| FR-03 | 문서(PDF, TXT, MD)를 Qdrant에 수집·저장하는 수집 파이프라인 | High | +| FR-04 | LangGraph ReAct 에이전트가 RAG Tool을 자동 호출하여 컨텍스트 확보 | High | +| FR-05 | Tool Calling 루프는 LangGraph가 관리 (직접 구현 금지) | High | +| FR-06 | 스트리밍 출력은 LangGraph의 `stream()` 인터페이스 활용 | Medium | + +### 3.2 Non-Functional Requirements + +| Category | Criteria | +|----------|----------| +| 커스텀 코드 최소화 | LangGraph/LangChain이 제공하는 기능은 직접 구현하지 않음 | +| 교체 용이성 | `MlxChatModel`을 `ChatOllama` 등으로 교체 시 `AgentService` 코드 변경 없음 | +| 성능 | 임베딩 모델 Singleton으로 1회만 로딩 | +| 안정성 | Tool 실행 실패 시 LangGraph가 에러를 메시지로 처리, 대화 중단 없음 | + +--- + +## 4. Architecture + +### 4.1 디렉터리 구조 + +``` +services/ + model/ + base.py # AbstractModelService (기존 유지) + mlx_model.py # MlxModelService (기존 유지) + mlx_chat_model.py # MlxChatModel : BaseChatModel (신규, Phase 1) + rag/ + __init__.py + ingestion_service.py # 문서 로드/청크/임베딩/Qdrant 저장 (Phase 2) + retriever_service.py # Qdrant 검색 → LangChain Retriever (Phase 2) + agent/ + __init__.py + agent_service.py # LangGraph create_react_agent 조립 (Phase 3) + tools.py # @tool 데코레이터 Tool 구현체 (Phase 3) + chat/ # 기존 전부 유지 + db/ # 기존 전부 유지 + events/ # 기존 전부 유지 + ui/ # 기존 전부 유지 +``` + +### 4.2 MlxChatModel 인터페이스 (Phase 1 핵심) + +```python +class MlxChatModel(BaseChatModel): + model_id: str + max_tokens: int = 1024 + + def _generate(self, messages, stop=None, **kwargs) -> ChatResult: + prompt = self._tokenizer.apply_chat_template(messages, ...) + text = generate(self._model, self._tokenizer, prompt, ...) + return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))]) + + def _stream(self, messages, stop=None, **kwargs) -> Iterator[ChatGenerationChunk]: + prompt = self._tokenizer.apply_chat_template(messages, ...) + for chunk in stream_generate(...): + yield ChatGenerationChunk(message=AIMessageChunk(content=chunk.text)) +``` + +### 4.3 LangGraph 에이전트 흐름 (Phase 3) + +```python +# AgentService의 핵심 — 대부분 라이브러리가 처리 +llm = MlxChatModel(model_id=config.model_id) +tools = [rag_search_tool, get_current_date_tool, ...] + +agent = create_react_agent(llm, tools) + +# 실행 — Tool Calling 루프, 히스토리, 에러 처리 모두 LangGraph 담당 +result = agent.invoke({"messages": [HumanMessage(content=user_input)]}) +``` + +### 4.4 RAG Tool 구조 (Phase 2 + Phase 3) + +```python +@tool +def search_documents(query: str) -> str: + """육아·금융 관련 문서에서 관련 내용을 검색합니다.""" + docs = retriever.invoke(query) + return format_docs(docs) +``` + +### 4.5 의존성 + +``` +# 신규 추가 +langchain-core +langchain-community # 문서 로더, HuggingFace 임베딩 +langchain-text-splitters +langchain-qdrant # Qdrant 벡터 스토어 +langgraph # 에이전트 오케스트레이션 +sentence-transformers # 로컬 임베딩 (BAAI/bge-m3) +qdrant-client +``` + +### 4.6 Config 확장 + +```python +# Qdrant +qdrant_host: str = "localhost" +qdrant_port: int = 6333 +qdrant_collection: str = "youlbot_docs" + +# Embedding +embedding_model_id: str = "BAAI/bge-m3" + +# RAG +rag_top_k: int = 3 +rag_score_threshold: float = 0.5 +``` + +--- + +## 5. Success Criteria + +- [ ] `MlxChatModel`이 `llm.invoke([HumanMessage(...)])` 호출로 정상 응답 +- [ ] `llm.bind_tools(tools).invoke(messages)` 호출 시 tool_call 포함 응답 생성 +- [ ] PDF/TXT 문서를 수집해 Qdrant에 저장, 쿼리로 관련 청크 검색 가능 +- [ ] LangGraph 에이전트가 RAG Tool을 자동 호출하고 결과를 반영하여 최종 답변 생성 +- [ ] `MlxChatModel`을 `ChatOllama`로 교체해도 `AgentService` 코드 변경 없음 + +--- + +## 6. Risks + +| Risk | Impact | Likelihood | Mitigation | +|------|--------|------------|------------| +| `MlxChatModel`의 tool_call 파싱이 LangGraph와 불일치 | High | Medium | Phase 1에서 단위 검증 후 Phase 3 진행 | +| Qwen2.5-7B의 ReAct 프롬프트 준수 불안정 | Medium | Medium | LangGraph 프롬프트 커스터마이징, few-shot 추가 | +| 로컬 임베딩 모델(BGE-M3) 최초 로딩 시간 (~30초) | Medium | High | Singleton 1회 로딩, 진행 안내 메시지 | +| Qdrant 미실행 시 에이전트 전체 불가 | High | Medium | RAG Tool 비활성화 config 플래그 | +| LangChain/LangGraph 버전 충돌 | Low | Low | 버전 고정, 의존성 테스트 | + +--- + +## 7. Architecture Decisions + +| Decision | Selected | Rationale | +|----------|----------|-----------| +| LLM 통합 방식 | mlx-lm → `BaseChatModel` 래퍼 (Option B) | mlx Apple Silicon 최적화 유지 + LangChain 생태계 전체 활용 | +| 에이전트 프레임워크 | LangGraph `create_react_agent` | Tool Calling 루프·상태 관리 직접 구현 불필요, LangChain 공식 표준 | +| Tool 정의 방식 | `@tool` 데코레이터 | LangGraph 표준, JSON 스키마 자동 생성 | +| 임베딩 모델 | BAAI/bge-m3 (로컬) | 한국어 포함 다국어 지원, 서버 불필요 | +| Qdrant 운영 | 로컬 Docker | 개발 단계 외부 의존 최소화 | +| 기존 코드 처리 | 보존 (병행 운영) | ChatService(단순 모드) / AgentService(RAG+Tool 모드) 선택적 사용 | diff --git a/docs/ROADMAP.md b/docs/ROADMAP.md new file mode 100644 index 0000000..b49836b --- /dev/null +++ b/docs/ROADMAP.md @@ -0,0 +1,56 @@ +# 율봇 개발 로드맵 + +## 현재 구현 상태 (Phase 1~7 완료) + +| 영역 | 현황 | +|------|------| +| LLM | Qwen2.5-7B-Instruct-4bit (MLX, Apple Silicon) | +| Agent | LangGraph ReAct + Tool Calling + Thinking 모드 | +| RAG | Qdrant + BAAI/bge-m3 임베딩 | +| Tools | `search_documents`, `get_current_date`, `web_search`, `remember_user_info`, `recall_user_info` (5개) | +| UI | Gradio Web UI (`app.py`) + CLI (`main.py`) | +| Memory | LangGraph MemorySaver (세션 내) + MySQL (대화 영구 저장) + `td_user_profile` (장기 사용자 메모리) | +| Streaming | 비동기 토큰 스트리밍 + `` 블록 파싱 | +| Tracing | LangSmith 트레이싱 설정 완료 (`.env`에서 활성화 가능) | + +--- + +## ✅ Phase 4 — Web UI (Gradio) + +- `app.py` — Gradio ChatInterface + `stream_response()` 연결 +- PDF/TXT 파일 업로드 → 인제스트 버튼 +- 사고 과정(thinking) 표시 토글 +- 대화 초기화 버튼 + +--- + +## ✅ Phase 5 — 장기 사용자 메모리 + +- MySQL `td_user_profile` 테이블 + Tool 2개 등록 +- `remember_user_info(key, value)` — 영구 저장 (아이 생년, 재정 목표 등) +- `recall_user_info(key)` — 이전 저장 정보 조회 +- `UserProfileRepository` (`services/db/user_profile_repository.py`) + +--- + +## ✅ Phase 6 — 실시간 웹 검색 Tool + +- `web_search(query)` — DuckDuckGo (무료, API 키 불필요) +- 최신 금리, 육아 정책, 뉴스 등 실시간 정보 검색 가능 + +--- + +## ✅ Phase 7 — LangSmith 트레이싱 + +- `.env`에서 `LANGCHAIN_TRACING_V2=true` + `LANGCHAIN_API_KEY` 설정으로 활성화 +- Tool Call 실패 원인, RAG 청크 내용, 에이전트 루프 흐름 시각화 가능 + +--- + +## Phase 8 — 멀티모달 이미지 이해 ★☆☆ + +**배경**: 이유식 사진 → "이 재료로 만들 수 있는 이유식은?", 금융 서류 사진 → 내용 분석 등 이미지 기반 질문 처리. + +**제약**: Qwen2.5-7B는 이미지 미지원 → `mlx-community/Qwen2.5-VL-7B-Instruct-4bit` 모델 교체 필요. + +**난이도**: 높음 | **임팩트**: 높음 (장기 과제) diff --git a/docs/ROADMAP2.md b/docs/ROADMAP2.md new file mode 100644 index 0000000..1e7fdb1 --- /dev/null +++ b/docs/ROADMAP2.md @@ -0,0 +1,224 @@ +# 율봇 개발 로드맵 2 + +## 현재 구현 상태 (Phase 1~11 + Phase 14 완료, 버그 1~3 수정 완료, 모델 업그레이드) + +| 영역 | 현황 | +|------|------| +| LLM | Qwen3-14B-4bit (MLX, Apple Silicon) | +| Agent | LangGraph ReAct + Tool Calling + Thinking 모드 | +| RAG | Qdrant + BAAI/bge-m3 임베딩 | +| Tools | `search_documents`, `web_search`, `get_current_date`, `remember_user_info`, `recall_user_info` (5개) | +| UI | CLI + Gradio Web UI | +| Memory | LangGraph MemorySaver (세션 내) + MySQL 대화 저장 + 장기 사용자 프로필 | +| Tracing | LangSmith 트레이싱 | +| Streaming | 비동기 토큰 스트리밍 + `` 블록 파싱 | +| History Compact | 대화 20턴 초과 시 오래된 절반을 LLM으로 자동 요약 (`CompactService`) | + +--- + +## 버그 수정 현황 + +### ✅ 버그 1 — RAG 중복 수집 (수정 완료) +`IngestionService._delete_by_source()`를 구현해 같은 파일 경로로 저장된 기존 청크를 `ingest()` 시작 시 삭제한다. + +### ✅ 버그 2 — LangGraph MemorySaver와 MySQL 이력 미연동 (수정 완료) +`AgentService.__init__`에서 MySQL에 저장된 최근 10턴을 `_pending_history`로 불러온 뒤, 첫 `stream_response()` 호출 시 LangGraph 초기 메시지로 주입한다. + +### ✅ 버그 3 — 단일 사용자 전제 (수정 완료) +DB 스키마(`td_conversations.user_id`, `td_user_profile.user_id`)는 `_migrate_schema`로 자동 마이그레이션. `AgentService`에 `user_id` 파라미터 추가, 모든 Repository 호출에 전파. Gradio에 사용자 선택 드롭다운(아록/근혜/도율/하율) 추가 및 사용자별 에이전트 캐시 구현. + +--- + +## ✅ Phase 9 — 문서 관리 (완료) + +- `IngestionService._delete_by_source()` — 파일 경로 기반 중복 청크 삭제 +- `RetrieverService.list_documents()` — Qdrant scroll로 고유 source 목록 반환 +- `RetrieverService.delete_document(source)` — source 기준 청크 전체 삭제 +- Gradio "문서 관리" 탭 — 목록 테이블 + 경로 입력 삭제 버튼 + 앱 로드 시 자동 새로고침 + +--- + +## ✅ Phase 10 — 멀티유저 지원 (완료) + +Bug 3 수정 및 Phase 9 작업과 함께 완전 구현됨. + +- DB 마이그레이션: `mysql_service._migrate_schema()`가 `td_conversations`, `td_user_profile` 양쪽에 `user_id` 컬럼 자동 추가 +- `ConversationRepository`: `create_conversation(user_id)` / `get_latest_conversation_id(user_id)` — user_id 기반 격리 +- `AgentService`: `user_id` 파라미터 추가, 모든 프로필·대화 조회에 전파 +- `make_memory_tools(profile_repo, user_id)`: remember/recall 도구가 올바른 사용자 데이터만 접근 +- Gradio: 사용자 선택 드롭다운(아록/근혜/도율/하율, 기본값 아록) + `_agent_cache` 사전으로 사용자별 에이전트 분리 + +--- + +## ✅ Phase 11 — 대화 이력 복원 (수정 완료) + +버그 2와 함께 해결됨. +`AgentService` 초기화 시 MySQL에서 최근 10턴을 `_pending_history`에 로드 → 첫 메시지와 함께 LangGraph에 주입. + +```python +# agent_service.py 초기화 (구현됨) +turns = conversation_repository.load_turns_after(self._conv_id, None, limit=10) +# → HumanMessage / AIMessage 변환 후 _pending_history에 저장 +``` + +--- + +## Phase 12 — 답변 피드백 & 품질 개선 ★★☆ + +**배경**: 에이전트가 잘못된 답변을 해도 피드백 루프가 없어 개선이 어려움. + +**구현 범위**: +- Gradio 채팅 메시지마다 👍 / 👎 버튼 +- `td_feedback` 테이블에 메시지·평점 저장 +- LangSmith의 `run_id`와 연결해 피드백을 트레이스에 기록 (`langsmith.Client().create_feedback()`) + +```sql +CREATE TABLE td_feedback ( + id INT AUTO_INCREMENT PRIMARY KEY, + message TEXT, + response TEXT, + rating TINYINT, -- 1: 좋음, -1: 나쁨 + langsmith_run_id VARCHAR(100), + created_at DATETIME DEFAULT CURRENT_TIMESTAMP +); +``` + +**난이도**: 중간 | **임팩트**: 중간 (장기 품질 향상) + +--- + +## Phase 13 — RAG 품질 향상 (Reranker + 청킹 개선) ★★☆ (부분 완료) + +**배경**: 현재 고정 크기 청킹 + 벡터 유사도 검색만으로는 관련 없는 청크가 섞일 수 있음. + +**✅ Semantic Chunker — 완료** +- `_SemanticSplitter` 클래스 직접 구현 (`services/rag/ingestion_service.py`) +- `langchain-experimental` 사용 없이 numpy + 기존 BAAI/bge-m3 임베딩으로 구현 +- 인접 문장 간 코사인 유사도 계산 → 유사도 하위 5% 지점에서 청크 분리 +- `config.py`에서 `rag_chunk_size` / `rag_chunk_overlap` 제거 → `semantic_breakpoint_threshold_type` 추가 + +**🔲 미완 — Reranker** +1. **Reranker 추가** — `cross-encoder/ms-marco-MiniLM-L-6-v2`로 검색 결과 재순위 +2. **top_k 조정** — 검색 후 rerank → 상위 3개만 LLM에 전달 + +> 기존 Qdrant 저장 문서는 재등록해야 새 청킹 방식이 적용됨. + +**난이도**: 중간 | **임팩트**: 중간 (답변 정확도 향상) + +--- + +## ✅ Phase 14 — 음성 인터페이스 (완료) + +**배경**: 육아 중에는 손이 자유롭지 않아 타이핑이 어려움. 음성으로 질문하고 답변을 들을 수 있으면 핵심 사용 시나리오 커버. + +**구현 내용**: +- `openai-whisper` (small 모델) — 마이크 녹음 → 한국어 텍스트 변환, 지연 로딩 +- macOS `say -v Yuna` — 에이전트 응답을 음성으로 읽어줌 (aiff 파일 경유) +- Gradio "대화" 탭 확장 — 마이크 녹음 + "음성→텍스트 변환" 버튼 + "음성으로 답변 읽기" 체크박스 + TTS 오디오 플레이어 +- LLM/Agent 레이어 변경 없음 — 순수 I/O 어댑터로 구현 + +```python +# app.py — STT +def transcribe_audio(filepath: str) -> str: + result = whisper.load_model("small").transcribe(filepath, language="ko") + return result["text"].strip() + +# app.py — TTS +def tts_speak(text: str, voice: str) -> str | None: + subprocess.run(["say", "-v", voice, "-o", tmp.name, text], ...) +``` + +**config.py 추가**: `whisper_model_size = "small"`, `tts_voice = "Yuna"` + +**난이도**: 중간 | **임팩트**: 높음 (핵심 사용 시나리오) + +--- + +## Phase 15 — 예방접종·건강검진 알림 스케줄러 ★★☆ + +**배경**: 아이 생년을 기억하고 있으므로, 예방접종 일정(BCG, DTaP 등)을 자동 계산해 알림을 줄 수 있음. 율봇의 차별화 포인트. + +**구현 방식**: +- `td_user_profile`에서 아이 생년 조회 → 예방접종 스케줄 계산 Tool +- Gradio "건강 일정" 탭: 달력형 일정 표시 +- APScheduler로 당일 알림 (또는 Gradio 시작 시 오늘 일정 배너) + +```python +@tool +def get_vaccination_schedule(birth_year: int, birth_month: int) -> str: + """아이 생년월을 기반으로 예방접종 일정을 계산합니다.""" +``` + +**난이도**: 중간 | **임팩트**: 높음 (육아 특화 차별화) + +--- + +## Phase 16 — 모델 선택 (Claude API / OpenAI 옵션) ★☆☆ + +**배경**: 로컬 MLX 모델은 Apple Silicon 전용. 원격 접속 시나리오나 더 높은 품질이 필요할 때 Claude API/OpenAI를 선택할 수 있으면 유연성 확보. + +**구현 방식**: `config.py`에 `model_provider` 추가, `container.py`에서 provider별 chat_model 분기. + +```python +model_provider: str = "mlx" # "mlx" | "claude" | "openai" +``` + +**난이도**: 중간 | **임팩트**: 중간 + +--- + +## Phase 17 — Docker 컨테이너화 ★☆☆ + +**배경**: 현재 로컬 전용. 가족이나 지인도 쓸 수 있도록 서버 배포 가능한 형태로 패키징. + +**구현 범위**: +``` +docker-compose.yml +├── youlbot (Gradio app) +├── qdrant +└── mysql +``` + +> 주의: MLX는 Apple Silicon 전용이라 서버 배포 시 Phase 16(모델 선택)이 선행되어야 함. + +**난이도**: 높음 | **임팩트**: 중간 + +--- + +## Phase 18 — 멀티모달 이미지 이해 ★☆☆ + +**배경**: 이유식 사진 → 재료 분석, 금융 서류 사진 → 내용 해석 등. + +**제약**: Qwen3-8B는 이미지 미지원 → `mlx-community/Qwen2.5-VL-7B-Instruct-4bit` 교체 필요. + +**난이도**: 높음 | **임팩트**: 높음 (장기 과제) + +--- + +## 추천 진행 순서 + +``` +단기 (1~2주) 중기 (1개월) 장기 +──────────────── ────────────────── ────────────── +Phase 14 (음성) → Phase 13 (RAG품질) → Phase 17 (Docker) +Phase 15 (알림) Phase 16 (모델선택) Phase 18 (멀티모달) +Phase 12 (피드백) +``` + +### 우선순위 매트릭스 + +| Phase | 상태 | 난이도 | 임팩트 | 추천 순위 | +|-------|------|--------|--------|-----------| +| 버그 1 RAG 중복 | ✅ 완료 | — | — | — | +| 버그 2 이력 미연동 | ✅ 완료 | — | — | — | +| 버그 3 단일 사용자 | ✅ 완료 | — | — | — | +| Phase 9 문서 관리 | ✅ 완료 | — | — | — | +| Phase 10 멀티유저 | ✅ 완료 | — | — | — | +| Phase 11 이력 복원 | ✅ 완료 | — | — | — | +| Phase 14 음성 인터페이스 | ✅ 완료 | — | — | — | +| Phase 15 예방접종 알림 | 🔲 미완 | 중간 | 높음 | ⭐ 2순위 | +| Phase 12 피드백 | 🔲 미완 | 중간 | 중간 | 3순위 | +| Phase 13 RAG 품질 (청킹 완료, Reranker 미완) | 🔲 진행 중 | 중간 | 중간 | 4순위 | +| Phase 16 모델 선택 | 🔲 미완 | 중간 | 중간 | 5순위 | +| Phase 17 Docker | 🔲 미완 | 높음 | 중간 | 6순위 | +| Phase 18 멀티모달 | 🔲 미완 | 높음 | 높음 | 7순위 | diff --git a/ingest.py b/ingest.py new file mode 100644 index 0000000..7a5826c --- /dev/null +++ b/ingest.py @@ -0,0 +1,28 @@ +"""문서 수집 CLI. + +사용법: + python ingest.py <파일경로> [<파일경로> ...] + +예시: + python ingest.py docs/육아가이드.pdf docs/금융상품안내.txt +""" +import sys +from container import Container + + +def main() -> None: + files = sys.argv[1:] + if not files: + print("사용법: python ingest.py <파일경로> [<파일경로> ...]") + sys.exit(1) + + container = Container() + service = container.ingestion_service() + + print(f"{len(files)}개 파일 수집 시작...") + count = service.ingest(files) + print(f"완료: {count}개 청크가 Qdrant({container.config().qdrant_url})에 저장되었습니다.") + + +if __name__ == "__main__": + main() diff --git a/main.py b/main.py index 12673b1..1690c5f 100644 --- a/main.py +++ b/main.py @@ -1,25 +1,23 @@ +import asyncio + +from dotenv import load_dotenv +load_dotenv() + from container import Container -from services.chat.chat_service import ChatService -def main() -> None: +async def main_async() -> None: container = Container() - ui = container.ui_service() - model = container.model_service() - bus = container.event_bus() + db = container.db_service() - repo = container.conversation_repository() - - bus.subscribe(ChatService.EVENT_TOKEN, container.stream_token_handler()) - bus.subscribe(ChatService.EVENT_END, container.stream_end_handler()) - - ui.show_banner(container.config().model_id) - model.load() db.connect() db.init_schema() - chat = container.chat_service() + ui.show_banner(container.config().model_id) + + # AgentService 초기화 — MlxChatModel 모델 로딩 + LangGraph 그래프 구성 포함 + agent = container.agent_service() while True: try: @@ -36,15 +34,18 @@ def main() -> None: break if ui.is_reset_command(user_input): - repo.create_conversation() - chat = container.chat_service() + agent.reset() print("\n[대화가 초기화되었습니다.]\n") continue ui.show_assistant_prefix() - chat.respond(user_input) + async for token in agent.stream_response(user_input): + print(token, end="", flush=True) + print("\n") - db.close() + +def main() -> None: + asyncio.run(main_async()) if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index aa2b00c..44f5403 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,21 @@ mlx-lm>=0.19.0 dependency-injector>=4.41.0 PyMySQL>=1.1.0 pydantic-settings>=2.0.0 +# Phase 1 — LangChain BaseChatModel +langchain-core>=0.3.0 +# Phase 2 — RAG pipeline (Qdrant, embeddings, document loading) +langchain-community>=0.3.0 +langchain-huggingface>=0.1.0 +langchain-text-splitters>=0.3.0 +langchain-qdrant>=0.2.0 +sentence-transformers>=3.0.0 +qdrant-client>=1.9.0 +pdfplumber>=0.11.0 +# Phase 3 — Agent orchestration +langgraph>=1.0.0 +# Phase 4 — Web UI +gradio>=4.0.0 +# Phase 6 — 웹 검색 Tool +duckduckgo-search>=6.0.0 +# Phase 14 — 음성 인터페이스 (STT) +openai-whisper>=20231117 diff --git a/services/agent/__init__.py b/services/agent/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/services/agent/agent_service.py b/services/agent/agent_service.py new file mode 100644 index 0000000..d96d37b --- /dev/null +++ b/services/agent/agent_service.py @@ -0,0 +1,248 @@ +import os +import time +import uuid +from typing import AsyncIterator + +from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, SystemMessage +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import START, MessagesState, StateGraph +from langgraph.prebuilt import ToolNode, tools_condition + +from services.agent.tools import get_current_date, make_memory_tools, make_retriever_tool, make_search_tool, web_search + + +class AgentService: + """LangGraph ReAct 에이전트 서비스. + + Tool Calling 루프, 대화 히스토리, 조건부 라우팅을 LangGraph가 담당한다. + """ + + def __init__( + self, + chat_model, + retriever_service, + system_prompt: str, + rag_verbose: bool = False, + rag_show_sources: bool = False, + langgraph_verbose: bool = False, + think_verbose: bool = False, + user_profile_repository=None, + conversation_repository=None, + user_id: str = "default", + ): + self._system_prompt = system_prompt + self._rag_verbose = rag_verbose + self._rag_show_sources = rag_show_sources + self._langgraph_verbose = langgraph_verbose + self._think_verbose = think_verbose + self._source_buffer: list[dict] = [] + self._thread_id = "default" + self._profile_repo = user_profile_repository + self._conv_repo = conversation_repository + self._conv_id: int | None = None + self._pending_history: list = [] + self._user_id = user_id + + if conversation_repository: + try: + self._conv_id = conversation_repository.get_latest_conversation_id(user_id) + if self._conv_id is None: + self._conv_id = conversation_repository.create_conversation(user_id) + else: + turns = conversation_repository.load_turns_after(self._conv_id, None, limit=10) + for turn in turns: + if turn["role"] == "user": + self._pending_history.append(HumanMessage(content=turn["content"])) + elif turn["role"] == "assistant": + self._pending_history.append(AIMessage(content=turn["content"])) + if self._pending_history: + print(f"[Agent] 이전 대화 {len(self._pending_history) // 2}턴 복원") + except Exception as e: + print(f"[Agent] 이력 복원 실패: {e}") + self._conv_id = None + self._pending_history = [] + + if rag_show_sources: + search_tool = make_search_tool(retriever_service, self._source_buffer) + else: + search_tool = make_retriever_tool(retriever_service) + tools = [search_tool, web_search, get_current_date] + if user_profile_repository is not None: + remember_tool, recall_tool = make_memory_tools(user_profile_repository, user_id) + tools += [remember_tool, recall_tool] + llm_with_tools = chat_model.bind_tools(tools) + + async def call_model(state: MessagesState, config: RunnableConfig) -> dict: + system_content = self._system_prompt + if self._profile_repo: + profile = self._profile_repo.get_all(self._user_id) + if profile: + lines = "\n".join(f"- {k}: {v}" for k, v in profile.items()) + system_content += f"\n\n## 사용자 정보 (이전 대화에서 기억된 내용)\n{lines}" + msgs = [SystemMessage(content=system_content)] + state["messages"] + thinking_acc, content_acc, tool_calls_acc = "", "", [] + async for chunk in llm_with_tools.astream(msgs, config): + t = chunk.additional_kwargs.get("thinking", "") + if t: + thinking_acc += t + if chunk.content and isinstance(chunk.content, str): + content_acc += chunk.content + if chunk.tool_calls: + tool_calls_acc.extend(chunk.tool_calls) + extra = {"thinking": thinking_acc} if thinking_acc else {} + return {"messages": [AIMessage( + content=content_acc, + tool_calls=tool_calls_acc, + additional_kwargs=extra, + )]} + + builder = StateGraph(MessagesState) + builder.add_node("agent", call_model) + builder.add_node("tools", ToolNode(tools)) + builder.add_edge(START, "agent") + builder.add_conditional_edges("agent", tools_condition) + builder.add_edge("tools", "agent") + + self._agent = builder.compile(checkpointer=MemorySaver()) + + @property + def _config(self) -> dict: + return {"configurable": {"thread_id": self._thread_id}} + + async def stream_response(self, user_input: str, show_thinking: bool | None = None) -> AsyncIterator[str]: + """사용자 입력을 받아 응답 토큰을 순서대로 yield한다.""" + _think_verbose = show_thinking if show_thinking is not None else self._think_verbose + self._source_buffer.clear() + + # 재시작 후 첫 호출 시 MySQL 이력을 초기 상태에 주입 + if self._pending_history: + all_messages = self._pending_history + [HumanMessage(content=user_input)] + self._pending_history = [] + else: + all_messages = [HumanMessage(content=user_input)] + messages = {"messages": all_messages} + response_content = "" # 실제 답변 내용만 누적 (MySQL 저장용) + pending_tool_calls: dict = {} # tool_call_id → {name, args} + prev_node: str = "" + lg = self._langgraph_verbose + thinking_open = False # [사고 과정] 헤더 출력 여부 + content_started = False # 노드 당 레이블 1회 출력 제어 + start_time = time.perf_counter() + + async for chunk, metadata in self._agent.astream( + messages, self._config, stream_mode="messages" + ): + node = metadata.get("langgraph_node", "") + + # ── 노드 전환 시 플래그 리셋 + 레이블 출력 ────────────── + if node != prev_node: + content_started = False + if lg: + if node == "agent": + elapsed = time.perf_counter() - start_time + label = "agent: 검색 결과 반영 중" if prev_node == "tools" else "agent: 질문 분석 중" + yield f"\n[LangGraph → {label}] ({elapsed:.2f}s)\n" + elif node == "tools": + elapsed = time.perf_counter() - start_time + yield f"\n[LangGraph → tools: 도구 실행 중] ({elapsed:.2f}s)\n" + prev_node = node + + # ── agent 노드 — AIMessageChunk만 처리 (중복 방지) ────── + if node == "agent" and isinstance(chunk, AIMessageChunk): + thinking = chunk.additional_kwargs.get("thinking", "") + if thinking and _think_verbose: + if not thinking_open: + yield "\n[사고 과정]\n" + thinking_open = True + yield thinking + + if chunk.tool_calls: + if thinking_open: + yield "\n[/사고 과정]\n" + thinking_open = False + for tc in chunk.tool_calls: + pending_tool_calls[tc["id"]] = tc + if tc.get("name") == "search_documents": + query = tc.get("args", {}).get("query", "") + yield f'\n문서 검색 중... ("{query}")\n' if query else "\n문서 검색 중...\n" + elif tc.get("name") == "web_search": + query = tc.get("args", {}).get("query", "") + yield f'\n웹 검색 중... ("{query}")\n' if query else "\n웹 검색 중...\n" + elif lg: + args_str = ", ".join(f'{k}="{v}"' for k, v in tc["args"].items()) + yield f" [tool_call: {tc['name']}({args_str})]\n" + + elif chunk.content: + if thinking_open: + yield "\n[/사고 과정]\n" + thinking_open = False + if lg and not content_started: + yield "\n[LangGraph → agent: 최종 답변 생성]\n\n" + content_started = True + response_content += chunk.content + yield chunk.content + + # ── agent 노드 — AIMessage(최종 state) ────────────────── + # 청크 스트리밍이 없었던 경우(edge case)에만 처리 + elif node == "agent" and isinstance(chunk, AIMessage): + if not content_started and not thinking_open: + thinking = chunk.additional_kwargs.get("thinking", "") + if thinking and self._think_verbose: + yield "\n[사고 과정]\n" + yield thinking + yield "\n[/사고 과정]\n" + if chunk.content: + if lg: + yield "\n[LangGraph → agent: 최종 답변 생성]\n\n" + response_content += chunk.content + yield chunk.content + + # ── tools 노드 ─────────────────────────────────────────── + elif node == "tools" and hasattr(chunk, "name") and chunk.name == "search_documents": + if lg: + result_lines = [b for b in chunk.content.split("\n\n") if b.strip()] + yield f" [결과: {len(result_lines)}개 문서 반환 → agent 복귀]\n" + + if self._rag_verbose: + tc = pending_tool_calls.get(chunk.tool_call_id, {}) + query = tc.get("args", {}).get("query", "") + yield f'\n[문서 검색: "{query}"]\n' + for block in chunk.content.split("\n\n"): + if block.strip(): + preview = block.strip().replace("\n", " ")[:80] + yield f" → {preview}\n" + yield "\n" + + elif node == "tools" and hasattr(chunk, "name") and chunk.name == "web_search": + if lg: + result_lines = [b for b in chunk.content.split("\n\n") if b.strip()] + yield f" [웹 검색 결과: {len(result_lines)}건 → agent 복귀]\n" + + if thinking_open: + yield "\n[/사고 과정]\n" + + # 대화 내용을 MySQL에 저장 + if self._conv_repo and self._conv_id and response_content: + try: + self._conv_repo.save_message(self._conv_id, "user", user_input) + self._conv_repo.save_message(self._conv_id, "assistant", response_content) + except Exception as e: + print(f"[Agent] 대화 저장 실패: {e}") + + if self._rag_show_sources and self._source_buffer: + yield "\n\n[참고 문서]\n" + for src in self._source_buffer: + filename = os.path.basename(src["source"]) + page = f" {src['page']}페이지" if "page" in src else "" + yield f"- {filename}{page}\n" + + def reset(self) -> None: + """새 thread_id로 대화 히스토리를 초기화한다.""" + self._thread_id = str(uuid.uuid4()) + self._pending_history = [] + if self._conv_repo: + try: + self._conv_id = self._conv_repo.create_conversation(self._user_id) + except Exception: + self._conv_id = None diff --git a/services/agent/tools.py b/services/agent/tools.py new file mode 100644 index 0000000..fba00ff --- /dev/null +++ b/services/agent/tools.py @@ -0,0 +1,96 @@ +from datetime import date + +from langchain_core.tools import tool + + +@tool +def get_current_date() -> str: + """오늘 날짜를 반환합니다. 날짜·기간 관련 질문에 사용하세요.""" + return date.today().isoformat() + + +@tool +def web_search(query: str) -> str: + """최신 뉴스, 금리, 육아 정책 등 실시간 정보가 필요할 때 사용하세요. 저장된 문서에 없는 최신 정보를 검색합니다.""" + from duckduckgo_search import DDGS + with DDGS() as ddgs: + results = list(ddgs.text(query, max_results=5)) + if not results: + return "검색 결과가 없습니다." + return "\n\n".join( + f"[{r['title']}]\n{r['body']}\n출처: {r['href']}" + for r in results + ) + + +def make_retriever_tool(retriever_service): + """as_retriever()를 사용하는 단순 검색 Tool (source_buffer 없음).""" + retriever = retriever_service.as_retriever() + + @tool + def search_documents(query: str) -> str: + """등록된 문서(논문, 육아 가이드, 금융 자료 등)에서 관련 정보를 검색합니다. + 육아·금융 관련 질문이 오면 자신의 지식으로 답하기 전에 반드시 이 도구를 먼저 호출하세요. + 등록된 문서가 없거나 검색 결과가 없을 때만 자신의 학습 지식을 보조적으로 활용합니다.""" + docs = retriever.invoke(query) + if not docs: + return "관련 문서를 찾을 수 없습니다." + return "\n\n".join( + f"[문서 {i + 1}]\n{doc.page_content}" for i, doc in enumerate(docs) + ) + + return search_documents + + +def make_memory_tools(profile_repo, user_id: str = "default"): + """사용자 정보 저장/조회 Tool 쌍을 반환한다.""" + + @tool + def remember_user_info(key: str, value: str) -> str: + """사용자 정보를 영구 저장합니다. 다음 대화에도 기억해야 할 정보를 저장하세요. + - 아이 나이는 반드시 '생년(출생연도)'으로 저장하세요. 나이는 매년 바뀌지만 생년은 영구적입니다. + 예: key='첫째_이름' value='신도율', key='첫째_생년' value='2020' + - 기타 key 예시: 재정_목표, 거주지, 직업, 자녀수""" + profile_repo.remember(key, value, user_id=user_id) + return f"'{key}' 정보를 기억했습니다: {value}" + + @tool + def recall_user_info(key: str) -> str: + """이전 대화에서 저장한 사용자 정보를 조회합니다.""" + value = profile_repo.recall(key, user_id=user_id) + return value if value is not None else f"'{key}'에 대한 저장된 정보가 없습니다." + + return remember_user_info, recall_user_info + + +def make_search_tool(retriever_service, source_buffer: list | None = None): + """RetrieverService를 클로저로 감싼 문서 검색 Tool을 반환합니다. + + source_buffer가 주어지면 검색된 문서의 메타데이터(source, page)를 누적 저장합니다. + """ + + @tool + def search_documents(query: str) -> str: + """등록된 문서(논문, 육아 가이드, 금융 자료 등)에서 관련 정보를 검색합니다. + 육아·금융 관련 질문이 오면 자신의 지식으로 답하기 전에 반드시 이 도구를 먼저 호출하세요. + 등록된 문서가 없거나 검색 결과가 없을 때만 자신의 학습 지식을 보조적으로 활용합니다.""" + docs = retriever_service.search(query) + + if source_buffer is not None: + for doc in docs: + src = doc.metadata.get("source", "") + page = doc.metadata.get("page", None) + if src: + entry = {"source": src} + if page is not None: + entry["page"] = page + 1 # 0-indexed → 1-indexed + if entry not in source_buffer: + source_buffer.append(entry) + + if not docs: + return "관련 문서를 찾을 수 없습니다." + return "\n\n".join( + f"[문서 {i + 1}]\n{doc.page_content}" for i, doc in enumerate(docs) + ) + + return search_documents diff --git a/services/db/conversation_repository.py b/services/db/conversation_repository.py index bc96dc7..dc59776 100644 --- a/services/db/conversation_repository.py +++ b/services/db/conversation_repository.py @@ -8,14 +8,16 @@ class ConversationRepository: def __init__(self, db: DatabaseService): self._db = db - def create_conversation(self) -> int: + def create_conversation(self, user_id: str = "default") -> int: return self._db.execute_write( - "INSERT INTO td_conversations () VALUES ()" + "INSERT INTO td_conversations (user_id) VALUES (%s)", + (user_id,), ) - def get_latest_conversation_id(self) -> int | None: + def get_latest_conversation_id(self, user_id: str = "default") -> int | None: rows = self._db.execute( - "SELECT id FROM td_conversations ORDER BY created_at DESC LIMIT 1" + "SELECT id FROM td_conversations WHERE user_id = %s ORDER BY created_at DESC LIMIT 1", + (user_id,), ) return rows[0]["id"] if rows else None diff --git a/services/db/mysql_service.py b/services/db/mysql_service.py index 410f05e..86cfa39 100644 --- a/services/db/mysql_service.py +++ b/services/db/mysql_service.py @@ -1,47 +1,81 @@ from __future__ import annotations +import threading from typing import Any class DatabaseService: - """MySQL 연결을 캡슐화하는 서비스. 미설정 시 graceful skip.""" + """MySQL 연결을 캡슐화하는 서비스. 미설정 시 graceful skip. - def __init__(self, host: str, port: int, db: str, user: str, password: str): + 스레드별 독립 연결(thread-local)을 사용해 LangGraph ToolNode의 + 스레드 풀 실행과 pymysql 비안전성 문제를 해결한다. + """ + + def __init__( + self, + host: str, + port: int, + db: str, + user: str, + password: str, + ): self._config = dict(host=host, port=port, db=db, user=user, passwd=password) - self._conn = None + self._local = threading.local() + + # ── DB 연결 ──────────────────────────────────────────────────────── + + def _get_conn(self): + if not self._config["user"]: + return None + + import pymysql + conn = getattr(self._local, "conn", None) + if conn is None: + try: + self._local.conn = pymysql.connect(**self._config) + except Exception as e: + print(f"[DB] 연결 실패: {e}") + return None + else: + try: + conn.ping(reconnect=True) + except Exception: + try: + self._local.conn = pymysql.connect(**self._config) + except Exception as e: + print(f"[DB] 재연결 실패: {e}") + return None + return self._local.conn def connect(self) -> None: - if not self._config["user"]: - return - try: - import pymysql - self._conn = pymysql.connect(**self._config) - except Exception as e: - print(f"[DB] 연결 실패 (선택적 기능): {e}") + self._get_conn() def execute(self, sql: str, params: tuple = ()) -> list[dict[str, Any]]: - if self._conn is None: + conn = self._get_conn() + if conn is None: return [] - cursor = self._conn.cursor() + cursor = conn.cursor() cursor.execute(sql, params) columns = [d[0] for d in cursor.description or []] return [dict(zip(columns, row)) for row in cursor.fetchall()] def execute_write(self, sql: str, params: tuple = ()) -> int: - """INSERT/UPDATE/DELETE 실행 후 lastrowid 반환.""" - if self._conn is None: + conn = self._get_conn() + if conn is None: return 0 - cursor = self._conn.cursor() + cursor = conn.cursor() cursor.execute(sql, params) - self._conn.commit() + conn.commit() return cursor.lastrowid def init_schema(self) -> None: - if self._conn is None: + conn = self._get_conn() + if conn is None: return - cursor = self._conn.cursor() + cursor = conn.cursor() cursor.execute(""" CREATE TABLE IF NOT EXISTS td_conversations ( id INT AUTO_INCREMENT PRIMARY KEY, + user_id VARCHAR(50) NOT NULL DEFAULT 'default', created_at DATETIME DEFAULT CURRENT_TIMESTAMP ) """) @@ -55,9 +89,35 @@ class DatabaseService: FOREIGN KEY (conversation_id) REFERENCES td_conversations(id) ) """) - self._conn.commit() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS td_user_profile ( + id INT AUTO_INCREMENT PRIMARY KEY, + user_id VARCHAR(50) NOT NULL DEFAULT 'default', + key_name VARCHAR(100) NOT NULL, + value TEXT NOT NULL, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + UNIQUE KEY uq_user_key (user_id, key_name) + ) + """) + conn.commit() + self._migrate_schema(conn) + + def _migrate_schema(self, conn) -> None: + cursor = conn.cursor() + for sql in [ + "ALTER TABLE td_conversations ADD COLUMN user_id VARCHAR(50) NOT NULL DEFAULT 'default'", + "ALTER TABLE td_user_profile ADD COLUMN user_id VARCHAR(50) NOT NULL DEFAULT 'default'", + "ALTER TABLE td_user_profile DROP INDEX key_name", + "ALTER TABLE td_user_profile ADD UNIQUE KEY uq_user_key (user_id, key_name)", + ]: + try: + cursor.execute(sql) + conn.commit() + except Exception: + pass def close(self) -> None: - if self._conn: - self._conn.close() - self._conn = None + conn = getattr(self._local, "conn", None) + if conn: + conn.close() + self._local.conn = None diff --git a/services/db/user_profile_repository.py b/services/db/user_profile_repository.py new file mode 100644 index 0000000..e5827f1 --- /dev/null +++ b/services/db/user_profile_repository.py @@ -0,0 +1,31 @@ +from __future__ import annotations +from services.db.mysql_service import DatabaseService + + +class UserProfileRepository: + """td_user_profile 테이블을 통한 사용자 장기 메모리 저장소.""" + + def __init__(self, db: DatabaseService): + self._db = db + + def remember(self, key: str, value: str, user_id: str = "default") -> None: + self._db.execute_write( + """INSERT INTO td_user_profile (user_id, key_name, value) + VALUES (%s, %s, %s) + ON DUPLICATE KEY UPDATE value = VALUES(value), updated_at = NOW()""", + (user_id, key, value), + ) + + def recall(self, key: str, user_id: str = "default") -> str | None: + rows = self._db.execute( + "SELECT value FROM td_user_profile WHERE user_id = %s AND key_name = %s", + (user_id, key), + ) + return rows[0]["value"] if rows else None + + def get_all(self, user_id: str = "default") -> dict[str, str]: + rows = self._db.execute( + "SELECT key_name, value FROM td_user_profile WHERE user_id = %s ORDER BY updated_at", + (user_id,), + ) + return {r["key_name"]: r["value"] for r in rows} diff --git a/services/model/mlx_chat_model.py b/services/model/mlx_chat_model.py new file mode 100644 index 0000000..7f8f038 --- /dev/null +++ b/services/model/mlx_chat_model.py @@ -0,0 +1,337 @@ +import json +import re +import uuid +from typing import Any, Iterator, List, Optional + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.utils.function_calling import convert_to_openai_tool +from pydantic import PrivateAttr, model_validator + +_TOOL_CALL_RE = re.compile(r"\s*(.*?)\s*", re.DOTALL) +_THINK_RE = re.compile(r"(.*?)", re.DOTALL) + + +class MlxChatModel(BaseChatModel): + """mlx-lm 기반 LangChain BaseChatModel. + + LangGraph와 완전 호환 — Tool Calling, 스트리밍, bind_tools() 지원. + Qwen3 thinking 모드 지원 — 블록을 content와 분리해 additional_kwargs에 저장. + """ + + model_id: str + max_tokens: int = 1024 + temp: float = 0.0 + enable_thinking: bool = True + + _model: Any = PrivateAttr(default=None) + _tokenizer: Any = PrivateAttr(default=None) + + @model_validator(mode="after") + def _load(self) -> "MlxChatModel": + from mlx_lm import load + print(f"모델 로딩 중: {self.model_id}") + self._model, self._tokenizer = load(self.model_id) + return self + + @property + def _llm_type(self) -> str: + return "mlx-chat" + + # ── 메시지 → chat dict 변환 ─────────────────────────────────── + + def _to_chat_dicts(self, messages: List[BaseMessage]) -> List[dict]: + result = [] + for msg in messages: + if isinstance(msg, SystemMessage): + result.append({"role": "system", "content": str(msg.content)}) + elif isinstance(msg, HumanMessage): + result.append({"role": "user", "content": str(msg.content)}) + elif isinstance(msg, AIMessage): + if msg.tool_calls: + result.append({ + "role": "assistant", + "content": str(msg.content) if msg.content else "", + "tool_calls": [ + { + "id": tc["id"], + "type": "function", + "function": { + "name": tc["name"], + "arguments": json.dumps(tc["args"]), + }, + } + for tc in msg.tool_calls + ], + }) + else: + result.append({"role": "assistant", "content": str(msg.content)}) + elif isinstance(msg, ToolMessage): + result.append({ + "role": "tool", + "content": str(msg.content), + "tool_call_id": msg.tool_call_id, + }) + return result + + def _build_prompt(self, messages: List[BaseMessage], tools: Optional[list] = None) -> str: + kwargs: dict = { + "tokenize": False, + "add_generation_prompt": True, + } + if tools: + kwargs["tools"] = tools + # Qwen3 thinking 모드 — 지원하지 않는 모델은 무시됨 + try: + kwargs["enable_thinking"] = self.enable_thinking + return self._tokenizer.apply_chat_template(self._to_chat_dicts(messages), **kwargs) + except TypeError: + kwargs.pop("enable_thinking") + return self._tokenizer.apply_chat_template(self._to_chat_dicts(messages), **kwargs) + + # ── 블록 파싱 (Qwen3) ──────────────────────────────── + + @staticmethod + def _parse_thinking(text: str) -> tuple[str, str]: + """... 블록을 분리해 (thinking, clean_text) 반환.""" + match = _THINK_RE.search(text) + if not match: + return "", text + thinking = match.group(1).strip() + clean = _THINK_RE.sub("", text).strip() + return thinking, clean + + # ── Tool Call 파싱 ──────────────────────────────────────────── + + @staticmethod + def _parse_tool_calls(text: str) -> tuple[str, list]: + matches = _TOOL_CALL_RE.findall(text) + if not matches: + return text, [] + + tool_calls = [] + for raw in matches: + try: + data = json.loads(raw) + tool_calls.append({ + "id": f"call_{uuid.uuid4().hex[:8]}", + "name": data["name"], + "args": data.get("arguments", data.get("args", {})), + "type": "tool_call", + }) + except (json.JSONDecodeError, KeyError): + continue + + clean = _TOOL_CALL_RE.sub("", text).strip() + return clean, tool_calls + + # ── LangChain BaseChatModel 인터페이스 ──────────────────────── + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager=None, + **kwargs, + ) -> ChatResult: + from mlx_lm import generate + + tools = kwargs.get("tools") + prompt = self._build_prompt(messages, tools) + text = generate( + self._model, + self._tokenizer, + prompt=prompt, + max_tokens=self.max_tokens, + verbose=False, + ) + thinking, after_think = self._parse_thinking(text) + clean_text, tool_calls = self._parse_tool_calls(after_think) + extra = {"thinking": thinking} if thinking else {} + message = AIMessage(content=clean_text, tool_calls=tool_calls, additional_kwargs=extra) + return ChatResult(generations=[ChatGeneration(message=message)]) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager=None, + **kwargs, + ) -> Iterator[ChatGenerationChunk]: + from mlx_lm import stream_generate + + tools = kwargs.get("tools") + prompt = self._build_prompt(messages, tools) + + OPEN_THINK = "" + CLOSE_THINK = "" + OPEN_TOOL = "" + CLOSE_TOOL = "" + SAFE = max(len(OPEN_THINK), len(CLOSE_THINK), len(OPEN_TOOL), len(CLOSE_TOOL)) + + # enable_thinking=False 모델은 블록을 생성하지 않으므로 post_think에서 시작 + state = "pre_think" if self.enable_thinking else "post_think" + buf = "" + out: list[ChatGenerationChunk] = [] + + def _think(text: str) -> None: + out.append(ChatGenerationChunk( + message=AIMessageChunk(content="", additional_kwargs={"thinking": text}) + )) + + def _content(text: str) -> None: + out.append(ChatGenerationChunk(message=AIMessageChunk(content=text))) + + def _tool(raw_json: str) -> None: + try: + data = json.loads(raw_json) + tc = { + "id": f"call_{uuid.uuid4().hex[:8]}", + "name": data["name"], + "args": data.get("arguments", data.get("args", {})), + "type": "tool_call", + } + out.append(ChatGenerationChunk(message=AIMessageChunk(content="", tool_calls=[tc]))) + except (json.JSONDecodeError, KeyError): + pass + + def advance() -> None: + nonlocal state, buf + while buf: + if state == "pre_think": + idx = buf.find(OPEN_THINK) + if idx == -1: + safe = len(buf) - SAFE + if safe > 0: + _content(buf[:safe]) + buf = buf[safe:] + return + if idx > 0: + _content(buf[:idx]) + buf = buf[idx + len(OPEN_THINK):] + state = "in_think" + + elif state == "in_think": + idx = buf.find(CLOSE_THINK) + if idx == -1: + safe = len(buf) - SAFE + if safe > 0: + _think(buf[:safe]) + buf = buf[safe:] + return + if idx > 0: + _think(buf[:idx]) + buf = buf[idx + len(CLOSE_THINK):].lstrip() + state = "post_think" + + elif state == "post_think": + # 이후 \n\n 같은 공백을 건너뜀 + buf = buf.lstrip() + if not buf: + return + idx = buf.find(OPEN_TOOL) + if idx == -1: + # partial tag at end — hold and wait + for i in range(len(OPEN_TOOL) - 1, 0, -1): + if buf.endswith(OPEN_TOOL[:i]): + safe_text = buf[:-i] + if safe_text: + _content(safe_text) + buf = buf[-i:] + return + state = "in_answer" # no tool call coming + elif idx == 0: + buf = buf[len(OPEN_TOOL):] + state = "in_tool" + else: + _content(buf[:idx]) + buf = buf[idx + len(OPEN_TOOL):] + state = "in_tool" + + elif state == "in_answer": + _content(buf) + buf = "" + return + + elif state == "in_tool": + idx = buf.find(CLOSE_TOOL) + if idx == -1: + return # wait for complete JSON + _tool(buf[:idx].strip()) + buf = buf[idx + len(CLOSE_TOOL):] + state = "post_think" # may have more tool calls + + for raw in stream_generate(self._model, self._tokenizer, prompt=prompt, max_tokens=self.max_tokens): + if run_manager: + run_manager.on_llm_new_token(raw.text) + buf += raw.text + advance() + yield from out + out.clear() + + # flush remaining buffer + if buf: + if state == "in_think": + _think(buf) + elif state == "in_answer": + _content(buf) + elif state in ("pre_think", "post_think"): + clean, tcs = self._parse_tool_calls(buf) + if clean: + _content(clean) + for tc in tcs: + out.append(ChatGenerationChunk(message=AIMessageChunk(content="", tool_calls=[tc]))) + elif state == "in_tool": + _tool(buf.strip()) + yield from out + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager=None, + **kwargs, + ): + import asyncio + import threading + + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue(maxsize=200) + sentinel = object() + exc_holder: list = [] + + def _run() -> None: + try: + for chunk in self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ): + asyncio.run_coroutine_threadsafe(queue.put(chunk), loop).result() + except Exception as exc: + exc_holder.append(exc) + finally: + asyncio.run_coroutine_threadsafe(queue.put(sentinel), loop).result() + + thread = threading.Thread(target=_run, daemon=True) + thread.start() + + while True: + item = await queue.get() + if item is sentinel: + break + yield item + + thread.join(timeout=5) + if exc_holder: + raise exc_holder[0] + + def bind_tools(self, tools, tool_choice=None, **kwargs): + formatted = [convert_to_openai_tool(t) for t in tools] + return self.bind(tools=formatted, **kwargs) diff --git a/services/rag/__init__.py b/services/rag/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/services/rag/ingestion_service.py b/services/rag/ingestion_service.py new file mode 100644 index 0000000..1e9bd2a --- /dev/null +++ b/services/rag/ingestion_service.py @@ -0,0 +1,107 @@ +import re + +import numpy as np +from langchain_community.document_loaders import PDFPlumberLoader, TextLoader +from langchain_core.documents import Document +from langchain_qdrant import QdrantVectorStore +from qdrant_client import QdrantClient +from qdrant_client.models import Filter, FieldCondition, MatchValue, FilterSelector + + +def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: + return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-10)) + + +class _SemanticSplitter: + """문장 임베딩 유사도 기반 청커. + + 인접 문장 간 코사인 유사도를 계산하고, 유사도가 낮은(= 의미 전환) 지점에서 청크를 분리한다. + breakpoint_percentile=95이면 유사도 하위 5% 지점이 분리 경계가 된다. + """ + + _SENTENCE_RE = re.compile(r"(?<=[.!?。!?])\s+") + + def __init__(self, embeddings, breakpoint_percentile: int = 95): + self._embeddings = embeddings + self._percentile = breakpoint_percentile + + def split_documents(self, docs: list[Document]) -> list[Document]: + result = [] + for doc in docs: + for chunk_text in self._split_text(doc.page_content): + result.append(Document(page_content=chunk_text, metadata=doc.metadata)) + return result + + def _split_text(self, text: str) -> list[str]: + sentences = [s for s in self._SENTENCE_RE.split(text.strip()) if s.strip()] + if len(sentences) <= 1: + return [text.strip()] if text.strip() else [] + + vecs = np.array(self._embeddings.embed_documents(sentences)) + similarities = [_cosine_similarity(vecs[i], vecs[i + 1]) for i in range(len(vecs) - 1)] + threshold = float(np.percentile(similarities, 100 - self._percentile)) + breakpoints = [i + 1 for i, s in enumerate(similarities) if s < threshold] + + chunks, start = [], 0 + for bp in breakpoints: + chunk = " ".join(sentences[start:bp]).strip() + if chunk: + chunks.append(chunk) + start = bp + tail = " ".join(sentences[start:]).strip() + if tail: + chunks.append(tail) + return chunks + + +class IngestionService: + """문서를 의미 단위 청크로 분할해 Qdrant에 저장하는 수집 파이프라인.""" + + def __init__( + self, + embeddings, + qdrant_url: str, + collection_name: str, + breakpoint_threshold_type: str = "percentile", + ): + self._embeddings = embeddings + self._qdrant_url = qdrant_url + self._collection_name = collection_name + # breakpoint_threshold_type은 향후 확장용으로 수용 (현재는 percentile 방식 고정) + self._splitter = _SemanticSplitter(embeddings, breakpoint_percentile=95) + self._client = QdrantClient(url=qdrant_url) + + def _delete_by_source(self, source_path: str) -> None: + """같은 파일 경로로 저장된 기존 청크를 모두 삭제한다.""" + try: + self._client.delete( + collection_name=self._collection_name, + points_selector=FilterSelector( + filter=Filter( + must=[ + FieldCondition( + key="metadata.source", + match=MatchValue(value=source_path), + ) + ] + ) + ), + ) + except Exception: + pass # 컬렉션이 없을 때(최초 수집) 무시 + + def ingest(self, file_paths: list[str]) -> int: + docs = [] + for path in file_paths: + self._delete_by_source(path) + loader = PDFPlumberLoader(path) if path.endswith(".pdf") else TextLoader(path, encoding="utf-8") + docs.extend(loader.load()) + + chunks = self._splitter.split_documents(docs) + QdrantVectorStore.from_documents( + documents=chunks, + embedding=self._embeddings, + url=self._qdrant_url, + collection_name=self._collection_name, + ) + return len(chunks) diff --git a/services/rag/retriever_service.py b/services/rag/retriever_service.py new file mode 100644 index 0000000..7c5958f --- /dev/null +++ b/services/rag/retriever_service.py @@ -0,0 +1,67 @@ +from langchain_core.documents import Document +from langchain_qdrant import QdrantVectorStore +from qdrant_client import QdrantClient +from qdrant_client.models import Filter, FieldCondition, MatchValue, FilterSelector + + +class RetrieverService: + """Qdrant 벡터 검색 서비스. LangGraph Tool 및 직접 검색 모두 지원.""" + + def __init__( + self, + embeddings, + qdrant_url: str, + collection_name: str, + top_k: int, + ): + self._client = QdrantClient(url=qdrant_url) + self._collection_name = collection_name + self._store = QdrantVectorStore( + client=self._client, + collection_name=collection_name, + embedding=embeddings, + ) + self._top_k = top_k + + def as_retriever(self): + return self._store.as_retriever(search_kwargs={"k": self._top_k}) + + def search(self, query: str) -> list[Document]: + return self._store.similarity_search(query, k=self._top_k) + + def list_documents(self) -> list[str]: + """Qdrant에 저장된 고유 파일 경로 목록을 반환한다.""" + sources: set[str] = set() + offset = None + while True: + results, next_offset = self._client.scroll( + collection_name=self._collection_name, + with_payload=True, + limit=200, + offset=offset, + ) + for point in results: + src = (point.payload or {}).get("metadata", {}).get("source", "") + if src: + sources.add(src) + if next_offset is None: + break + offset = next_offset + return sorted(sources) + + def delete_document(self, source: str) -> None: + """파일 경로로 저장된 모든 청크를 Qdrant에서 삭제한다.""" + try: + self._client.delete( + collection_name=self._collection_name, + points_selector=FilterSelector( + filter=Filter( + must=[FieldCondition( + key="metadata.source", + match=MatchValue(value=source), + )] + ) + ), + ) + except Exception: + pass