Files
youlbot/services/model/mlx_chat_model.py
T
shinalok 145b0cc96f Implement Phase 12 feedback, Phase 13 Semantic Chunker, Phase 13-B Reranker, Bug 5 thinking fix
- Phase 12: FeedbackRepository + td_feedback 테이블, Gradio 👍/👎 이벤트, run_id 추적, LangSmith create_feedback() 연동
- Phase 13: 커스텀 _SemanticSplitter 제거 → langchain_experimental.SemanticChunker 교체, buffer_size/threshold_type 환경변수 적용
- Phase 13-B: RerankService (Cross-Encoder), RetrieverService.search()에 reranker 통합, tools.py as_retriever() → search() 전환
- Bug 5: mlx_chat_model enable_thinking 런타임 오버라이드, agent_service stream_mode=["messages","custom"] 이중 스트림, thinking 토큰 custom 이벤트로 emit
- ROADMAP: LLM 모델명 8B 반영, RAG에 Reranker 추가, 추천 진행 순서 갱신

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-29 17:41:36 +09:00

347 lines
13 KiB
Python

import json
import re
import uuid
from typing import Any, Iterator, List, Optional
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import PrivateAttr, model_validator
_TOOL_CALL_RE = re.compile(r"<tool_call>\s*(.*?)\s*</tool_call>", re.DOTALL)
_THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL)
class MlxChatModel(BaseChatModel):
"""mlx-lm 기반 LangChain BaseChatModel.
LangGraph와 완전 호환 — Tool Calling, 스트리밍, bind_tools() 지원.
Qwen3 thinking 모드 지원 — <think> 블록을 content와 분리해 additional_kwargs에 저장.
"""
model_id: str
max_tokens: int = 1024
temp: float = 0.0
enable_thinking: bool = True
_model: Any = PrivateAttr(default=None)
_tokenizer: Any = PrivateAttr(default=None)
@model_validator(mode="after")
def _load(self) -> "MlxChatModel":
from mlx_lm import load
print(f"모델 로딩 중: {self.model_id}")
self._model, self._tokenizer = load(self.model_id)
return self
@property
def _llm_type(self) -> str:
return "mlx-chat"
# ── 메시지 → chat dict 변환 ───────────────────────────────────
def _to_chat_dicts(self, messages: List[BaseMessage]) -> List[dict]:
result = []
for msg in messages:
if isinstance(msg, SystemMessage):
result.append({"role": "system", "content": str(msg.content)})
elif isinstance(msg, HumanMessage):
result.append({"role": "user", "content": str(msg.content)})
elif isinstance(msg, AIMessage):
if msg.tool_calls:
result.append({
"role": "assistant",
"content": str(msg.content) if msg.content else "",
"tool_calls": [
{
"id": tc["id"],
"type": "function",
"function": {
"name": tc["name"],
"arguments": json.dumps(tc["args"]),
},
}
for tc in msg.tool_calls
],
})
else:
result.append({"role": "assistant", "content": str(msg.content)})
elif isinstance(msg, ToolMessage):
result.append({
"role": "tool",
"content": str(msg.content),
"tool_call_id": msg.tool_call_id,
})
return result
def _build_prompt(
self,
messages: List[BaseMessage],
tools: Optional[list] = None,
enable_thinking: Optional[bool] = None,
) -> str:
_enable_thinking = enable_thinking if enable_thinking is not None else self.enable_thinking
kwargs: dict = {
"tokenize": False,
"add_generation_prompt": True,
}
if tools:
kwargs["tools"] = tools
# Qwen3 thinking 모드 — 지원하지 않는 모델은 무시됨
try:
kwargs["enable_thinking"] = _enable_thinking
return self._tokenizer.apply_chat_template(self._to_chat_dicts(messages), **kwargs)
except TypeError:
kwargs.pop("enable_thinking")
return self._tokenizer.apply_chat_template(self._to_chat_dicts(messages), **kwargs)
# ── <think> 블록 파싱 (Qwen3) ────────────────────────────────
@staticmethod
def _parse_thinking(text: str) -> tuple[str, str]:
"""<think>...</think> 블록을 분리해 (thinking, clean_text) 반환."""
match = _THINK_RE.search(text)
if not match:
return "", text
thinking = match.group(1).strip()
clean = _THINK_RE.sub("", text).strip()
return thinking, clean
# ── Tool Call 파싱 ────────────────────────────────────────────
@staticmethod
def _parse_tool_calls(text: str) -> tuple[str, list]:
matches = _TOOL_CALL_RE.findall(text)
if not matches:
return text, []
tool_calls = []
for raw in matches:
try:
data = json.loads(raw)
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:8]}",
"name": data["name"],
"args": data.get("arguments", data.get("args", {})),
"type": "tool_call",
})
except (json.JSONDecodeError, KeyError):
continue
clean = _TOOL_CALL_RE.sub("", text).strip()
return clean, tool_calls
# ── LangChain BaseChatModel 인터페이스 ────────────────────────
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager=None,
**kwargs,
) -> ChatResult:
from mlx_lm import generate
tools = kwargs.get("tools")
enable_thinking_override = kwargs.pop("enable_thinking", None)
prompt = self._build_prompt(messages, tools, enable_thinking=enable_thinking_override)
text = generate(
self._model,
self._tokenizer,
prompt=prompt,
max_tokens=self.max_tokens,
verbose=False,
)
thinking, after_think = self._parse_thinking(text)
clean_text, tool_calls = self._parse_tool_calls(after_think)
extra = {"thinking": thinking} if thinking else {}
message = AIMessage(content=clean_text, tool_calls=tool_calls, additional_kwargs=extra)
return ChatResult(generations=[ChatGeneration(message=message)])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager=None,
**kwargs,
) -> Iterator[ChatGenerationChunk]:
from mlx_lm import stream_generate
tools = kwargs.get("tools")
enable_thinking_override = kwargs.pop("enable_thinking", None)
_enable_thinking = enable_thinking_override if enable_thinking_override is not None else self.enable_thinking
prompt = self._build_prompt(messages, tools, enable_thinking=_enable_thinking)
OPEN_THINK = "<think>"
CLOSE_THINK = "</think>"
OPEN_TOOL = "<tool_call>"
CLOSE_TOOL = "</tool_call>"
SAFE = max(len(OPEN_THINK), len(CLOSE_THINK), len(OPEN_TOOL), len(CLOSE_TOOL))
# enable_thinking=False 모델은 <think> 블록을 생성하지 않으므로 post_think에서 시작
state = "pre_think" if _enable_thinking else "post_think"
buf = ""
out: list[ChatGenerationChunk] = []
def _think(text: str) -> None:
out.append(ChatGenerationChunk(
message=AIMessageChunk(content="", additional_kwargs={"thinking": text})
))
def _content(text: str) -> None:
out.append(ChatGenerationChunk(message=AIMessageChunk(content=text)))
def _tool(raw_json: str) -> None:
try:
data = json.loads(raw_json)
tc = {
"id": f"call_{uuid.uuid4().hex[:8]}",
"name": data["name"],
"args": data.get("arguments", data.get("args", {})),
"type": "tool_call",
}
out.append(ChatGenerationChunk(message=AIMessageChunk(content="", tool_calls=[tc])))
except (json.JSONDecodeError, KeyError):
pass
def advance() -> None:
nonlocal state, buf
while buf:
if state == "pre_think":
idx = buf.find(OPEN_THINK)
if idx == -1:
safe = len(buf) - SAFE
if safe > 0:
_content(buf[:safe])
buf = buf[safe:]
return
if idx > 0:
_content(buf[:idx])
buf = buf[idx + len(OPEN_THINK):]
state = "in_think"
elif state == "in_think":
idx = buf.find(CLOSE_THINK)
if idx == -1:
safe = len(buf) - SAFE
if safe > 0:
_think(buf[:safe])
buf = buf[safe:]
return
if idx > 0:
_think(buf[:idx])
buf = buf[idx + len(CLOSE_THINK):].lstrip()
state = "post_think"
elif state == "post_think":
# </think> 이후 \n\n 같은 공백을 건너뜀
buf = buf.lstrip()
if not buf:
return
idx = buf.find(OPEN_TOOL)
if idx == -1:
# partial tag at end — hold and wait
for i in range(len(OPEN_TOOL) - 1, 0, -1):
if buf.endswith(OPEN_TOOL[:i]):
safe_text = buf[:-i]
if safe_text:
_content(safe_text)
buf = buf[-i:]
return
state = "in_answer" # no tool call coming
elif idx == 0:
buf = buf[len(OPEN_TOOL):]
state = "in_tool"
else:
_content(buf[:idx])
buf = buf[idx + len(OPEN_TOOL):]
state = "in_tool"
elif state == "in_answer":
_content(buf)
buf = ""
return
elif state == "in_tool":
idx = buf.find(CLOSE_TOOL)
if idx == -1:
return # wait for complete JSON
_tool(buf[:idx].strip())
buf = buf[idx + len(CLOSE_TOOL):]
state = "post_think" # may have more tool calls
for raw in stream_generate(self._model, self._tokenizer, prompt=prompt, max_tokens=self.max_tokens):
if run_manager:
run_manager.on_llm_new_token(raw.text)
buf += raw.text
advance()
yield from out
out.clear()
# flush remaining buffer
if buf:
if state == "in_think":
_think(buf)
elif state == "in_answer":
_content(buf)
elif state in ("pre_think", "post_think"):
clean, tcs = self._parse_tool_calls(buf)
if clean:
_content(clean)
for tc in tcs:
out.append(ChatGenerationChunk(message=AIMessageChunk(content="", tool_calls=[tc])))
elif state == "in_tool":
_tool(buf.strip())
yield from out
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager=None,
**kwargs,
):
import asyncio
import threading
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue(maxsize=200)
sentinel = object()
exc_holder: list = []
def _run() -> None:
try:
for chunk in self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
):
asyncio.run_coroutine_threadsafe(queue.put(chunk), loop).result()
except Exception as exc:
exc_holder.append(exc)
finally:
asyncio.run_coroutine_threadsafe(queue.put(sentinel), loop).result()
thread = threading.Thread(target=_run, daemon=True)
thread.start()
while True:
item = await queue.get()
if item is sentinel:
break
yield item
thread.join(timeout=5)
if exc_holder:
raise exc_holder[0]
def bind_tools(self, tools, tool_choice=None, **kwargs):
formatted = [convert_to_openai_tool(t) for t in tools]
return self.bind(tools=formatted, **kwargs)