import os import sys parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) sys.path.append(parent_dir) from langchain_core.prompts import ChatPromptTemplate from langchain_core.language_models import BaseLanguageModel from langgraph.graph.graph import CompiledGraph from typing import TypedDict, List from langchain_core.pydantic_v1 import BaseModel, Field from langchain_community.utilities import SQLDatabase from config import DatabaseConfig, QwenConfig from langgraph.graph import END, START, StateGraph, MessagesState from langchain_core.output_parsers import PydanticOutputParser from langchain.output_parsers import OutputFixingParser from langchain_core.messages import RemoveMessage, AnyMessage import json MAX_ITERATIONS = 3 FLAG = "reflect" class GraphState(TypedDict): """ Represents the state of our graph. Attributes: error : Binary flag for control flow to indicate whether test error was tripped messages : With user question, error messages, reasoning generation : Code solution iterations : Number of tries """ error: str messages: list[AnyMessage] generation: str iterations: int class code(BaseModel): """SQL输出""" prefix: str = Field(description="对问题和使用方法的描述") code: str = Field(description="SQL代码") import re def extract_sql_code(text): # Regular expression to match ```sql``` blocks pattern = r'```sql(.*?)```' # Find all matches and extract the SQL code matches = re.findall(pattern, text, re.DOTALL) # Strip leading and trailing whitespace from each match sql_code_blocks = [match.strip() for match in matches] return sql_code_blocks def parse_output(message): content = message.content sql_codes = extract_sql_code(content) if len(sql_codes) > 0: sql_code = sql_codes[0] else: sql_code = "" return code(prefix=content, code=sql_code) def truncate_strings_from_end(string_list, max_length=6000): current_length = 0 truncated_list = [] # 从列表末尾开始遍历 for s in reversed(string_list): if current_length + len(s) > max_length: # 如果加上当前字符串长度后超过了max_length,计算可以保留的部分 truncated_length = max_length - current_length truncated_list.append(s[-truncated_length:]) break else: truncated_list.append(s) current_length += len(s) # 由于是从后往前添加的,最后需要逆序返回 return list(reversed(truncated_list)) def create_sql_graph(model:BaseLanguageModel, qwen_cfg: QwenConfig, data_cfg: DatabaseConfig) -> CompiledGraph: # parser = PydanticOutputParser(pydantic_object=code) sql_gen_prompt = ChatPromptTemplate.from_messages( [ ("system", qwen_cfg.prompt), # ("human", "回答用户问题,并用`json`标签包装输出 {format_instructions}"), ("placeholder", "{messages}") ] ) sql_gen_prompt = sql_gen_prompt.partial(metadata=data_cfg.metadata, product=data_cfg.product, example=qwen_cfg.params["example"]) sql_gen_chain = sql_gen_prompt | model | parse_output # 生成SQL语句 def generate(state: GraphState): messages = state["messages"] iterations = state["iterations"] if "error" in state: error = state["error"] else: error = "no" # error = state["error"] if error == 'yes': messages += [("human", "请重新生成SQL,使用 code 工具确保输出内容格式化,包括前缀(prefix)、sql代码(code):")] try: sql_solution = sql_gen_chain.invoke({"messages": messages}) except ValueError as e: messages = truncate_strings_from_end(messages) sql_solution = sql_gen_chain.invoke({"messages": messages}) messages += [ ( "ai", f"{sql_solution.prefix} \n Code: {sql_solution.code}" ) ] iterations = iterations + 1 return {"generation": sql_solution, "messages": messages, "iterations": iterations} # 检查SQL是否合法 def sql_check(state: GraphState): print("---检查SQL---") messages = state["messages"] sql_solution = state["generation"] iterations = state["iterations"] code = sql_solution.code sql_executor = SQLDatabase.from_uri(data_cfg.connection_string) try: sql_executor.run(code) except Exception as e: print("---SQL检查错误---") error_msg = [("human", f"生成的SQL代码无法通过执行测试: {e}")] messages += error_msg return { "generation": sql_solution, "messages": messages, "iterations": iterations, "error": "yes", } print("---SQL检测通过---") return { "generation": sql_solution, "messages": messages, "iterations": iterations, "error": "no", } # SQL若错误,反思 def reflect(state: GraphState): messages = state["messages"] iterations = state["iterations"] sql_solution = state["generation"] try: reflections = sql_gen_chain.invoke({"messages": messages}).prefix except ValueError as e: messages = truncate_strings_from_end(messages) reflections = sql_gen_chain.invoke({"messages": messages}).prefix messages += [("ai", f"这里是对于错误的反思:{reflections}")] return { "generation": sql_solution, "messages": messages, "iterations": iterations, "error": "no", } # 删除消息,保持消息列表的长度不超过3条 def delete_messages(state: GraphState): messages = state["messages"] if len(messages) > 3: new_messages = messages[-3:] state["messages"] = new_messages return state # 根据错误状态和迭代次数决定是否结束工作流 def decide_to_finish(state: GraphState): error = state["error"] iterations = state["iterations"] if error == "no" or iterations == MAX_ITERATIONS: print("---DECISION: FINISH---") return "end" else: print("---DECISION: RE-TRY SOLUTION---") if FLAG == "reflect": return "reflect" else: return "generate" # 状态图 workflow = StateGraph(GraphState) workflow.add_node("generate", generate) workflow.add_node("check_code", sql_check) # check code workflow.add_node("reflect", reflect) # reflect workflow.add_node("delete_messages",delete_messages) # Build graph workflow.add_edge(START, "generate") workflow.add_edge("generate", "delete_messages") workflow.add_edge("delete_messages", "check_code") workflow.add_conditional_edges( "check_code", decide_to_finish, { "end": END, "reflect": "reflect", "generate": "generate", }, ) workflow.add_edge("reflect", "generate") app = workflow.compile() return app import config if __name__ == '__main__': # model_name = "qwen" # model_cfg = config.model_config() # modelLoader = ModelLoader(model_cfg) # modelManager = ModelManager(modelLoader) # modelManager.switch_model(model_name) # model = modelManager.get_model() from langchain_community.chat_models import ChatTongyi model = ChatTongyi(model="qwen-turbo") data_cfg = config.metadata("contracts") qwen_cfg = config.qwen_config("qwen_graph.conf") sql_graph = create_sql_graph(model, qwen_cfg, data_cfg) input = {"messages": [("human", "教育、工业互联网、物联网相关行业的项目都有些什么,列举合同金额、项目名称、部门、签订时间、分公司")], "iterations": 0} res = sql_graph.invoke( input ) print(input)