from __future__ import annotations import networkx as nx from services.db.mysql_service import DatabaseService class GraphService: """NetworkX 기반 지식 그래프. 관계 트리플(subject, relation, object)을 MySQL에 영구 저장하고 메모리에 로드해 빠른 그래프 쿼리를 제공한다. """ def __init__(self, db: DatabaseService): self._db = db self._graphs: dict[str, nx.MultiDiGraph] = {} def _load(self, user_id: str) -> nx.MultiDiGraph: g = nx.MultiDiGraph() rows = self._db.execute( "SELECT subject, relation, object FROM td_knowledge_graph WHERE user_id = %s", (user_id,), ) for row in rows: g.add_edge(row["subject"], row["object"], relation=row["relation"]) return g def _graph(self, user_id: str) -> nx.MultiDiGraph: if user_id not in self._graphs: self._graphs[user_id] = self._load(user_id) return self._graphs[user_id] def _edge_exists(self, g: nx.MultiDiGraph, subject: str, relation: str, obj: str) -> bool: return any( d.get("relation") == relation and target == obj for _, target, d in g.out_edges(subject, data=True) ) def add_relation(self, subject: str, relation: str, obj: str, user_id: str) -> str: """관계 트리플을 저장한다. 동일 트리플이 존재하면 스킵.""" g = self._graph(user_id) if self._edge_exists(g, subject, relation, obj): return f"이미 저장된 관계입니다: {subject} -[{relation}]→ {obj}" rows = self._db.execute( "SELECT id FROM td_knowledge_graph " "WHERE user_id=%s AND subject=%s AND relation=%s AND object=%s", (user_id, subject, relation, obj), ) if not rows: self._db.execute_write( "INSERT INTO td_knowledge_graph (user_id, subject, relation, object) " "VALUES (%s, %s, %s, %s)", (user_id, subject, relation, obj), ) g.add_edge(subject, obj, relation=relation) return f"'{subject} -[{relation}]→ {obj}' 관계를 저장했습니다." def query_entity(self, entity: str, user_id: str) -> str: """엔티티에 연결된 모든 관계를 반환한다 (출발/도착 방향 모두).""" g = self._graph(user_id) if entity not in g: return f"'{entity}'에 대해 저장된 정보가 없습니다." lines = [] for _, target, data in g.out_edges(entity, data=True): lines.append(f" {entity} -[{data['relation']}]→ {target}") for source, _, data in g.in_edges(entity, data=True): lines.append(f" {source} -[{data['relation']}]→ {entity}") if not lines: return f"'{entity}'에 대해 저장된 정보가 없습니다." return f"'{entity}' 관련 정보:\n" + "\n".join(lines) def get_summary(self, user_id: str) -> str: """시스템 프롬프트 주입용 전체 관계 요약. 없으면 빈 문자열.""" g = self._graph(user_id) if not g.edges: return "" return "\n".join( f" {s} -[{d['relation']}]→ {t}" for s, t, d in g.edges(data=True) )