"""大语言模型工厂:支持 OpenAI 兼容的云端 API、Anthropic 兼容 API 和本地 Ollama。""" import os import time from typing import Any from dotenv import load_dotenv from backend.logger import get_logger load_dotenv(override=True) _llm_log = get_logger("llm") class _BaseLLM: """LLM 统一接口基类 — 所有后端都提供 invoke() 和 stream()。""" def invoke(self, prompt: str) -> Any: raise NotImplementedError def stream(self, prompt: str): raise NotImplementedError class _LLMLoggingWrapper(_BaseLLM): """包装任何 LLM 后端,自动记录输入/输出到 llm.log。""" def __init__(self, inner: _BaseLLM, model: str, backend: str, caller: str = ""): self._inner = inner self._model = model self._backend = backend self._caller = caller def invoke(self, prompt: str) -> Any: t0 = time.time() prompt_len = len(prompt) _llm_log.debug( "LLM invoke 请求", extra={ "direction": "request", "model": self._model, "backend": self._backend, "caller": self._caller, "prompt_length": prompt_len, "prompt_preview": prompt[:500], }, ) try: result = self._inner.invoke(prompt) elapsed = round((time.time() - t0) * 1000) content = getattr(result, "content", str(result)) resp_len = len(content) resp_preview = content[:500] _llm_log.info( "LLM invoke 完成", extra={ "direction": "response", "model": self._model, "backend": self._backend, "caller": self._caller, "duration_ms": elapsed, "response_length": resp_len, "response_preview": resp_preview, }, ) return result except Exception as e: elapsed = round((time.time() - t0) * 1000) _llm_log.error( "LLM invoke 异常", extra={ "direction": "error", "model": self._model, "backend": self._backend, "caller": self._caller, "duration_ms": elapsed, "error": str(e), "prompt_preview": prompt[:500], }, ) raise def stream(self, prompt: str): t0 = time.time() prompt_len = len(prompt) prompt_preview = prompt[:500] _llm_log.debug( "LLM stream 请求", extra={ "direction": "request", "model": self._model, "backend": self._backend, "caller": self._caller, "prompt_length": prompt_len, "prompt_preview": prompt[:500], }, ) full = [] try: for chunk in self._inner.stream(prompt): full.append(chunk) yield chunk elapsed = round((time.time() - t0) * 1000) resp_text = "".join(full) resp_len = len(resp_text) resp_preview = resp_text[:500] stop_reason = getattr(self._inner, '_last_stop_reason', None) self._last_stop_reason = stop_reason if stop_reason == "max_tokens": _llm_log.warning( "LLM stream 截断 (max_tokens),输出可能不完整", extra={ "direction": "response", "model": self._model, "backend": self._backend, "caller": self._caller, "duration_ms": elapsed, "response_length": resp_len, "stop_reason": stop_reason, }, ) else: _llm_log.info( "LLM stream 完成", extra={ "direction": "response", "model": self._model, "backend": self._backend, "caller": self._caller, "duration_ms": elapsed, "response_length": resp_len, "response_preview": resp_preview, "stop_reason": stop_reason, }, ) except Exception as e: elapsed = round((time.time() - t0) * 1000) _llm_log.error( "LLM stream 异常", extra={ "direction": "error", "model": self._model, "backend": self._backend, "caller": self._caller, "duration_ms": elapsed, "error": str(e), "prompt_preview": prompt[:500], }, ) raise DEFAULT_MAX_TOKENS = int(os.getenv("LLM_MAX_TOKENS", "8192")) def _build_raw_llm(caller: str = "", max_tokens: int | None = None) -> tuple[_BaseLLM, str, str]: """构造原始 LLM 实例,返回 (实例, model名, backend名)。 max_tokens: 覆盖默认输出 token 数。None 使用 LLM_MAX_TOKENS 环境变量或 8192。 """ backend = os.getenv("LLM_BACKEND", "cloud") if backend == "local": from langchain_ollama import ChatOllama model = os.getenv("LOCAL_LLM_MODEL", "qwen2.5-coder:7b") raw = ChatOllama(model=model, temperature=0.1) class OllamaWrapper(_BaseLLM): def invoke(self, prompt): return raw.invoke(prompt) def stream(self, prompt): for chunk in raw.stream(prompt): yield chunk.content return OllamaWrapper(), model, f"local/{model}" provider = os.getenv("LLM_PROVIDER", "openai") if provider == "anthropic": from anthropic import Anthropic api_key = os.getenv("ANTHROPIC_API_KEY") or os.getenv("OPENAI_API_KEY", "") base_url = os.getenv("ANTHROPIC_BASE_URL") or os.getenv("OPENAI_BASE_URL", "https://api.minimaxi.com/anthropic") model = os.getenv("LLM_MODEL", "MiniMax-M2.7") temperature = 0.1 _default_max_tokens = max_tokens if max_tokens is not None else DEFAULT_MAX_TOKENS client = Anthropic(api_key=api_key, base_url=base_url, timeout=120) class MiniMaxLLM(_BaseLLM): def __init__(self): self._last_stop_reason = None self._max_tokens = _default_max_tokens def invoke(self, prompt: str) -> Any: resp = client.messages.create( model=model, max_tokens=self._max_tokens, temperature=temperature, messages=[{"role": "user", "content": [{"type": "text", "text": prompt}]}], ) for block in resp.content: block_type = getattr(block, "type", "") if block_type == "text": return type("Response", (), {"content": block.text})() return type("Response", (), {"content": ""})() def stream(self, prompt: str): self._last_stop_reason = None with client.messages.stream( model=model, max_tokens=self._max_tokens, temperature=temperature, messages=[{"role": "user", "content": [{"type": "text", "text": prompt}]}], ) as s: for text in s.text_stream: yield text try: final_msg = s.get_final_message() self._last_stop_reason = getattr(final_msg, 'stop_reason', None) except Exception: pass def get_num_tokens(self, text: str) -> int: resp = client.messages.count_tokens( model=model, messages=[{"role": "user", "content": [{"type": "text", "text": text}]}], ) return resp.input_tokens return MiniMaxLLM(), model, f"cloud/anthropic/{model}" else: from langchain_openai import ChatOpenAI model = os.getenv("LLM_MODEL", "gpt-4o") raw = ChatOpenAI( model=model, api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1"), temperature=0.1, ) class OpenAIWrapper(_BaseLLM): def invoke(self, prompt): return raw.invoke(prompt) def stream(self, prompt): for chunk in raw.stream(prompt): yield chunk.content return OpenAIWrapper(), model, f"cloud/openai/{model}" def get_llm(caller: str = "", max_tokens: int | None = None) -> _BaseLLM: """返回带日志的 LLM 实例。caller 用于标识调用来源(如 generate、classify_intent)。 max_tokens: 覆盖默认输出 token 数。用于骨架生成等需要大量输出的节点。 """ inner, model, backend = _build_raw_llm(caller, max_tokens=max_tokens) return _LLMLoggingWrapper(inner, model=model, backend=backend, caller=caller) def get_llm_for_correction(): return get_llm(caller="correction")