bizwechat/sqlcode/multi_agent.py
2025-02-17 10:34:35 +08:00

242 lines
8.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(parent_dir)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.language_models import BaseLanguageModel
from langgraph.graph.graph import CompiledGraph
from typing import TypedDict, List
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_community.utilities import SQLDatabase
from config import DatabaseConfig, QwenConfig
from langgraph.graph import END, START, StateGraph, MessagesState
from langchain_core.output_parsers import PydanticOutputParser
from langchain.output_parsers import OutputFixingParser
from langchain_core.messages import RemoveMessage, AnyMessage
import json
MAX_ITERATIONS = 3
FLAG = "reflect"
class GraphState(TypedDict):
"""
Represents the state of our graph.
Attributes:
error : Binary flag for control flow to indicate whether test error was tripped
messages : With user question, error messages, reasoning
generation : Code solution
iterations : Number of tries
"""
error: str
messages: list[AnyMessage]
generation: str
iterations: int
class code(BaseModel):
"""SQL输出"""
prefix: str = Field(description="对问题和使用方法的描述")
code: str = Field(description="SQL代码")
import re
def extract_sql_code(text):
# Regular expression to match ```sql``` blocks
pattern = r'```sql(.*?)```'
# Find all matches and extract the SQL code
matches = re.findall(pattern, text, re.DOTALL)
# Strip leading and trailing whitespace from each match
sql_code_blocks = [match.strip() for match in matches]
return sql_code_blocks
def parse_output(message):
content = message.content
sql_codes = extract_sql_code(content)
if len(sql_codes) > 0:
sql_code = sql_codes[0]
else:
sql_code = ""
return code(prefix=content, code=sql_code)
def truncate_strings_from_end(string_list, max_length=6000):
current_length = 0
truncated_list = []
# 从列表末尾开始遍历
for s in reversed(string_list):
if current_length + len(s) > max_length:
# 如果加上当前字符串长度后超过了max_length计算可以保留的部分
truncated_length = max_length - current_length
truncated_list.append(s[-truncated_length:])
break
else:
truncated_list.append(s)
current_length += len(s)
# 由于是从后往前添加的,最后需要逆序返回
return list(reversed(truncated_list))
def create_sql_graph(model:BaseLanguageModel, qwen_cfg: QwenConfig, data_cfg: DatabaseConfig) -> CompiledGraph:
# parser = PydanticOutputParser(pydantic_object=code)
sql_gen_prompt = ChatPromptTemplate.from_messages(
[
("system", qwen_cfg.prompt),
# ("human", "回答用户问题,并用`json`标签包装输出 {format_instructions}"),
("placeholder", "{messages}")
]
)
sql_gen_prompt = sql_gen_prompt.partial(metadata=data_cfg.metadata, product=data_cfg.product, example=qwen_cfg.params["example"])
sql_gen_chain = sql_gen_prompt | model | parse_output
# 生成SQL语句
def generate(state: GraphState):
messages = state["messages"]
iterations = state["iterations"]
if "error" in state:
error = state["error"]
else:
error = "no"
# error = state["error"]
if error == 'yes':
messages += [("human", "请重新生成SQL使用 code 工具确保输出内容格式化包括前缀prefix、sql代码code:")]
try:
sql_solution = sql_gen_chain.invoke({"messages": messages})
except ValueError as e:
messages = truncate_strings_from_end(messages)
sql_solution = sql_gen_chain.invoke({"messages": messages})
messages += [
(
"ai",
f"{sql_solution.prefix} \n Code: {sql_solution.code}"
)
]
iterations = iterations + 1
return {"generation": sql_solution, "messages": messages, "iterations": iterations}
# 检查SQL是否合法
def sql_check(state: GraphState):
print("---检查SQL---")
messages = state["messages"]
sql_solution = state["generation"]
iterations = state["iterations"]
code = sql_solution.code
sql_executor = SQLDatabase.from_uri(data_cfg.connection_string)
try:
sql_executor.run(code)
except Exception as e:
print("---SQL检查错误---")
error_msg = [("human", f"生成的SQL代码无法通过执行测试: {e}")]
messages += error_msg
return {
"generation": sql_solution,
"messages": messages,
"iterations": iterations,
"error": "yes",
}
print("---SQL检测通过---")
return {
"generation": sql_solution,
"messages": messages,
"iterations": iterations,
"error": "no",
}
# SQL若错误反思
def reflect(state: GraphState):
messages = state["messages"]
iterations = state["iterations"]
sql_solution = state["generation"]
try:
reflections = sql_gen_chain.invoke({"messages": messages}).prefix
except ValueError as e:
messages = truncate_strings_from_end(messages)
reflections = sql_gen_chain.invoke({"messages": messages}).prefix
messages += [("ai", f"这里是对于错误的反思:{reflections}")]
return {
"generation": sql_solution,
"messages": messages,
"iterations": iterations,
"error": "no",
}
# 删除消息保持消息列表的长度不超过3条
def delete_messages(state: GraphState):
messages = state["messages"]
if len(messages) > 3:
new_messages = messages[-3:]
state["messages"] = new_messages
return state
# 根据错误状态和迭代次数决定是否结束工作流
def decide_to_finish(state: GraphState):
error = state["error"]
iterations = state["iterations"]
if error == "no" or iterations == MAX_ITERATIONS:
print("---DECISION: FINISH---")
return "end"
else:
print("---DECISION: RE-TRY SOLUTION---")
if FLAG == "reflect":
return "reflect"
else:
return "generate"
# 状态图
workflow = StateGraph(GraphState)
workflow.add_node("generate", generate)
workflow.add_node("check_code", sql_check) # check code
workflow.add_node("reflect", reflect) # reflect
workflow.add_node("delete_messages",delete_messages)
# Build graph
workflow.add_edge(START, "generate")
workflow.add_edge("generate", "delete_messages")
workflow.add_edge("delete_messages", "check_code")
workflow.add_conditional_edges(
"check_code",
decide_to_finish,
{
"end": END,
"reflect": "reflect",
"generate": "generate",
},
)
workflow.add_edge("reflect", "generate")
app = workflow.compile()
return app
import config
if __name__ == '__main__':
# model_name = "qwen"
# model_cfg = config.model_config()
# modelLoader = ModelLoader(model_cfg)
# modelManager = ModelManager(modelLoader)
# modelManager.switch_model(model_name)
# model = modelManager.get_model()
from langchain_community.chat_models import ChatTongyi
model = ChatTongyi(model="qwen-turbo")
data_cfg = config.metadata("contracts")
qwen_cfg = config.qwen_config("qwen_graph.conf")
sql_graph = create_sql_graph(model, qwen_cfg, data_cfg)
input = {"messages": [("human", "教育、工业互联网、物联网相关行业的项目都有些什么,列举合同金额、项目名称、部门、签订时间、分公司")], "iterations": 0}
res = sql_graph.invoke(
input
)
print(input)