import json import re import uuid from typing import Any, Iterator, List, Optional from langchain_core.language_models import BaseChatModel from langchain_core.messages import ( AIMessage, AIMessageChunk, BaseMessage, HumanMessage, SystemMessage, ToolMessage, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.utils.function_calling import convert_to_openai_tool from pydantic import PrivateAttr, model_validator _TOOL_CALL_RE = re.compile(r"\s*(.*?)\s*", re.DOTALL) _THINK_RE = re.compile(r"(.*?)", re.DOTALL) class MlxChatModel(BaseChatModel): """mlx-lm 기반 LangChain BaseChatModel. LangGraph와 완전 호환 — Tool Calling, 스트리밍, bind_tools() 지원. Qwen3 thinking 모드 지원 — 블록을 content와 분리해 additional_kwargs에 저장. """ model_id: str max_tokens: int = 1024 temp: float = 0.0 enable_thinking: bool = True _model: Any = PrivateAttr(default=None) _tokenizer: Any = PrivateAttr(default=None) @model_validator(mode="after") def _load(self) -> "MlxChatModel": from mlx_lm import load print(f"모델 로딩 중: {self.model_id}") self._model, self._tokenizer = load(self.model_id) return self @property def _llm_type(self) -> str: return "mlx-chat" # ── 메시지 → chat dict 변환 ─────────────────────────────────── def _to_chat_dicts(self, messages: List[BaseMessage]) -> List[dict]: result = [] for msg in messages: if isinstance(msg, SystemMessage): result.append({"role": "system", "content": str(msg.content)}) elif isinstance(msg, HumanMessage): result.append({"role": "user", "content": str(msg.content)}) elif isinstance(msg, AIMessage): if msg.tool_calls: result.append({ "role": "assistant", "content": str(msg.content) if msg.content else "", "tool_calls": [ { "id": tc["id"], "type": "function", "function": { "name": tc["name"], "arguments": json.dumps(tc["args"]), }, } for tc in msg.tool_calls ], }) else: result.append({"role": "assistant", "content": str(msg.content)}) elif isinstance(msg, ToolMessage): result.append({ "role": "tool", "content": str(msg.content), "tool_call_id": msg.tool_call_id, }) return result def _build_prompt(self, messages: List[BaseMessage], tools: Optional[list] = None) -> str: kwargs: dict = { "tokenize": False, "add_generation_prompt": True, } if tools: kwargs["tools"] = tools # Qwen3 thinking 모드 — 지원하지 않는 모델은 무시됨 try: kwargs["enable_thinking"] = self.enable_thinking return self._tokenizer.apply_chat_template(self._to_chat_dicts(messages), **kwargs) except TypeError: kwargs.pop("enable_thinking") return self._tokenizer.apply_chat_template(self._to_chat_dicts(messages), **kwargs) # ── 블록 파싱 (Qwen3) ──────────────────────────────── @staticmethod def _parse_thinking(text: str) -> tuple[str, str]: """... 블록을 분리해 (thinking, clean_text) 반환.""" match = _THINK_RE.search(text) if not match: return "", text thinking = match.group(1).strip() clean = _THINK_RE.sub("", text).strip() return thinking, clean # ── Tool Call 파싱 ──────────────────────────────────────────── @staticmethod def _parse_tool_calls(text: str) -> tuple[str, list]: matches = _TOOL_CALL_RE.findall(text) if not matches: return text, [] tool_calls = [] for raw in matches: try: data = json.loads(raw) tool_calls.append({ "id": f"call_{uuid.uuid4().hex[:8]}", "name": data["name"], "args": data.get("arguments", data.get("args", {})), "type": "tool_call", }) except (json.JSONDecodeError, KeyError): continue clean = _TOOL_CALL_RE.sub("", text).strip() return clean, tool_calls # ── LangChain BaseChatModel 인터페이스 ──────────────────────── def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager=None, **kwargs, ) -> ChatResult: from mlx_lm import generate tools = kwargs.get("tools") prompt = self._build_prompt(messages, tools) text = generate( self._model, self._tokenizer, prompt=prompt, max_tokens=self.max_tokens, verbose=False, ) thinking, after_think = self._parse_thinking(text) clean_text, tool_calls = self._parse_tool_calls(after_think) extra = {"thinking": thinking} if thinking else {} message = AIMessage(content=clean_text, tool_calls=tool_calls, additional_kwargs=extra) return ChatResult(generations=[ChatGeneration(message=message)]) def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager=None, **kwargs, ) -> Iterator[ChatGenerationChunk]: from mlx_lm import stream_generate tools = kwargs.get("tools") prompt = self._build_prompt(messages, tools) OPEN_THINK = "" CLOSE_THINK = "" OPEN_TOOL = "" CLOSE_TOOL = "" SAFE = max(len(OPEN_THINK), len(CLOSE_THINK), len(OPEN_TOOL), len(CLOSE_TOOL)) # enable_thinking=False 모델은 블록을 생성하지 않으므로 post_think에서 시작 state = "pre_think" if self.enable_thinking else "post_think" buf = "" out: list[ChatGenerationChunk] = [] def _think(text: str) -> None: out.append(ChatGenerationChunk( message=AIMessageChunk(content="", additional_kwargs={"thinking": text}) )) def _content(text: str) -> None: out.append(ChatGenerationChunk(message=AIMessageChunk(content=text))) def _tool(raw_json: str) -> None: try: data = json.loads(raw_json) tc = { "id": f"call_{uuid.uuid4().hex[:8]}", "name": data["name"], "args": data.get("arguments", data.get("args", {})), "type": "tool_call", } out.append(ChatGenerationChunk(message=AIMessageChunk(content="", tool_calls=[tc]))) except (json.JSONDecodeError, KeyError): pass def advance() -> None: nonlocal state, buf while buf: if state == "pre_think": idx = buf.find(OPEN_THINK) if idx == -1: safe = len(buf) - SAFE if safe > 0: _content(buf[:safe]) buf = buf[safe:] return if idx > 0: _content(buf[:idx]) buf = buf[idx + len(OPEN_THINK):] state = "in_think" elif state == "in_think": idx = buf.find(CLOSE_THINK) if idx == -1: safe = len(buf) - SAFE if safe > 0: _think(buf[:safe]) buf = buf[safe:] return if idx > 0: _think(buf[:idx]) buf = buf[idx + len(CLOSE_THINK):].lstrip() state = "post_think" elif state == "post_think": # 이후 \n\n 같은 공백을 건너뜀 buf = buf.lstrip() if not buf: return idx = buf.find(OPEN_TOOL) if idx == -1: # partial tag at end — hold and wait for i in range(len(OPEN_TOOL) - 1, 0, -1): if buf.endswith(OPEN_TOOL[:i]): safe_text = buf[:-i] if safe_text: _content(safe_text) buf = buf[-i:] return state = "in_answer" # no tool call coming elif idx == 0: buf = buf[len(OPEN_TOOL):] state = "in_tool" else: _content(buf[:idx]) buf = buf[idx + len(OPEN_TOOL):] state = "in_tool" elif state == "in_answer": _content(buf) buf = "" return elif state == "in_tool": idx = buf.find(CLOSE_TOOL) if idx == -1: return # wait for complete JSON _tool(buf[:idx].strip()) buf = buf[idx + len(CLOSE_TOOL):] state = "post_think" # may have more tool calls for raw in stream_generate(self._model, self._tokenizer, prompt=prompt, max_tokens=self.max_tokens): if run_manager: run_manager.on_llm_new_token(raw.text) buf += raw.text advance() yield from out out.clear() # flush remaining buffer if buf: if state == "in_think": _think(buf) elif state == "in_answer": _content(buf) elif state in ("pre_think", "post_think"): clean, tcs = self._parse_tool_calls(buf) if clean: _content(clean) for tc in tcs: out.append(ChatGenerationChunk(message=AIMessageChunk(content="", tool_calls=[tc]))) elif state == "in_tool": _tool(buf.strip()) yield from out async def _astream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager=None, **kwargs, ): import asyncio import threading loop = asyncio.get_running_loop() queue: asyncio.Queue = asyncio.Queue(maxsize=200) sentinel = object() exc_holder: list = [] def _run() -> None: try: for chunk in self._stream( messages, stop=stop, run_manager=run_manager, **kwargs ): asyncio.run_coroutine_threadsafe(queue.put(chunk), loop).result() except Exception as exc: exc_holder.append(exc) finally: asyncio.run_coroutine_threadsafe(queue.put(sentinel), loop).result() thread = threading.Thread(target=_run, daemon=True) thread.start() while True: item = await queue.get() if item is sentinel: break yield item thread.join(timeout=5) if exc_holder: raise exc_holder[0] def bind_tools(self, tools, tool_choice=None, **kwargs): formatted = [convert_to_openai_tool(t) for t in tools] return self.bind(tools=formatted, **kwargs)