bizwechat/sqlcode/multi_agent.py

242 lines
8.2 KiB
Python
Raw Permalink Normal View History

2025-02-17 10:34:35 +08:00
import os
import sys
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(parent_dir)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.language_models import BaseLanguageModel
from langgraph.graph.graph import CompiledGraph
from typing import TypedDict, List
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_community.utilities import SQLDatabase
from config import DatabaseConfig, QwenConfig
from langgraph.graph import END, START, StateGraph, MessagesState
from langchain_core.output_parsers import PydanticOutputParser
from langchain.output_parsers import OutputFixingParser
from langchain_core.messages import RemoveMessage, AnyMessage
import json
MAX_ITERATIONS = 3
FLAG = "reflect"
class GraphState(TypedDict):
"""
Represents the state of our graph.
Attributes:
error : Binary flag for control flow to indicate whether test error was tripped
messages : With user question, error messages, reasoning
generation : Code solution
iterations : Number of tries
"""
error: str
messages: list[AnyMessage]
generation: str
iterations: int
class code(BaseModel):
"""SQL输出"""
prefix: str = Field(description="对问题和使用方法的描述")
code: str = Field(description="SQL代码")
import re
def extract_sql_code(text):
# Regular expression to match ```sql``` blocks
pattern = r'```sql(.*?)```'
# Find all matches and extract the SQL code
matches = re.findall(pattern, text, re.DOTALL)
# Strip leading and trailing whitespace from each match
sql_code_blocks = [match.strip() for match in matches]
return sql_code_blocks
def parse_output(message):
content = message.content
sql_codes = extract_sql_code(content)
if len(sql_codes) > 0:
sql_code = sql_codes[0]
else:
sql_code = ""
return code(prefix=content, code=sql_code)
def truncate_strings_from_end(string_list, max_length=6000):
current_length = 0
truncated_list = []
# 从列表末尾开始遍历
for s in reversed(string_list):
if current_length + len(s) > max_length:
# 如果加上当前字符串长度后超过了max_length计算可以保留的部分
truncated_length = max_length - current_length
truncated_list.append(s[-truncated_length:])
break
else:
truncated_list.append(s)
current_length += len(s)
# 由于是从后往前添加的,最后需要逆序返回
return list(reversed(truncated_list))
def create_sql_graph(model:BaseLanguageModel, qwen_cfg: QwenConfig, data_cfg: DatabaseConfig) -> CompiledGraph:
# parser = PydanticOutputParser(pydantic_object=code)
sql_gen_prompt = ChatPromptTemplate.from_messages(
[
("system", qwen_cfg.prompt),
# ("human", "回答用户问题,并用`json`标签包装输出 {format_instructions}"),
("placeholder", "{messages}")
]
)
sql_gen_prompt = sql_gen_prompt.partial(metadata=data_cfg.metadata, product=data_cfg.product, example=qwen_cfg.params["example"])
sql_gen_chain = sql_gen_prompt | model | parse_output
# 生成SQL语句
def generate(state: GraphState):
messages = state["messages"]
iterations = state["iterations"]
if "error" in state:
error = state["error"]
else:
error = "no"
# error = state["error"]
if error == 'yes':
messages += [("human", "请重新生成SQL使用 code 工具确保输出内容格式化包括前缀prefix、sql代码code:")]
try:
sql_solution = sql_gen_chain.invoke({"messages": messages})
except ValueError as e:
messages = truncate_strings_from_end(messages)
sql_solution = sql_gen_chain.invoke({"messages": messages})
messages += [
(
"ai",
f"{sql_solution.prefix} \n Code: {sql_solution.code}"
)
]
iterations = iterations + 1
return {"generation": sql_solution, "messages": messages, "iterations": iterations}
# 检查SQL是否合法
def sql_check(state: GraphState):
print("---检查SQL---")
messages = state["messages"]
sql_solution = state["generation"]
iterations = state["iterations"]
code = sql_solution.code
sql_executor = SQLDatabase.from_uri(data_cfg.connection_string)
try:
sql_executor.run(code)
except Exception as e:
print("---SQL检查错误---")
error_msg = [("human", f"生成的SQL代码无法通过执行测试: {e}")]
messages += error_msg
return {
"generation": sql_solution,
"messages": messages,
"iterations": iterations,
"error": "yes",
}
print("---SQL检测通过---")
return {
"generation": sql_solution,
"messages": messages,
"iterations": iterations,
"error": "no",
}
# SQL若错误反思
def reflect(state: GraphState):
messages = state["messages"]
iterations = state["iterations"]
sql_solution = state["generation"]
try:
reflections = sql_gen_chain.invoke({"messages": messages}).prefix
except ValueError as e:
messages = truncate_strings_from_end(messages)
reflections = sql_gen_chain.invoke({"messages": messages}).prefix
messages += [("ai", f"这里是对于错误的反思:{reflections}")]
return {
"generation": sql_solution,
"messages": messages,
"iterations": iterations,
"error": "no",
}
# 删除消息保持消息列表的长度不超过3条
def delete_messages(state: GraphState):
messages = state["messages"]
if len(messages) > 3:
new_messages = messages[-3:]
state["messages"] = new_messages
return state
# 根据错误状态和迭代次数决定是否结束工作流
def decide_to_finish(state: GraphState):
error = state["error"]
iterations = state["iterations"]
if error == "no" or iterations == MAX_ITERATIONS:
print("---DECISION: FINISH---")
return "end"
else:
print("---DECISION: RE-TRY SOLUTION---")
if FLAG == "reflect":
return "reflect"
else:
return "generate"
# 状态图
workflow = StateGraph(GraphState)
workflow.add_node("generate", generate)
workflow.add_node("check_code", sql_check) # check code
workflow.add_node("reflect", reflect) # reflect
workflow.add_node("delete_messages",delete_messages)
# Build graph
workflow.add_edge(START, "generate")
workflow.add_edge("generate", "delete_messages")
workflow.add_edge("delete_messages", "check_code")
workflow.add_conditional_edges(
"check_code",
decide_to_finish,
{
"end": END,
"reflect": "reflect",
"generate": "generate",
},
)
workflow.add_edge("reflect", "generate")
app = workflow.compile()
return app
import config
if __name__ == '__main__':
# model_name = "qwen"
# model_cfg = config.model_config()
# modelLoader = ModelLoader(model_cfg)
# modelManager = ModelManager(modelLoader)
# modelManager.switch_model(model_name)
# model = modelManager.get_model()
from langchain_community.chat_models import ChatTongyi
model = ChatTongyi(model="qwen-turbo")
data_cfg = config.metadata("contracts")
qwen_cfg = config.qwen_config("qwen_graph.conf")
sql_graph = create_sql_graph(model, qwen_cfg, data_cfg)
input = {"messages": [("human", "教育、工业互联网、物联网相关行业的项目都有些什么,列举合同金额、项目名称、部门、签订时间、分公司")], "iterations": 0}
res = sql_graph.invoke(
input
)
print(input)