Implement Phase 4~14: LangGraph Agent, RAG pipeline, Gradio Web UI, voice interface
- Upgrade LLM to Qwen3-14B-4bit with Thinking mode (MlxChatModel as LangChain BaseChatModel) - Add LangGraph ReAct agent with tool calling loop (search_documents, web_search, get_current_date, remember/recall_user_info) - Add RAG pipeline: BAAI/bge-m3 embeddings + Qdrant vector store + semantic chunking (SemanticSplitter via cosine similarity) - Replace fixed-size RecursiveCharacterTextSplitter with meaning-based SemanticSplitter (numpy only, no extra deps) - Add Gradio Web UI (app.py): chat, document ingestion, document management tabs - Add multi-user support (user_id isolation in DB + per-user agent cache + dropdown selector) - Add conversation history restore from MySQL on agent init (Phase 11) - Add UserProfileRepository for persistent user profile (remember/recall tools) - Add thread-local DB connections to fix pymysql thread-safety with LangGraph ToolNode - Add Phase 14 voice interface: Whisper STT (microphone → text) + macOS TTS (say -v Yuna) - Enforce search_documents-first policy in system prompt and tool descriptions - Update ROADMAP2.md: Phase 14 완료, Phase 13 청킹 부분 완료 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,248 @@
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import AsyncIterator
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.graph import START, MessagesState, StateGraph
|
||||
from langgraph.prebuilt import ToolNode, tools_condition
|
||||
|
||||
from services.agent.tools import get_current_date, make_memory_tools, make_retriever_tool, make_search_tool, web_search
|
||||
|
||||
|
||||
class AgentService:
|
||||
"""LangGraph ReAct 에이전트 서비스.
|
||||
|
||||
Tool Calling 루프, 대화 히스토리, 조건부 라우팅을 LangGraph가 담당한다.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_model,
|
||||
retriever_service,
|
||||
system_prompt: str,
|
||||
rag_verbose: bool = False,
|
||||
rag_show_sources: bool = False,
|
||||
langgraph_verbose: bool = False,
|
||||
think_verbose: bool = False,
|
||||
user_profile_repository=None,
|
||||
conversation_repository=None,
|
||||
user_id: str = "default",
|
||||
):
|
||||
self._system_prompt = system_prompt
|
||||
self._rag_verbose = rag_verbose
|
||||
self._rag_show_sources = rag_show_sources
|
||||
self._langgraph_verbose = langgraph_verbose
|
||||
self._think_verbose = think_verbose
|
||||
self._source_buffer: list[dict] = []
|
||||
self._thread_id = "default"
|
||||
self._profile_repo = user_profile_repository
|
||||
self._conv_repo = conversation_repository
|
||||
self._conv_id: int | None = None
|
||||
self._pending_history: list = []
|
||||
self._user_id = user_id
|
||||
|
||||
if conversation_repository:
|
||||
try:
|
||||
self._conv_id = conversation_repository.get_latest_conversation_id(user_id)
|
||||
if self._conv_id is None:
|
||||
self._conv_id = conversation_repository.create_conversation(user_id)
|
||||
else:
|
||||
turns = conversation_repository.load_turns_after(self._conv_id, None, limit=10)
|
||||
for turn in turns:
|
||||
if turn["role"] == "user":
|
||||
self._pending_history.append(HumanMessage(content=turn["content"]))
|
||||
elif turn["role"] == "assistant":
|
||||
self._pending_history.append(AIMessage(content=turn["content"]))
|
||||
if self._pending_history:
|
||||
print(f"[Agent] 이전 대화 {len(self._pending_history) // 2}턴 복원")
|
||||
except Exception as e:
|
||||
print(f"[Agent] 이력 복원 실패: {e}")
|
||||
self._conv_id = None
|
||||
self._pending_history = []
|
||||
|
||||
if rag_show_sources:
|
||||
search_tool = make_search_tool(retriever_service, self._source_buffer)
|
||||
else:
|
||||
search_tool = make_retriever_tool(retriever_service)
|
||||
tools = [search_tool, web_search, get_current_date]
|
||||
if user_profile_repository is not None:
|
||||
remember_tool, recall_tool = make_memory_tools(user_profile_repository, user_id)
|
||||
tools += [remember_tool, recall_tool]
|
||||
llm_with_tools = chat_model.bind_tools(tools)
|
||||
|
||||
async def call_model(state: MessagesState, config: RunnableConfig) -> dict:
|
||||
system_content = self._system_prompt
|
||||
if self._profile_repo:
|
||||
profile = self._profile_repo.get_all(self._user_id)
|
||||
if profile:
|
||||
lines = "\n".join(f"- {k}: {v}" for k, v in profile.items())
|
||||
system_content += f"\n\n## 사용자 정보 (이전 대화에서 기억된 내용)\n{lines}"
|
||||
msgs = [SystemMessage(content=system_content)] + state["messages"]
|
||||
thinking_acc, content_acc, tool_calls_acc = "", "", []
|
||||
async for chunk in llm_with_tools.astream(msgs, config):
|
||||
t = chunk.additional_kwargs.get("thinking", "")
|
||||
if t:
|
||||
thinking_acc += t
|
||||
if chunk.content and isinstance(chunk.content, str):
|
||||
content_acc += chunk.content
|
||||
if chunk.tool_calls:
|
||||
tool_calls_acc.extend(chunk.tool_calls)
|
||||
extra = {"thinking": thinking_acc} if thinking_acc else {}
|
||||
return {"messages": [AIMessage(
|
||||
content=content_acc,
|
||||
tool_calls=tool_calls_acc,
|
||||
additional_kwargs=extra,
|
||||
)]}
|
||||
|
||||
builder = StateGraph(MessagesState)
|
||||
builder.add_node("agent", call_model)
|
||||
builder.add_node("tools", ToolNode(tools))
|
||||
builder.add_edge(START, "agent")
|
||||
builder.add_conditional_edges("agent", tools_condition)
|
||||
builder.add_edge("tools", "agent")
|
||||
|
||||
self._agent = builder.compile(checkpointer=MemorySaver())
|
||||
|
||||
@property
|
||||
def _config(self) -> dict:
|
||||
return {"configurable": {"thread_id": self._thread_id}}
|
||||
|
||||
async def stream_response(self, user_input: str, show_thinking: bool | None = None) -> AsyncIterator[str]:
|
||||
"""사용자 입력을 받아 응답 토큰을 순서대로 yield한다."""
|
||||
_think_verbose = show_thinking if show_thinking is not None else self._think_verbose
|
||||
self._source_buffer.clear()
|
||||
|
||||
# 재시작 후 첫 호출 시 MySQL 이력을 초기 상태에 주입
|
||||
if self._pending_history:
|
||||
all_messages = self._pending_history + [HumanMessage(content=user_input)]
|
||||
self._pending_history = []
|
||||
else:
|
||||
all_messages = [HumanMessage(content=user_input)]
|
||||
messages = {"messages": all_messages}
|
||||
response_content = "" # 실제 답변 내용만 누적 (MySQL 저장용)
|
||||
pending_tool_calls: dict = {} # tool_call_id → {name, args}
|
||||
prev_node: str = ""
|
||||
lg = self._langgraph_verbose
|
||||
thinking_open = False # [사고 과정] 헤더 출력 여부
|
||||
content_started = False # 노드 당 레이블 1회 출력 제어
|
||||
start_time = time.perf_counter()
|
||||
|
||||
async for chunk, metadata in self._agent.astream(
|
||||
messages, self._config, stream_mode="messages"
|
||||
):
|
||||
node = metadata.get("langgraph_node", "")
|
||||
|
||||
# ── 노드 전환 시 플래그 리셋 + 레이블 출력 ──────────────
|
||||
if node != prev_node:
|
||||
content_started = False
|
||||
if lg:
|
||||
if node == "agent":
|
||||
elapsed = time.perf_counter() - start_time
|
||||
label = "agent: 검색 결과 반영 중" if prev_node == "tools" else "agent: 질문 분석 중"
|
||||
yield f"\n[LangGraph → {label}] ({elapsed:.2f}s)\n"
|
||||
elif node == "tools":
|
||||
elapsed = time.perf_counter() - start_time
|
||||
yield f"\n[LangGraph → tools: 도구 실행 중] ({elapsed:.2f}s)\n"
|
||||
prev_node = node
|
||||
|
||||
# ── agent 노드 — AIMessageChunk만 처리 (중복 방지) ──────
|
||||
if node == "agent" and isinstance(chunk, AIMessageChunk):
|
||||
thinking = chunk.additional_kwargs.get("thinking", "")
|
||||
if thinking and _think_verbose:
|
||||
if not thinking_open:
|
||||
yield "\n[사고 과정]\n"
|
||||
thinking_open = True
|
||||
yield thinking
|
||||
|
||||
if chunk.tool_calls:
|
||||
if thinking_open:
|
||||
yield "\n[/사고 과정]\n"
|
||||
thinking_open = False
|
||||
for tc in chunk.tool_calls:
|
||||
pending_tool_calls[tc["id"]] = tc
|
||||
if tc.get("name") == "search_documents":
|
||||
query = tc.get("args", {}).get("query", "")
|
||||
yield f'\n문서 검색 중... ("{query}")\n' if query else "\n문서 검색 중...\n"
|
||||
elif tc.get("name") == "web_search":
|
||||
query = tc.get("args", {}).get("query", "")
|
||||
yield f'\n웹 검색 중... ("{query}")\n' if query else "\n웹 검색 중...\n"
|
||||
elif lg:
|
||||
args_str = ", ".join(f'{k}="{v}"' for k, v in tc["args"].items())
|
||||
yield f" [tool_call: {tc['name']}({args_str})]\n"
|
||||
|
||||
elif chunk.content:
|
||||
if thinking_open:
|
||||
yield "\n[/사고 과정]\n"
|
||||
thinking_open = False
|
||||
if lg and not content_started:
|
||||
yield "\n[LangGraph → agent: 최종 답변 생성]\n\n"
|
||||
content_started = True
|
||||
response_content += chunk.content
|
||||
yield chunk.content
|
||||
|
||||
# ── agent 노드 — AIMessage(최종 state) ──────────────────
|
||||
# 청크 스트리밍이 없었던 경우(edge case)에만 처리
|
||||
elif node == "agent" and isinstance(chunk, AIMessage):
|
||||
if not content_started and not thinking_open:
|
||||
thinking = chunk.additional_kwargs.get("thinking", "")
|
||||
if thinking and self._think_verbose:
|
||||
yield "\n[사고 과정]\n"
|
||||
yield thinking
|
||||
yield "\n[/사고 과정]\n"
|
||||
if chunk.content:
|
||||
if lg:
|
||||
yield "\n[LangGraph → agent: 최종 답변 생성]\n\n"
|
||||
response_content += chunk.content
|
||||
yield chunk.content
|
||||
|
||||
# ── tools 노드 ───────────────────────────────────────────
|
||||
elif node == "tools" and hasattr(chunk, "name") and chunk.name == "search_documents":
|
||||
if lg:
|
||||
result_lines = [b for b in chunk.content.split("\n\n") if b.strip()]
|
||||
yield f" [결과: {len(result_lines)}개 문서 반환 → agent 복귀]\n"
|
||||
|
||||
if self._rag_verbose:
|
||||
tc = pending_tool_calls.get(chunk.tool_call_id, {})
|
||||
query = tc.get("args", {}).get("query", "")
|
||||
yield f'\n[문서 검색: "{query}"]\n'
|
||||
for block in chunk.content.split("\n\n"):
|
||||
if block.strip():
|
||||
preview = block.strip().replace("\n", " ")[:80]
|
||||
yield f" → {preview}\n"
|
||||
yield "\n"
|
||||
|
||||
elif node == "tools" and hasattr(chunk, "name") and chunk.name == "web_search":
|
||||
if lg:
|
||||
result_lines = [b for b in chunk.content.split("\n\n") if b.strip()]
|
||||
yield f" [웹 검색 결과: {len(result_lines)}건 → agent 복귀]\n"
|
||||
|
||||
if thinking_open:
|
||||
yield "\n[/사고 과정]\n"
|
||||
|
||||
# 대화 내용을 MySQL에 저장
|
||||
if self._conv_repo and self._conv_id and response_content:
|
||||
try:
|
||||
self._conv_repo.save_message(self._conv_id, "user", user_input)
|
||||
self._conv_repo.save_message(self._conv_id, "assistant", response_content)
|
||||
except Exception as e:
|
||||
print(f"[Agent] 대화 저장 실패: {e}")
|
||||
|
||||
if self._rag_show_sources and self._source_buffer:
|
||||
yield "\n\n[참고 문서]\n"
|
||||
for src in self._source_buffer:
|
||||
filename = os.path.basename(src["source"])
|
||||
page = f" {src['page']}페이지" if "page" in src else ""
|
||||
yield f"- {filename}{page}\n"
|
||||
|
||||
def reset(self) -> None:
|
||||
"""새 thread_id로 대화 히스토리를 초기화한다."""
|
||||
self._thread_id = str(uuid.uuid4())
|
||||
self._pending_history = []
|
||||
if self._conv_repo:
|
||||
try:
|
||||
self._conv_id = self._conv_repo.create_conversation(self._user_id)
|
||||
except Exception:
|
||||
self._conv_id = None
|
||||
@@ -0,0 +1,96 @@
|
||||
from datetime import date
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
@tool
|
||||
def get_current_date() -> str:
|
||||
"""오늘 날짜를 반환합니다. 날짜·기간 관련 질문에 사용하세요."""
|
||||
return date.today().isoformat()
|
||||
|
||||
|
||||
@tool
|
||||
def web_search(query: str) -> str:
|
||||
"""최신 뉴스, 금리, 육아 정책 등 실시간 정보가 필요할 때 사용하세요. 저장된 문서에 없는 최신 정보를 검색합니다."""
|
||||
from duckduckgo_search import DDGS
|
||||
with DDGS() as ddgs:
|
||||
results = list(ddgs.text(query, max_results=5))
|
||||
if not results:
|
||||
return "검색 결과가 없습니다."
|
||||
return "\n\n".join(
|
||||
f"[{r['title']}]\n{r['body']}\n출처: {r['href']}"
|
||||
for r in results
|
||||
)
|
||||
|
||||
|
||||
def make_retriever_tool(retriever_service):
|
||||
"""as_retriever()를 사용하는 단순 검색 Tool (source_buffer 없음)."""
|
||||
retriever = retriever_service.as_retriever()
|
||||
|
||||
@tool
|
||||
def search_documents(query: str) -> str:
|
||||
"""등록된 문서(논문, 육아 가이드, 금융 자료 등)에서 관련 정보를 검색합니다.
|
||||
육아·금융 관련 질문이 오면 자신의 지식으로 답하기 전에 반드시 이 도구를 먼저 호출하세요.
|
||||
등록된 문서가 없거나 검색 결과가 없을 때만 자신의 학습 지식을 보조적으로 활용합니다."""
|
||||
docs = retriever.invoke(query)
|
||||
if not docs:
|
||||
return "관련 문서를 찾을 수 없습니다."
|
||||
return "\n\n".join(
|
||||
f"[문서 {i + 1}]\n{doc.page_content}" for i, doc in enumerate(docs)
|
||||
)
|
||||
|
||||
return search_documents
|
||||
|
||||
|
||||
def make_memory_tools(profile_repo, user_id: str = "default"):
|
||||
"""사용자 정보 저장/조회 Tool 쌍을 반환한다."""
|
||||
|
||||
@tool
|
||||
def remember_user_info(key: str, value: str) -> str:
|
||||
"""사용자 정보를 영구 저장합니다. 다음 대화에도 기억해야 할 정보를 저장하세요.
|
||||
- 아이 나이는 반드시 '생년(출생연도)'으로 저장하세요. 나이는 매년 바뀌지만 생년은 영구적입니다.
|
||||
예: key='첫째_이름' value='신도율', key='첫째_생년' value='2020'
|
||||
- 기타 key 예시: 재정_목표, 거주지, 직업, 자녀수"""
|
||||
profile_repo.remember(key, value, user_id=user_id)
|
||||
return f"'{key}' 정보를 기억했습니다: {value}"
|
||||
|
||||
@tool
|
||||
def recall_user_info(key: str) -> str:
|
||||
"""이전 대화에서 저장한 사용자 정보를 조회합니다."""
|
||||
value = profile_repo.recall(key, user_id=user_id)
|
||||
return value if value is not None else f"'{key}'에 대한 저장된 정보가 없습니다."
|
||||
|
||||
return remember_user_info, recall_user_info
|
||||
|
||||
|
||||
def make_search_tool(retriever_service, source_buffer: list | None = None):
|
||||
"""RetrieverService를 클로저로 감싼 문서 검색 Tool을 반환합니다.
|
||||
|
||||
source_buffer가 주어지면 검색된 문서의 메타데이터(source, page)를 누적 저장합니다.
|
||||
"""
|
||||
|
||||
@tool
|
||||
def search_documents(query: str) -> str:
|
||||
"""등록된 문서(논문, 육아 가이드, 금융 자료 등)에서 관련 정보를 검색합니다.
|
||||
육아·금융 관련 질문이 오면 자신의 지식으로 답하기 전에 반드시 이 도구를 먼저 호출하세요.
|
||||
등록된 문서가 없거나 검색 결과가 없을 때만 자신의 학습 지식을 보조적으로 활용합니다."""
|
||||
docs = retriever_service.search(query)
|
||||
|
||||
if source_buffer is not None:
|
||||
for doc in docs:
|
||||
src = doc.metadata.get("source", "")
|
||||
page = doc.metadata.get("page", None)
|
||||
if src:
|
||||
entry = {"source": src}
|
||||
if page is not None:
|
||||
entry["page"] = page + 1 # 0-indexed → 1-indexed
|
||||
if entry not in source_buffer:
|
||||
source_buffer.append(entry)
|
||||
|
||||
if not docs:
|
||||
return "관련 문서를 찾을 수 없습니다."
|
||||
return "\n\n".join(
|
||||
f"[문서 {i + 1}]\n{doc.page_content}" for i, doc in enumerate(docs)
|
||||
)
|
||||
|
||||
return search_documents
|
||||
@@ -8,14 +8,16 @@ class ConversationRepository:
|
||||
def __init__(self, db: DatabaseService):
|
||||
self._db = db
|
||||
|
||||
def create_conversation(self) -> int:
|
||||
def create_conversation(self, user_id: str = "default") -> int:
|
||||
return self._db.execute_write(
|
||||
"INSERT INTO td_conversations () VALUES ()"
|
||||
"INSERT INTO td_conversations (user_id) VALUES (%s)",
|
||||
(user_id,),
|
||||
)
|
||||
|
||||
def get_latest_conversation_id(self) -> int | None:
|
||||
def get_latest_conversation_id(self, user_id: str = "default") -> int | None:
|
||||
rows = self._db.execute(
|
||||
"SELECT id FROM td_conversations ORDER BY created_at DESC LIMIT 1"
|
||||
"SELECT id FROM td_conversations WHERE user_id = %s ORDER BY created_at DESC LIMIT 1",
|
||||
(user_id,),
|
||||
)
|
||||
return rows[0]["id"] if rows else None
|
||||
|
||||
|
||||
@@ -1,47 +1,81 @@
|
||||
from __future__ import annotations
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
|
||||
class DatabaseService:
|
||||
"""MySQL 연결을 캡슐화하는 서비스. 미설정 시 graceful skip."""
|
||||
"""MySQL 연결을 캡슐화하는 서비스. 미설정 시 graceful skip.
|
||||
|
||||
def __init__(self, host: str, port: int, db: str, user: str, password: str):
|
||||
스레드별 독립 연결(thread-local)을 사용해 LangGraph ToolNode의
|
||||
스레드 풀 실행과 pymysql 비안전성 문제를 해결한다.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
db: str,
|
||||
user: str,
|
||||
password: str,
|
||||
):
|
||||
self._config = dict(host=host, port=port, db=db, user=user, passwd=password)
|
||||
self._conn = None
|
||||
self._local = threading.local()
|
||||
|
||||
# ── DB 연결 ────────────────────────────────────────────────────────
|
||||
|
||||
def _get_conn(self):
|
||||
if not self._config["user"]:
|
||||
return None
|
||||
|
||||
import pymysql
|
||||
conn = getattr(self._local, "conn", None)
|
||||
if conn is None:
|
||||
try:
|
||||
self._local.conn = pymysql.connect(**self._config)
|
||||
except Exception as e:
|
||||
print(f"[DB] 연결 실패: {e}")
|
||||
return None
|
||||
else:
|
||||
try:
|
||||
conn.ping(reconnect=True)
|
||||
except Exception:
|
||||
try:
|
||||
self._local.conn = pymysql.connect(**self._config)
|
||||
except Exception as e:
|
||||
print(f"[DB] 재연결 실패: {e}")
|
||||
return None
|
||||
return self._local.conn
|
||||
|
||||
def connect(self) -> None:
|
||||
if not self._config["user"]:
|
||||
return
|
||||
try:
|
||||
import pymysql
|
||||
self._conn = pymysql.connect(**self._config)
|
||||
except Exception as e:
|
||||
print(f"[DB] 연결 실패 (선택적 기능): {e}")
|
||||
self._get_conn()
|
||||
|
||||
def execute(self, sql: str, params: tuple = ()) -> list[dict[str, Any]]:
|
||||
if self._conn is None:
|
||||
conn = self._get_conn()
|
||||
if conn is None:
|
||||
return []
|
||||
cursor = self._conn.cursor()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql, params)
|
||||
columns = [d[0] for d in cursor.description or []]
|
||||
return [dict(zip(columns, row)) for row in cursor.fetchall()]
|
||||
|
||||
def execute_write(self, sql: str, params: tuple = ()) -> int:
|
||||
"""INSERT/UPDATE/DELETE 실행 후 lastrowid 반환."""
|
||||
if self._conn is None:
|
||||
conn = self._get_conn()
|
||||
if conn is None:
|
||||
return 0
|
||||
cursor = self._conn.cursor()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql, params)
|
||||
self._conn.commit()
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
|
||||
def init_schema(self) -> None:
|
||||
if self._conn is None:
|
||||
conn = self._get_conn()
|
||||
if conn is None:
|
||||
return
|
||||
cursor = self._conn.cursor()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS td_conversations (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
user_id VARCHAR(50) NOT NULL DEFAULT 'default',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
@@ -55,9 +89,35 @@ class DatabaseService:
|
||||
FOREIGN KEY (conversation_id) REFERENCES td_conversations(id)
|
||||
)
|
||||
""")
|
||||
self._conn.commit()
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS td_user_profile (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
user_id VARCHAR(50) NOT NULL DEFAULT 'default',
|
||||
key_name VARCHAR(100) NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
UNIQUE KEY uq_user_key (user_id, key_name)
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
self._migrate_schema(conn)
|
||||
|
||||
def _migrate_schema(self, conn) -> None:
|
||||
cursor = conn.cursor()
|
||||
for sql in [
|
||||
"ALTER TABLE td_conversations ADD COLUMN user_id VARCHAR(50) NOT NULL DEFAULT 'default'",
|
||||
"ALTER TABLE td_user_profile ADD COLUMN user_id VARCHAR(50) NOT NULL DEFAULT 'default'",
|
||||
"ALTER TABLE td_user_profile DROP INDEX key_name",
|
||||
"ALTER TABLE td_user_profile ADD UNIQUE KEY uq_user_key (user_id, key_name)",
|
||||
]:
|
||||
try:
|
||||
cursor.execute(sql)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
if self._conn:
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
conn = getattr(self._local, "conn", None)
|
||||
if conn:
|
||||
conn.close()
|
||||
self._local.conn = None
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
from services.db.mysql_service import DatabaseService
|
||||
|
||||
|
||||
class UserProfileRepository:
|
||||
"""td_user_profile 테이블을 통한 사용자 장기 메모리 저장소."""
|
||||
|
||||
def __init__(self, db: DatabaseService):
|
||||
self._db = db
|
||||
|
||||
def remember(self, key: str, value: str, user_id: str = "default") -> None:
|
||||
self._db.execute_write(
|
||||
"""INSERT INTO td_user_profile (user_id, key_name, value)
|
||||
VALUES (%s, %s, %s)
|
||||
ON DUPLICATE KEY UPDATE value = VALUES(value), updated_at = NOW()""",
|
||||
(user_id, key, value),
|
||||
)
|
||||
|
||||
def recall(self, key: str, user_id: str = "default") -> str | None:
|
||||
rows = self._db.execute(
|
||||
"SELECT value FROM td_user_profile WHERE user_id = %s AND key_name = %s",
|
||||
(user_id, key),
|
||||
)
|
||||
return rows[0]["value"] if rows else None
|
||||
|
||||
def get_all(self, user_id: str = "default") -> dict[str, str]:
|
||||
rows = self._db.execute(
|
||||
"SELECT key_name, value FROM td_user_profile WHERE user_id = %s ORDER BY updated_at",
|
||||
(user_id,),
|
||||
)
|
||||
return {r["key_name"]: r["value"] for r in rows}
|
||||
@@ -0,0 +1,337 @@
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Iterator, List, Optional
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from pydantic import PrivateAttr, model_validator
|
||||
|
||||
_TOOL_CALL_RE = re.compile(r"<tool_call>\s*(.*?)\s*</tool_call>", re.DOTALL)
|
||||
_THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL)
|
||||
|
||||
|
||||
class MlxChatModel(BaseChatModel):
|
||||
"""mlx-lm 기반 LangChain BaseChatModel.
|
||||
|
||||
LangGraph와 완전 호환 — Tool Calling, 스트리밍, bind_tools() 지원.
|
||||
Qwen3 thinking 모드 지원 — <think> 블록을 content와 분리해 additional_kwargs에 저장.
|
||||
"""
|
||||
|
||||
model_id: str
|
||||
max_tokens: int = 1024
|
||||
temp: float = 0.0
|
||||
enable_thinking: bool = True
|
||||
|
||||
_model: Any = PrivateAttr(default=None)
|
||||
_tokenizer: Any = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _load(self) -> "MlxChatModel":
|
||||
from mlx_lm import load
|
||||
print(f"모델 로딩 중: {self.model_id}")
|
||||
self._model, self._tokenizer = load(self.model_id)
|
||||
return self
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "mlx-chat"
|
||||
|
||||
# ── 메시지 → chat dict 변환 ───────────────────────────────────
|
||||
|
||||
def _to_chat_dicts(self, messages: List[BaseMessage]) -> List[dict]:
|
||||
result = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, SystemMessage):
|
||||
result.append({"role": "system", "content": str(msg.content)})
|
||||
elif isinstance(msg, HumanMessage):
|
||||
result.append({"role": "user", "content": str(msg.content)})
|
||||
elif isinstance(msg, AIMessage):
|
||||
if msg.tool_calls:
|
||||
result.append({
|
||||
"role": "assistant",
|
||||
"content": str(msg.content) if msg.content else "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc["name"],
|
||||
"arguments": json.dumps(tc["args"]),
|
||||
},
|
||||
}
|
||||
for tc in msg.tool_calls
|
||||
],
|
||||
})
|
||||
else:
|
||||
result.append({"role": "assistant", "content": str(msg.content)})
|
||||
elif isinstance(msg, ToolMessage):
|
||||
result.append({
|
||||
"role": "tool",
|
||||
"content": str(msg.content),
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
})
|
||||
return result
|
||||
|
||||
def _build_prompt(self, messages: List[BaseMessage], tools: Optional[list] = None) -> str:
|
||||
kwargs: dict = {
|
||||
"tokenize": False,
|
||||
"add_generation_prompt": True,
|
||||
}
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
# Qwen3 thinking 모드 — 지원하지 않는 모델은 무시됨
|
||||
try:
|
||||
kwargs["enable_thinking"] = self.enable_thinking
|
||||
return self._tokenizer.apply_chat_template(self._to_chat_dicts(messages), **kwargs)
|
||||
except TypeError:
|
||||
kwargs.pop("enable_thinking")
|
||||
return self._tokenizer.apply_chat_template(self._to_chat_dicts(messages), **kwargs)
|
||||
|
||||
# ── <think> 블록 파싱 (Qwen3) ────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _parse_thinking(text: str) -> tuple[str, str]:
|
||||
"""<think>...</think> 블록을 분리해 (thinking, clean_text) 반환."""
|
||||
match = _THINK_RE.search(text)
|
||||
if not match:
|
||||
return "", text
|
||||
thinking = match.group(1).strip()
|
||||
clean = _THINK_RE.sub("", text).strip()
|
||||
return thinking, clean
|
||||
|
||||
# ── Tool Call 파싱 ────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _parse_tool_calls(text: str) -> tuple[str, list]:
|
||||
matches = _TOOL_CALL_RE.findall(text)
|
||||
if not matches:
|
||||
return text, []
|
||||
|
||||
tool_calls = []
|
||||
for raw in matches:
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
tool_calls.append({
|
||||
"id": f"call_{uuid.uuid4().hex[:8]}",
|
||||
"name": data["name"],
|
||||
"args": data.get("arguments", data.get("args", {})),
|
||||
"type": "tool_call",
|
||||
})
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
||||
clean = _TOOL_CALL_RE.sub("", text).strip()
|
||||
return clean, tool_calls
|
||||
|
||||
# ── LangChain BaseChatModel 인터페이스 ────────────────────────
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager=None,
|
||||
**kwargs,
|
||||
) -> ChatResult:
|
||||
from mlx_lm import generate
|
||||
|
||||
tools = kwargs.get("tools")
|
||||
prompt = self._build_prompt(messages, tools)
|
||||
text = generate(
|
||||
self._model,
|
||||
self._tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=self.max_tokens,
|
||||
verbose=False,
|
||||
)
|
||||
thinking, after_think = self._parse_thinking(text)
|
||||
clean_text, tool_calls = self._parse_tool_calls(after_think)
|
||||
extra = {"thinking": thinking} if thinking else {}
|
||||
message = AIMessage(content=clean_text, tool_calls=tool_calls, additional_kwargs=extra)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager=None,
|
||||
**kwargs,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
from mlx_lm import stream_generate
|
||||
|
||||
tools = kwargs.get("tools")
|
||||
prompt = self._build_prompt(messages, tools)
|
||||
|
||||
OPEN_THINK = "<think>"
|
||||
CLOSE_THINK = "</think>"
|
||||
OPEN_TOOL = "<tool_call>"
|
||||
CLOSE_TOOL = "</tool_call>"
|
||||
SAFE = max(len(OPEN_THINK), len(CLOSE_THINK), len(OPEN_TOOL), len(CLOSE_TOOL))
|
||||
|
||||
# enable_thinking=False 모델은 <think> 블록을 생성하지 않으므로 post_think에서 시작
|
||||
state = "pre_think" if self.enable_thinking else "post_think"
|
||||
buf = ""
|
||||
out: list[ChatGenerationChunk] = []
|
||||
|
||||
def _think(text: str) -> None:
|
||||
out.append(ChatGenerationChunk(
|
||||
message=AIMessageChunk(content="", additional_kwargs={"thinking": text})
|
||||
))
|
||||
|
||||
def _content(text: str) -> None:
|
||||
out.append(ChatGenerationChunk(message=AIMessageChunk(content=text)))
|
||||
|
||||
def _tool(raw_json: str) -> None:
|
||||
try:
|
||||
data = json.loads(raw_json)
|
||||
tc = {
|
||||
"id": f"call_{uuid.uuid4().hex[:8]}",
|
||||
"name": data["name"],
|
||||
"args": data.get("arguments", data.get("args", {})),
|
||||
"type": "tool_call",
|
||||
}
|
||||
out.append(ChatGenerationChunk(message=AIMessageChunk(content="", tool_calls=[tc])))
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
def advance() -> None:
|
||||
nonlocal state, buf
|
||||
while buf:
|
||||
if state == "pre_think":
|
||||
idx = buf.find(OPEN_THINK)
|
||||
if idx == -1:
|
||||
safe = len(buf) - SAFE
|
||||
if safe > 0:
|
||||
_content(buf[:safe])
|
||||
buf = buf[safe:]
|
||||
return
|
||||
if idx > 0:
|
||||
_content(buf[:idx])
|
||||
buf = buf[idx + len(OPEN_THINK):]
|
||||
state = "in_think"
|
||||
|
||||
elif state == "in_think":
|
||||
idx = buf.find(CLOSE_THINK)
|
||||
if idx == -1:
|
||||
safe = len(buf) - SAFE
|
||||
if safe > 0:
|
||||
_think(buf[:safe])
|
||||
buf = buf[safe:]
|
||||
return
|
||||
if idx > 0:
|
||||
_think(buf[:idx])
|
||||
buf = buf[idx + len(CLOSE_THINK):].lstrip()
|
||||
state = "post_think"
|
||||
|
||||
elif state == "post_think":
|
||||
# </think> 이후 \n\n 같은 공백을 건너뜀
|
||||
buf = buf.lstrip()
|
||||
if not buf:
|
||||
return
|
||||
idx = buf.find(OPEN_TOOL)
|
||||
if idx == -1:
|
||||
# partial tag at end — hold and wait
|
||||
for i in range(len(OPEN_TOOL) - 1, 0, -1):
|
||||
if buf.endswith(OPEN_TOOL[:i]):
|
||||
safe_text = buf[:-i]
|
||||
if safe_text:
|
||||
_content(safe_text)
|
||||
buf = buf[-i:]
|
||||
return
|
||||
state = "in_answer" # no tool call coming
|
||||
elif idx == 0:
|
||||
buf = buf[len(OPEN_TOOL):]
|
||||
state = "in_tool"
|
||||
else:
|
||||
_content(buf[:idx])
|
||||
buf = buf[idx + len(OPEN_TOOL):]
|
||||
state = "in_tool"
|
||||
|
||||
elif state == "in_answer":
|
||||
_content(buf)
|
||||
buf = ""
|
||||
return
|
||||
|
||||
elif state == "in_tool":
|
||||
idx = buf.find(CLOSE_TOOL)
|
||||
if idx == -1:
|
||||
return # wait for complete JSON
|
||||
_tool(buf[:idx].strip())
|
||||
buf = buf[idx + len(CLOSE_TOOL):]
|
||||
state = "post_think" # may have more tool calls
|
||||
|
||||
for raw in stream_generate(self._model, self._tokenizer, prompt=prompt, max_tokens=self.max_tokens):
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(raw.text)
|
||||
buf += raw.text
|
||||
advance()
|
||||
yield from out
|
||||
out.clear()
|
||||
|
||||
# flush remaining buffer
|
||||
if buf:
|
||||
if state == "in_think":
|
||||
_think(buf)
|
||||
elif state == "in_answer":
|
||||
_content(buf)
|
||||
elif state in ("pre_think", "post_think"):
|
||||
clean, tcs = self._parse_tool_calls(buf)
|
||||
if clean:
|
||||
_content(clean)
|
||||
for tc in tcs:
|
||||
out.append(ChatGenerationChunk(message=AIMessageChunk(content="", tool_calls=[tc])))
|
||||
elif state == "in_tool":
|
||||
_tool(buf.strip())
|
||||
yield from out
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager=None,
|
||||
**kwargs,
|
||||
):
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
queue: asyncio.Queue = asyncio.Queue(maxsize=200)
|
||||
sentinel = object()
|
||||
exc_holder: list = []
|
||||
|
||||
def _run() -> None:
|
||||
try:
|
||||
for chunk in self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
asyncio.run_coroutine_threadsafe(queue.put(chunk), loop).result()
|
||||
except Exception as exc:
|
||||
exc_holder.append(exc)
|
||||
finally:
|
||||
asyncio.run_coroutine_threadsafe(queue.put(sentinel), loop).result()
|
||||
|
||||
thread = threading.Thread(target=_run, daemon=True)
|
||||
thread.start()
|
||||
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item is sentinel:
|
||||
break
|
||||
yield item
|
||||
|
||||
thread.join(timeout=5)
|
||||
if exc_holder:
|
||||
raise exc_holder[0]
|
||||
|
||||
def bind_tools(self, tools, tool_choice=None, **kwargs):
|
||||
formatted = [convert_to_openai_tool(t) for t in tools]
|
||||
return self.bind(tools=formatted, **kwargs)
|
||||
@@ -0,0 +1,107 @@
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
from langchain_community.document_loaders import PDFPlumberLoader, TextLoader
|
||||
from langchain_core.documents import Document
|
||||
from langchain_qdrant import QdrantVectorStore
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import Filter, FieldCondition, MatchValue, FilterSelector
|
||||
|
||||
|
||||
def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-10))
|
||||
|
||||
|
||||
class _SemanticSplitter:
|
||||
"""문장 임베딩 유사도 기반 청커.
|
||||
|
||||
인접 문장 간 코사인 유사도를 계산하고, 유사도가 낮은(= 의미 전환) 지점에서 청크를 분리한다.
|
||||
breakpoint_percentile=95이면 유사도 하위 5% 지점이 분리 경계가 된다.
|
||||
"""
|
||||
|
||||
_SENTENCE_RE = re.compile(r"(?<=[.!?。!?])\s+")
|
||||
|
||||
def __init__(self, embeddings, breakpoint_percentile: int = 95):
|
||||
self._embeddings = embeddings
|
||||
self._percentile = breakpoint_percentile
|
||||
|
||||
def split_documents(self, docs: list[Document]) -> list[Document]:
|
||||
result = []
|
||||
for doc in docs:
|
||||
for chunk_text in self._split_text(doc.page_content):
|
||||
result.append(Document(page_content=chunk_text, metadata=doc.metadata))
|
||||
return result
|
||||
|
||||
def _split_text(self, text: str) -> list[str]:
|
||||
sentences = [s for s in self._SENTENCE_RE.split(text.strip()) if s.strip()]
|
||||
if len(sentences) <= 1:
|
||||
return [text.strip()] if text.strip() else []
|
||||
|
||||
vecs = np.array(self._embeddings.embed_documents(sentences))
|
||||
similarities = [_cosine_similarity(vecs[i], vecs[i + 1]) for i in range(len(vecs) - 1)]
|
||||
threshold = float(np.percentile(similarities, 100 - self._percentile))
|
||||
breakpoints = [i + 1 for i, s in enumerate(similarities) if s < threshold]
|
||||
|
||||
chunks, start = [], 0
|
||||
for bp in breakpoints:
|
||||
chunk = " ".join(sentences[start:bp]).strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
start = bp
|
||||
tail = " ".join(sentences[start:]).strip()
|
||||
if tail:
|
||||
chunks.append(tail)
|
||||
return chunks
|
||||
|
||||
|
||||
class IngestionService:
|
||||
"""문서를 의미 단위 청크로 분할해 Qdrant에 저장하는 수집 파이프라인."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embeddings,
|
||||
qdrant_url: str,
|
||||
collection_name: str,
|
||||
breakpoint_threshold_type: str = "percentile",
|
||||
):
|
||||
self._embeddings = embeddings
|
||||
self._qdrant_url = qdrant_url
|
||||
self._collection_name = collection_name
|
||||
# breakpoint_threshold_type은 향후 확장용으로 수용 (현재는 percentile 방식 고정)
|
||||
self._splitter = _SemanticSplitter(embeddings, breakpoint_percentile=95)
|
||||
self._client = QdrantClient(url=qdrant_url)
|
||||
|
||||
def _delete_by_source(self, source_path: str) -> None:
|
||||
"""같은 파일 경로로 저장된 기존 청크를 모두 삭제한다."""
|
||||
try:
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
key="metadata.source",
|
||||
match=MatchValue(value=source_path),
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
pass # 컬렉션이 없을 때(최초 수집) 무시
|
||||
|
||||
def ingest(self, file_paths: list[str]) -> int:
|
||||
docs = []
|
||||
for path in file_paths:
|
||||
self._delete_by_source(path)
|
||||
loader = PDFPlumberLoader(path) if path.endswith(".pdf") else TextLoader(path, encoding="utf-8")
|
||||
docs.extend(loader.load())
|
||||
|
||||
chunks = self._splitter.split_documents(docs)
|
||||
QdrantVectorStore.from_documents(
|
||||
documents=chunks,
|
||||
embedding=self._embeddings,
|
||||
url=self._qdrant_url,
|
||||
collection_name=self._collection_name,
|
||||
)
|
||||
return len(chunks)
|
||||
@@ -0,0 +1,67 @@
|
||||
from langchain_core.documents import Document
|
||||
from langchain_qdrant import QdrantVectorStore
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import Filter, FieldCondition, MatchValue, FilterSelector
|
||||
|
||||
|
||||
class RetrieverService:
|
||||
"""Qdrant 벡터 검색 서비스. LangGraph Tool 및 직접 검색 모두 지원."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embeddings,
|
||||
qdrant_url: str,
|
||||
collection_name: str,
|
||||
top_k: int,
|
||||
):
|
||||
self._client = QdrantClient(url=qdrant_url)
|
||||
self._collection_name = collection_name
|
||||
self._store = QdrantVectorStore(
|
||||
client=self._client,
|
||||
collection_name=collection_name,
|
||||
embedding=embeddings,
|
||||
)
|
||||
self._top_k = top_k
|
||||
|
||||
def as_retriever(self):
|
||||
return self._store.as_retriever(search_kwargs={"k": self._top_k})
|
||||
|
||||
def search(self, query: str) -> list[Document]:
|
||||
return self._store.similarity_search(query, k=self._top_k)
|
||||
|
||||
def list_documents(self) -> list[str]:
|
||||
"""Qdrant에 저장된 고유 파일 경로 목록을 반환한다."""
|
||||
sources: set[str] = set()
|
||||
offset = None
|
||||
while True:
|
||||
results, next_offset = self._client.scroll(
|
||||
collection_name=self._collection_name,
|
||||
with_payload=True,
|
||||
limit=200,
|
||||
offset=offset,
|
||||
)
|
||||
for point in results:
|
||||
src = (point.payload or {}).get("metadata", {}).get("source", "")
|
||||
if src:
|
||||
sources.add(src)
|
||||
if next_offset is None:
|
||||
break
|
||||
offset = next_offset
|
||||
return sorted(sources)
|
||||
|
||||
def delete_document(self, source: str) -> None:
|
||||
"""파일 경로로 저장된 모든 청크를 Qdrant에서 삭제한다."""
|
||||
try:
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=Filter(
|
||||
must=[FieldCondition(
|
||||
key="metadata.source",
|
||||
match=MatchValue(value=source),
|
||||
)]
|
||||
)
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
Reference in New Issue
Block a user