145b0cc96f
- Phase 12: FeedbackRepository + td_feedback 테이블, Gradio 👍/👎 이벤트, run_id 추적, LangSmith create_feedback() 연동 - Phase 13: 커스텀 _SemanticSplitter 제거 → langchain_experimental.SemanticChunker 교체, buffer_size/threshold_type 환경변수 적용 - Phase 13-B: RerankService (Cross-Encoder), RetrieverService.search()에 reranker 통합, tools.py as_retriever() → search() 전환 - Bug 5: mlx_chat_model enable_thinking 런타임 오버라이드, agent_service stream_mode=["messages","custom"] 이중 스트림, thinking 토큰 custom 이벤트로 emit - ROADMAP: LLM 모델명 8B 반영, RAG에 Reranker 추가, 추천 진행 순서 갱신 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
347 lines
13 KiB
Python
347 lines
13 KiB
Python
import json
|
|
import re
|
|
import uuid
|
|
from typing import Any, Iterator, List, Optional
|
|
|
|
from langchain_core.language_models import BaseChatModel
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
AIMessageChunk,
|
|
BaseMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
ToolMessage,
|
|
)
|
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
|
from pydantic import PrivateAttr, model_validator
|
|
|
|
_TOOL_CALL_RE = re.compile(r"<tool_call>\s*(.*?)\s*</tool_call>", re.DOTALL)
|
|
_THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL)
|
|
|
|
|
|
class MlxChatModel(BaseChatModel):
|
|
"""mlx-lm 기반 LangChain BaseChatModel.
|
|
|
|
LangGraph와 완전 호환 — Tool Calling, 스트리밍, bind_tools() 지원.
|
|
Qwen3 thinking 모드 지원 — <think> 블록을 content와 분리해 additional_kwargs에 저장.
|
|
"""
|
|
|
|
model_id: str
|
|
max_tokens: int = 1024
|
|
temp: float = 0.0
|
|
enable_thinking: bool = True
|
|
|
|
_model: Any = PrivateAttr(default=None)
|
|
_tokenizer: Any = PrivateAttr(default=None)
|
|
|
|
@model_validator(mode="after")
|
|
def _load(self) -> "MlxChatModel":
|
|
from mlx_lm import load
|
|
print(f"모델 로딩 중: {self.model_id}")
|
|
self._model, self._tokenizer = load(self.model_id)
|
|
return self
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
return "mlx-chat"
|
|
|
|
# ── 메시지 → chat dict 변환 ───────────────────────────────────
|
|
|
|
def _to_chat_dicts(self, messages: List[BaseMessage]) -> List[dict]:
|
|
result = []
|
|
for msg in messages:
|
|
if isinstance(msg, SystemMessage):
|
|
result.append({"role": "system", "content": str(msg.content)})
|
|
elif isinstance(msg, HumanMessage):
|
|
result.append({"role": "user", "content": str(msg.content)})
|
|
elif isinstance(msg, AIMessage):
|
|
if msg.tool_calls:
|
|
result.append({
|
|
"role": "assistant",
|
|
"content": str(msg.content) if msg.content else "",
|
|
"tool_calls": [
|
|
{
|
|
"id": tc["id"],
|
|
"type": "function",
|
|
"function": {
|
|
"name": tc["name"],
|
|
"arguments": json.dumps(tc["args"]),
|
|
},
|
|
}
|
|
for tc in msg.tool_calls
|
|
],
|
|
})
|
|
else:
|
|
result.append({"role": "assistant", "content": str(msg.content)})
|
|
elif isinstance(msg, ToolMessage):
|
|
result.append({
|
|
"role": "tool",
|
|
"content": str(msg.content),
|
|
"tool_call_id": msg.tool_call_id,
|
|
})
|
|
return result
|
|
|
|
def _build_prompt(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
tools: Optional[list] = None,
|
|
enable_thinking: Optional[bool] = None,
|
|
) -> str:
|
|
_enable_thinking = enable_thinking if enable_thinking is not None else self.enable_thinking
|
|
kwargs: dict = {
|
|
"tokenize": False,
|
|
"add_generation_prompt": True,
|
|
}
|
|
if tools:
|
|
kwargs["tools"] = tools
|
|
# Qwen3 thinking 모드 — 지원하지 않는 모델은 무시됨
|
|
try:
|
|
kwargs["enable_thinking"] = _enable_thinking
|
|
return self._tokenizer.apply_chat_template(self._to_chat_dicts(messages), **kwargs)
|
|
except TypeError:
|
|
kwargs.pop("enable_thinking")
|
|
return self._tokenizer.apply_chat_template(self._to_chat_dicts(messages), **kwargs)
|
|
|
|
# ── <think> 블록 파싱 (Qwen3) ────────────────────────────────
|
|
|
|
@staticmethod
|
|
def _parse_thinking(text: str) -> tuple[str, str]:
|
|
"""<think>...</think> 블록을 분리해 (thinking, clean_text) 반환."""
|
|
match = _THINK_RE.search(text)
|
|
if not match:
|
|
return "", text
|
|
thinking = match.group(1).strip()
|
|
clean = _THINK_RE.sub("", text).strip()
|
|
return thinking, clean
|
|
|
|
# ── Tool Call 파싱 ────────────────────────────────────────────
|
|
|
|
@staticmethod
|
|
def _parse_tool_calls(text: str) -> tuple[str, list]:
|
|
matches = _TOOL_CALL_RE.findall(text)
|
|
if not matches:
|
|
return text, []
|
|
|
|
tool_calls = []
|
|
for raw in matches:
|
|
try:
|
|
data = json.loads(raw)
|
|
tool_calls.append({
|
|
"id": f"call_{uuid.uuid4().hex[:8]}",
|
|
"name": data["name"],
|
|
"args": data.get("arguments", data.get("args", {})),
|
|
"type": "tool_call",
|
|
})
|
|
except (json.JSONDecodeError, KeyError):
|
|
continue
|
|
|
|
clean = _TOOL_CALL_RE.sub("", text).strip()
|
|
return clean, tool_calls
|
|
|
|
# ── LangChain BaseChatModel 인터페이스 ────────────────────────
|
|
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager=None,
|
|
**kwargs,
|
|
) -> ChatResult:
|
|
from mlx_lm import generate
|
|
|
|
tools = kwargs.get("tools")
|
|
enable_thinking_override = kwargs.pop("enable_thinking", None)
|
|
prompt = self._build_prompt(messages, tools, enable_thinking=enable_thinking_override)
|
|
text = generate(
|
|
self._model,
|
|
self._tokenizer,
|
|
prompt=prompt,
|
|
max_tokens=self.max_tokens,
|
|
verbose=False,
|
|
)
|
|
thinking, after_think = self._parse_thinking(text)
|
|
clean_text, tool_calls = self._parse_tool_calls(after_think)
|
|
extra = {"thinking": thinking} if thinking else {}
|
|
message = AIMessage(content=clean_text, tool_calls=tool_calls, additional_kwargs=extra)
|
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
|
|
|
def _stream(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager=None,
|
|
**kwargs,
|
|
) -> Iterator[ChatGenerationChunk]:
|
|
from mlx_lm import stream_generate
|
|
|
|
tools = kwargs.get("tools")
|
|
enable_thinking_override = kwargs.pop("enable_thinking", None)
|
|
_enable_thinking = enable_thinking_override if enable_thinking_override is not None else self.enable_thinking
|
|
prompt = self._build_prompt(messages, tools, enable_thinking=_enable_thinking)
|
|
|
|
OPEN_THINK = "<think>"
|
|
CLOSE_THINK = "</think>"
|
|
OPEN_TOOL = "<tool_call>"
|
|
CLOSE_TOOL = "</tool_call>"
|
|
SAFE = max(len(OPEN_THINK), len(CLOSE_THINK), len(OPEN_TOOL), len(CLOSE_TOOL))
|
|
|
|
# enable_thinking=False 모델은 <think> 블록을 생성하지 않으므로 post_think에서 시작
|
|
state = "pre_think" if _enable_thinking else "post_think"
|
|
buf = ""
|
|
out: list[ChatGenerationChunk] = []
|
|
|
|
def _think(text: str) -> None:
|
|
out.append(ChatGenerationChunk(
|
|
message=AIMessageChunk(content="", additional_kwargs={"thinking": text})
|
|
))
|
|
|
|
def _content(text: str) -> None:
|
|
out.append(ChatGenerationChunk(message=AIMessageChunk(content=text)))
|
|
|
|
def _tool(raw_json: str) -> None:
|
|
try:
|
|
data = json.loads(raw_json)
|
|
tc = {
|
|
"id": f"call_{uuid.uuid4().hex[:8]}",
|
|
"name": data["name"],
|
|
"args": data.get("arguments", data.get("args", {})),
|
|
"type": "tool_call",
|
|
}
|
|
out.append(ChatGenerationChunk(message=AIMessageChunk(content="", tool_calls=[tc])))
|
|
except (json.JSONDecodeError, KeyError):
|
|
pass
|
|
|
|
def advance() -> None:
|
|
nonlocal state, buf
|
|
while buf:
|
|
if state == "pre_think":
|
|
idx = buf.find(OPEN_THINK)
|
|
if idx == -1:
|
|
safe = len(buf) - SAFE
|
|
if safe > 0:
|
|
_content(buf[:safe])
|
|
buf = buf[safe:]
|
|
return
|
|
if idx > 0:
|
|
_content(buf[:idx])
|
|
buf = buf[idx + len(OPEN_THINK):]
|
|
state = "in_think"
|
|
|
|
elif state == "in_think":
|
|
idx = buf.find(CLOSE_THINK)
|
|
if idx == -1:
|
|
safe = len(buf) - SAFE
|
|
if safe > 0:
|
|
_think(buf[:safe])
|
|
buf = buf[safe:]
|
|
return
|
|
if idx > 0:
|
|
_think(buf[:idx])
|
|
buf = buf[idx + len(CLOSE_THINK):].lstrip()
|
|
state = "post_think"
|
|
|
|
elif state == "post_think":
|
|
# </think> 이후 \n\n 같은 공백을 건너뜀
|
|
buf = buf.lstrip()
|
|
if not buf:
|
|
return
|
|
idx = buf.find(OPEN_TOOL)
|
|
if idx == -1:
|
|
# partial tag at end — hold and wait
|
|
for i in range(len(OPEN_TOOL) - 1, 0, -1):
|
|
if buf.endswith(OPEN_TOOL[:i]):
|
|
safe_text = buf[:-i]
|
|
if safe_text:
|
|
_content(safe_text)
|
|
buf = buf[-i:]
|
|
return
|
|
state = "in_answer" # no tool call coming
|
|
elif idx == 0:
|
|
buf = buf[len(OPEN_TOOL):]
|
|
state = "in_tool"
|
|
else:
|
|
_content(buf[:idx])
|
|
buf = buf[idx + len(OPEN_TOOL):]
|
|
state = "in_tool"
|
|
|
|
elif state == "in_answer":
|
|
_content(buf)
|
|
buf = ""
|
|
return
|
|
|
|
elif state == "in_tool":
|
|
idx = buf.find(CLOSE_TOOL)
|
|
if idx == -1:
|
|
return # wait for complete JSON
|
|
_tool(buf[:idx].strip())
|
|
buf = buf[idx + len(CLOSE_TOOL):]
|
|
state = "post_think" # may have more tool calls
|
|
|
|
for raw in stream_generate(self._model, self._tokenizer, prompt=prompt, max_tokens=self.max_tokens):
|
|
if run_manager:
|
|
run_manager.on_llm_new_token(raw.text)
|
|
buf += raw.text
|
|
advance()
|
|
yield from out
|
|
out.clear()
|
|
|
|
# flush remaining buffer
|
|
if buf:
|
|
if state == "in_think":
|
|
_think(buf)
|
|
elif state == "in_answer":
|
|
_content(buf)
|
|
elif state in ("pre_think", "post_think"):
|
|
clean, tcs = self._parse_tool_calls(buf)
|
|
if clean:
|
|
_content(clean)
|
|
for tc in tcs:
|
|
out.append(ChatGenerationChunk(message=AIMessageChunk(content="", tool_calls=[tc])))
|
|
elif state == "in_tool":
|
|
_tool(buf.strip())
|
|
yield from out
|
|
|
|
async def _astream(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager=None,
|
|
**kwargs,
|
|
):
|
|
import asyncio
|
|
import threading
|
|
|
|
loop = asyncio.get_running_loop()
|
|
queue: asyncio.Queue = asyncio.Queue(maxsize=200)
|
|
sentinel = object()
|
|
exc_holder: list = []
|
|
|
|
def _run() -> None:
|
|
try:
|
|
for chunk in self._stream(
|
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
):
|
|
asyncio.run_coroutine_threadsafe(queue.put(chunk), loop).result()
|
|
except Exception as exc:
|
|
exc_holder.append(exc)
|
|
finally:
|
|
asyncio.run_coroutine_threadsafe(queue.put(sentinel), loop).result()
|
|
|
|
thread = threading.Thread(target=_run, daemon=True)
|
|
thread.start()
|
|
|
|
while True:
|
|
item = await queue.get()
|
|
if item is sentinel:
|
|
break
|
|
yield item
|
|
|
|
thread.join(timeout=5)
|
|
if exc_holder:
|
|
raise exc_holder[0]
|
|
|
|
def bind_tools(self, tools, tool_choice=None, **kwargs):
|
|
formatted = [convert_to_openai_tool(t) for t in tools]
|
|
return self.bind(tools=formatted, **kwargs)
|