242 lines
8.2 KiB
Python
242 lines
8.2 KiB
Python
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
|||
|
|
sys.path.append(parent_dir)
|
|||
|
|
|
|||
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|||
|
|
from langchain_core.language_models import BaseLanguageModel
|
|||
|
|
from langgraph.graph.graph import CompiledGraph
|
|||
|
|
from typing import TypedDict, List
|
|||
|
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
|||
|
|
from langchain_community.utilities import SQLDatabase
|
|||
|
|
from config import DatabaseConfig, QwenConfig
|
|||
|
|
from langgraph.graph import END, START, StateGraph, MessagesState
|
|||
|
|
from langchain_core.output_parsers import PydanticOutputParser
|
|||
|
|
from langchain.output_parsers import OutputFixingParser
|
|||
|
|
from langchain_core.messages import RemoveMessage, AnyMessage
|
|||
|
|
import json
|
|||
|
|
|
|||
|
|
MAX_ITERATIONS = 3
|
|||
|
|
FLAG = "reflect"
|
|||
|
|
|
|||
|
|
class GraphState(TypedDict):
|
|||
|
|
"""
|
|||
|
|
Represents the state of our graph.
|
|||
|
|
|
|||
|
|
Attributes:
|
|||
|
|
error : Binary flag for control flow to indicate whether test error was tripped
|
|||
|
|
messages : With user question, error messages, reasoning
|
|||
|
|
generation : Code solution
|
|||
|
|
iterations : Number of tries
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
error: str
|
|||
|
|
messages: list[AnyMessage]
|
|||
|
|
generation: str
|
|||
|
|
iterations: int
|
|||
|
|
|
|||
|
|
class code(BaseModel):
|
|||
|
|
"""SQL输出"""
|
|||
|
|
|
|||
|
|
prefix: str = Field(description="对问题和使用方法的描述")
|
|||
|
|
code: str = Field(description="SQL代码")
|
|||
|
|
|
|||
|
|
|
|||
|
|
import re
|
|||
|
|
|
|||
|
|
def extract_sql_code(text):
|
|||
|
|
# Regular expression to match ```sql``` blocks
|
|||
|
|
pattern = r'```sql(.*?)```'
|
|||
|
|
|
|||
|
|
# Find all matches and extract the SQL code
|
|||
|
|
matches = re.findall(pattern, text, re.DOTALL)
|
|||
|
|
|
|||
|
|
# Strip leading and trailing whitespace from each match
|
|||
|
|
sql_code_blocks = [match.strip() for match in matches]
|
|||
|
|
|
|||
|
|
return sql_code_blocks
|
|||
|
|
|
|||
|
|
def parse_output(message):
|
|||
|
|
content = message.content
|
|||
|
|
sql_codes = extract_sql_code(content)
|
|||
|
|
if len(sql_codes) > 0:
|
|||
|
|
sql_code = sql_codes[0]
|
|||
|
|
else:
|
|||
|
|
sql_code = ""
|
|||
|
|
return code(prefix=content, code=sql_code)
|
|||
|
|
|
|||
|
|
def truncate_strings_from_end(string_list, max_length=6000):
|
|||
|
|
current_length = 0
|
|||
|
|
truncated_list = []
|
|||
|
|
|
|||
|
|
# 从列表末尾开始遍历
|
|||
|
|
for s in reversed(string_list):
|
|||
|
|
if current_length + len(s) > max_length:
|
|||
|
|
# 如果加上当前字符串长度后超过了max_length,计算可以保留的部分
|
|||
|
|
truncated_length = max_length - current_length
|
|||
|
|
truncated_list.append(s[-truncated_length:])
|
|||
|
|
break
|
|||
|
|
else:
|
|||
|
|
truncated_list.append(s)
|
|||
|
|
current_length += len(s)
|
|||
|
|
|
|||
|
|
# 由于是从后往前添加的,最后需要逆序返回
|
|||
|
|
return list(reversed(truncated_list))
|
|||
|
|
|
|||
|
|
def create_sql_graph(model:BaseLanguageModel, qwen_cfg: QwenConfig, data_cfg: DatabaseConfig) -> CompiledGraph:
|
|||
|
|
# parser = PydanticOutputParser(pydantic_object=code)
|
|||
|
|
|
|||
|
|
sql_gen_prompt = ChatPromptTemplate.from_messages(
|
|||
|
|
[
|
|||
|
|
("system", qwen_cfg.prompt),
|
|||
|
|
# ("human", "回答用户问题,并用`json`标签包装输出 {format_instructions}"),
|
|||
|
|
("placeholder", "{messages}")
|
|||
|
|
]
|
|||
|
|
)
|
|||
|
|
sql_gen_prompt = sql_gen_prompt.partial(metadata=data_cfg.metadata, product=data_cfg.product, example=qwen_cfg.params["example"])
|
|||
|
|
|
|||
|
|
sql_gen_chain = sql_gen_prompt | model | parse_output
|
|||
|
|
# 生成SQL语句
|
|||
|
|
def generate(state: GraphState):
|
|||
|
|
messages = state["messages"]
|
|||
|
|
iterations = state["iterations"]
|
|||
|
|
if "error" in state:
|
|||
|
|
error = state["error"]
|
|||
|
|
else:
|
|||
|
|
error = "no"
|
|||
|
|
# error = state["error"]
|
|||
|
|
if error == 'yes':
|
|||
|
|
messages += [("human", "请重新生成SQL,使用 code 工具确保输出内容格式化,包括前缀(prefix)、sql代码(code):")]
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
sql_solution = sql_gen_chain.invoke({"messages": messages})
|
|||
|
|
except ValueError as e:
|
|||
|
|
messages = truncate_strings_from_end(messages)
|
|||
|
|
sql_solution = sql_gen_chain.invoke({"messages": messages})
|
|||
|
|
|
|||
|
|
messages += [
|
|||
|
|
(
|
|||
|
|
"ai",
|
|||
|
|
f"{sql_solution.prefix} \n Code: {sql_solution.code}"
|
|||
|
|
)
|
|||
|
|
]
|
|||
|
|
iterations = iterations + 1
|
|||
|
|
return {"generation": sql_solution, "messages": messages, "iterations": iterations}
|
|||
|
|
# 检查SQL是否合法
|
|||
|
|
def sql_check(state: GraphState):
|
|||
|
|
print("---检查SQL---")
|
|||
|
|
messages = state["messages"]
|
|||
|
|
sql_solution = state["generation"]
|
|||
|
|
iterations = state["iterations"]
|
|||
|
|
|
|||
|
|
code = sql_solution.code
|
|||
|
|
sql_executor = SQLDatabase.from_uri(data_cfg.connection_string)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
sql_executor.run(code)
|
|||
|
|
except Exception as e:
|
|||
|
|
print("---SQL检查错误---")
|
|||
|
|
error_msg = [("human", f"生成的SQL代码无法通过执行测试: {e}")]
|
|||
|
|
messages += error_msg
|
|||
|
|
return {
|
|||
|
|
"generation": sql_solution,
|
|||
|
|
"messages": messages,
|
|||
|
|
"iterations": iterations,
|
|||
|
|
"error": "yes",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
print("---SQL检测通过---")
|
|||
|
|
return {
|
|||
|
|
"generation": sql_solution,
|
|||
|
|
"messages": messages,
|
|||
|
|
"iterations": iterations,
|
|||
|
|
"error": "no",
|
|||
|
|
}
|
|||
|
|
# SQL若错误,反思
|
|||
|
|
def reflect(state: GraphState):
|
|||
|
|
messages = state["messages"]
|
|||
|
|
iterations = state["iterations"]
|
|||
|
|
sql_solution = state["generation"]
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
reflections = sql_gen_chain.invoke({"messages": messages}).prefix
|
|||
|
|
except ValueError as e:
|
|||
|
|
messages = truncate_strings_from_end(messages)
|
|||
|
|
reflections = sql_gen_chain.invoke({"messages": messages}).prefix
|
|||
|
|
|
|||
|
|
messages += [("ai", f"这里是对于错误的反思:{reflections}")]
|
|||
|
|
return {
|
|||
|
|
"generation": sql_solution,
|
|||
|
|
"messages": messages,
|
|||
|
|
"iterations": iterations,
|
|||
|
|
"error": "no",
|
|||
|
|
}
|
|||
|
|
# 删除消息,保持消息列表的长度不超过3条
|
|||
|
|
def delete_messages(state: GraphState):
|
|||
|
|
messages = state["messages"]
|
|||
|
|
if len(messages) > 3:
|
|||
|
|
new_messages = messages[-3:]
|
|||
|
|
state["messages"] = new_messages
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
# 根据错误状态和迭代次数决定是否结束工作流
|
|||
|
|
def decide_to_finish(state: GraphState):
|
|||
|
|
error = state["error"]
|
|||
|
|
iterations = state["iterations"]
|
|||
|
|
|
|||
|
|
if error == "no" or iterations == MAX_ITERATIONS:
|
|||
|
|
print("---DECISION: FINISH---")
|
|||
|
|
return "end"
|
|||
|
|
else:
|
|||
|
|
print("---DECISION: RE-TRY SOLUTION---")
|
|||
|
|
if FLAG == "reflect":
|
|||
|
|
return "reflect"
|
|||
|
|
else:
|
|||
|
|
return "generate"
|
|||
|
|
|
|||
|
|
# 状态图
|
|||
|
|
workflow = StateGraph(GraphState)
|
|||
|
|
|
|||
|
|
workflow.add_node("generate", generate)
|
|||
|
|
workflow.add_node("check_code", sql_check) # check code
|
|||
|
|
workflow.add_node("reflect", reflect) # reflect
|
|||
|
|
workflow.add_node("delete_messages",delete_messages)
|
|||
|
|
|
|||
|
|
# Build graph
|
|||
|
|
workflow.add_edge(START, "generate")
|
|||
|
|
workflow.add_edge("generate", "delete_messages")
|
|||
|
|
workflow.add_edge("delete_messages", "check_code")
|
|||
|
|
workflow.add_conditional_edges(
|
|||
|
|
"check_code",
|
|||
|
|
decide_to_finish,
|
|||
|
|
{
|
|||
|
|
"end": END,
|
|||
|
|
"reflect": "reflect",
|
|||
|
|
"generate": "generate",
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
workflow.add_edge("reflect", "generate")
|
|||
|
|
app = workflow.compile()
|
|||
|
|
|
|||
|
|
return app
|
|||
|
|
|
|||
|
|
|
|||
|
|
import config
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
# model_name = "qwen"
|
|||
|
|
# model_cfg = config.model_config()
|
|||
|
|
# modelLoader = ModelLoader(model_cfg)
|
|||
|
|
# modelManager = ModelManager(modelLoader)
|
|||
|
|
# modelManager.switch_model(model_name)
|
|||
|
|
# model = modelManager.get_model()
|
|||
|
|
|
|||
|
|
from langchain_community.chat_models import ChatTongyi
|
|||
|
|
model = ChatTongyi(model="qwen-turbo")
|
|||
|
|
data_cfg = config.metadata("contracts")
|
|||
|
|
qwen_cfg = config.qwen_config("qwen_graph.conf")
|
|||
|
|
sql_graph = create_sql_graph(model, qwen_cfg, data_cfg)
|
|||
|
|
input = {"messages": [("human", "教育、工业互联网、物联网相关行业的项目都有些什么,列举合同金额、项目名称、部门、签订时间、分公司")], "iterations": 0}
|
|||
|
|
res = sql_graph.invoke(
|
|||
|
|
input
|
|||
|
|
)
|
|||
|
|
print(input)
|