0b50444e43
- IDEA-2 스마트 알림: td_reminders 테이블, set_reminder/list_reminders 도구,
SchedulerService(asyncio 60초 루프, D-7/D-1/D-0 Telegram push),
FastAPI lifespan 연동, GET /reminders/{user_id} 엔드포인트
- IDEA-1 대화 기반 RAG: IngestionService.store_text() 추가,
AgentService._maybe_index_conversation() — 응답 후 LLM 판단 → Qdrant 저장
(CONV_RAG_ENABLED=true 활성화, background task로 응답 속도 무관)
- IDEA-5 CRAG: AgentState에 crag_fallback_used 플래그 추가,
crag_check LangGraph 노드 — search_documents 결과 없으면 web_search 자동 주입,
route_after_crag으로 fallback 1회 루프 제어 (CRAG_ENABLED=true 활성화)
- IDEA-7 RAG Auto-Eval: eval/auto_tune.py — API 서버 없이 파라미터 조합별
context_precision/recall 비교, 최적 설정 추천
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
192 lines
7.3 KiB
Python
192 lines
7.3 KiB
Python
"""RAG 파라미터 자동 튜닝 스크립트 (IDEA-7)
|
|
|
|
API 서버 없이 RetrieverService를 직접 사용해 파라미터 조합별 context 품질을 비교한다.
|
|
평가 지표: context_precision, context_recall (RAGAS)
|
|
|
|
실행:
|
|
python eval/auto_tune.py [--dataset eval/dataset.jsonl]
|
|
|
|
출력:
|
|
eval/results/tune_YYYYMMDD_HHMMSS.json — 조합별 점수 및 추천 파라미터
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parent.parent
|
|
sys.path.insert(0, str(ROOT))
|
|
os.chdir(ROOT)
|
|
|
|
from dotenv import load_dotenv
|
|
load_dotenv(ROOT / ".env")
|
|
|
|
# ── Compatibility shim (run_ragas.py 동일) ─────────────────────────────────────
|
|
try:
|
|
import langchain_community.chat_models.vertexai # noqa: F401
|
|
except ModuleNotFoundError:
|
|
try:
|
|
from langchain_google_vertexai import ChatVertexAI as _CV
|
|
_stub = type(sys)("langchain_community.chat_models.vertexai")
|
|
_stub.ChatVertexAI = _CV
|
|
sys.modules["langchain_community.chat_models.vertexai"] = _stub
|
|
except ImportError:
|
|
_stub = type(sys)("langchain_community.chat_models.vertexai")
|
|
_stub.ChatVertexAI = object
|
|
sys.modules["langchain_community.chat_models.vertexai"] = _stub
|
|
|
|
from ragas import evaluate
|
|
from ragas.metrics import context_precision, context_recall
|
|
from ragas.embeddings import LangchainEmbeddingsWrapper
|
|
from ragas.llms import LangchainLLMWrapper
|
|
from datasets import Dataset
|
|
from ragas.run_config import RunConfig
|
|
|
|
from container import Container
|
|
from services.rag.retriever_service import RetrieverService
|
|
|
|
_container = Container()
|
|
_container.db_service().connect()
|
|
_container.db_service().init_schema()
|
|
_cfg = _container.config()
|
|
|
|
# ── 튜닝 대상 파라미터 조합 ────────────────────────────────────────────────────
|
|
|
|
VARIANTS = [
|
|
{"name": "baseline", "top_k": 3, "rerank_fetch_k": 10},
|
|
{"name": "top_k_5", "top_k": 5, "rerank_fetch_k": 15},
|
|
{"name": "top_k_2", "top_k": 2, "rerank_fetch_k": 6},
|
|
{"name": "fetch_k_20", "top_k": 3, "rerank_fetch_k": 20},
|
|
]
|
|
|
|
|
|
def _build_retriever(top_k: int, rerank_fetch_k: int) -> RetrieverService:
|
|
return RetrieverService(
|
|
embeddings=_container.embeddings(),
|
|
qdrant_url=_cfg.qdrant_url,
|
|
collection_name=_cfg.qdrant_collection,
|
|
top_k=top_k,
|
|
reranker=_container.reranker() if _cfg.reranker_enabled else None,
|
|
rerank_fetch_k=rerank_fetch_k,
|
|
sparse_embeddings=_container.sparse_embeddings() if _cfg.hybrid_search_enabled else None,
|
|
)
|
|
|
|
|
|
def _build_evaluator():
|
|
if os.getenv("OPENAI_API_KEY"):
|
|
from langchain_openai import ChatOpenAI
|
|
print("[AutoTune] 평가 LLM: OpenAI GPT-4o-mini")
|
|
return LangchainLLMWrapper(ChatOpenAI(model="gpt-4o-mini", temperature=0))
|
|
if os.getenv("ANTHROPIC_API_KEY"):
|
|
from langchain_anthropic import ChatAnthropic
|
|
print("[AutoTune] 평가 LLM: Claude Haiku")
|
|
return LangchainLLMWrapper(ChatAnthropic(model="claude-haiku-4-5-20251001", temperature=0))
|
|
print("[AutoTune] 평가 LLM: 로컬 Qwen3")
|
|
return LangchainLLMWrapper(_container.chat_model())
|
|
|
|
|
|
def run(dataset_path: str) -> None:
|
|
samples = []
|
|
with open(dataset_path, encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line:
|
|
samples.append(json.loads(line))
|
|
|
|
if not samples:
|
|
print(f"[오류] 데이터셋이 비어 있습니다: {dataset_path}")
|
|
sys.exit(1)
|
|
|
|
print(f"[AutoTune] 파라미터 튜닝 시작 — {len(samples)}개 질문, {len(VARIANTS)}개 조합\n")
|
|
|
|
llm = _build_evaluator()
|
|
emb = LangchainEmbeddingsWrapper(_container.embeddings())
|
|
run_cfg = RunConfig(timeout=300, max_retries=1, max_workers=1)
|
|
|
|
results = []
|
|
|
|
for variant in VARIANTS:
|
|
name = variant["name"]
|
|
print(f"── {name} (top_k={variant['top_k']}, fetch_k={variant['rerank_fetch_k']}) ──")
|
|
retriever = _build_retriever(variant["top_k"], variant["rerank_fetch_k"])
|
|
|
|
questions, ground_truths, contexts = [], [], []
|
|
for s in samples:
|
|
q = s["question"]
|
|
docs = retriever.search(q)
|
|
contexts.append([d.page_content for d in docs])
|
|
questions.append(q)
|
|
ground_truths.append(s["ground_truth"])
|
|
print(f" [{q[:40]}] → {len(docs)}개 청크")
|
|
|
|
ds = Dataset.from_dict({
|
|
"question": questions,
|
|
"contexts": contexts,
|
|
"ground_truth": ground_truths,
|
|
})
|
|
|
|
result = evaluate(
|
|
ds,
|
|
metrics=[context_precision, context_recall],
|
|
llm=llm,
|
|
embeddings=emb,
|
|
run_config=run_cfg,
|
|
raise_exceptions=False,
|
|
)
|
|
df = result.to_pandas()
|
|
|
|
def _score(col: str) -> float | None:
|
|
if col not in df.columns:
|
|
return None
|
|
val = df[col].dropna().mean()
|
|
return float(val) if val == val else None
|
|
|
|
scores = {
|
|
"context_precision": _score("context_precision"),
|
|
"context_recall": _score("context_recall"),
|
|
}
|
|
avg = sum(v for v in scores.values() if v is not None) / max(
|
|
sum(1 for v in scores.values() if v is not None), 1
|
|
)
|
|
results.append({**variant, "scores": scores, "avg": avg})
|
|
print(f" precision={scores['context_precision']}, recall={scores['context_recall']}, avg={avg:.3f}\n")
|
|
|
|
# ── 결과 출력 ──────────────────────────────────────────────────────────────
|
|
best = max(results, key=lambda r: r["avg"])
|
|
|
|
print("=" * 60)
|
|
print("AutoTune 결과")
|
|
print("=" * 60)
|
|
header = f"{'조합':<14} {'precision':>10} {'recall':>10} {'avg':>8}"
|
|
print(header)
|
|
print("-" * 60)
|
|
for r in sorted(results, key=lambda x: x["avg"], reverse=True):
|
|
marker = " ★" if r["name"] == best["name"] else ""
|
|
prec = f"{r['scores']['context_precision']:.3f}" if r['scores']['context_precision'] else "N/A"
|
|
rec = f"{r['scores']['context_recall']:.3f}" if r['scores']['context_recall'] else "N/A"
|
|
print(f" {r['name']:<12} {prec:>10} {rec:>10} {r['avg']:>8.3f}{marker}")
|
|
print("=" * 60)
|
|
print(f"\n추천: top_k={best['top_k']}, rerank_fetch_k={best['rerank_fetch_k']} ({best['name']})")
|
|
print(f" .env에 RAG_TOP_K={best['top_k']}, RERANKER_FETCH_K={best['rerank_fetch_k']} 설정\n")
|
|
|
|
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
results_dir = ROOT / "eval" / "results"
|
|
results_dir.mkdir(exist_ok=True)
|
|
out = results_dir / f"tune_{ts}.json"
|
|
out.write_text(
|
|
json.dumps({"timestamp": ts, "best": best, "all": results}, ensure_ascii=False, indent=2),
|
|
encoding="utf-8",
|
|
)
|
|
print(f"JSON 저장: {out}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="RAG 파라미터 자동 튜닝")
|
|
parser.add_argument("--dataset", default=str(ROOT / "eval" / "dataset.jsonl"))
|
|
args = parser.parse_args()
|
|
run(args.dataset)
|