242 lines
8.2 KiB
Python
242 lines
8.2 KiB
Python
import os
|
||
import sys
|
||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||
sys.path.append(parent_dir)
|
||
|
||
from langchain_core.prompts import ChatPromptTemplate
|
||
from langchain_core.language_models import BaseLanguageModel
|
||
from langgraph.graph.graph import CompiledGraph
|
||
from typing import TypedDict, List
|
||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||
from langchain_community.utilities import SQLDatabase
|
||
from config import DatabaseConfig, QwenConfig
|
||
from langgraph.graph import END, START, StateGraph, MessagesState
|
||
from langchain_core.output_parsers import PydanticOutputParser
|
||
from langchain.output_parsers import OutputFixingParser
|
||
from langchain_core.messages import RemoveMessage, AnyMessage
|
||
import json
|
||
|
||
MAX_ITERATIONS = 3
|
||
FLAG = "reflect"
|
||
|
||
class GraphState(TypedDict):
|
||
"""
|
||
Represents the state of our graph.
|
||
|
||
Attributes:
|
||
error : Binary flag for control flow to indicate whether test error was tripped
|
||
messages : With user question, error messages, reasoning
|
||
generation : Code solution
|
||
iterations : Number of tries
|
||
"""
|
||
|
||
error: str
|
||
messages: list[AnyMessage]
|
||
generation: str
|
||
iterations: int
|
||
|
||
class code(BaseModel):
|
||
"""SQL输出"""
|
||
|
||
prefix: str = Field(description="对问题和使用方法的描述")
|
||
code: str = Field(description="SQL代码")
|
||
|
||
|
||
import re
|
||
|
||
def extract_sql_code(text):
|
||
# Regular expression to match ```sql``` blocks
|
||
pattern = r'```sql(.*?)```'
|
||
|
||
# Find all matches and extract the SQL code
|
||
matches = re.findall(pattern, text, re.DOTALL)
|
||
|
||
# Strip leading and trailing whitespace from each match
|
||
sql_code_blocks = [match.strip() for match in matches]
|
||
|
||
return sql_code_blocks
|
||
|
||
def parse_output(message):
|
||
content = message.content
|
||
sql_codes = extract_sql_code(content)
|
||
if len(sql_codes) > 0:
|
||
sql_code = sql_codes[0]
|
||
else:
|
||
sql_code = ""
|
||
return code(prefix=content, code=sql_code)
|
||
|
||
def truncate_strings_from_end(string_list, max_length=6000):
|
||
current_length = 0
|
||
truncated_list = []
|
||
|
||
# 从列表末尾开始遍历
|
||
for s in reversed(string_list):
|
||
if current_length + len(s) > max_length:
|
||
# 如果加上当前字符串长度后超过了max_length,计算可以保留的部分
|
||
truncated_length = max_length - current_length
|
||
truncated_list.append(s[-truncated_length:])
|
||
break
|
||
else:
|
||
truncated_list.append(s)
|
||
current_length += len(s)
|
||
|
||
# 由于是从后往前添加的,最后需要逆序返回
|
||
return list(reversed(truncated_list))
|
||
|
||
def create_sql_graph(model:BaseLanguageModel, qwen_cfg: QwenConfig, data_cfg: DatabaseConfig) -> CompiledGraph:
|
||
# parser = PydanticOutputParser(pydantic_object=code)
|
||
|
||
sql_gen_prompt = ChatPromptTemplate.from_messages(
|
||
[
|
||
("system", qwen_cfg.prompt),
|
||
# ("human", "回答用户问题,并用`json`标签包装输出 {format_instructions}"),
|
||
("placeholder", "{messages}")
|
||
]
|
||
)
|
||
sql_gen_prompt = sql_gen_prompt.partial(metadata=data_cfg.metadata, product=data_cfg.product, example=qwen_cfg.params["example"])
|
||
|
||
sql_gen_chain = sql_gen_prompt | model | parse_output
|
||
# 生成SQL语句
|
||
def generate(state: GraphState):
|
||
messages = state["messages"]
|
||
iterations = state["iterations"]
|
||
if "error" in state:
|
||
error = state["error"]
|
||
else:
|
||
error = "no"
|
||
# error = state["error"]
|
||
if error == 'yes':
|
||
messages += [("human", "请重新生成SQL,使用 code 工具确保输出内容格式化,包括前缀(prefix)、sql代码(code):")]
|
||
|
||
try:
|
||
sql_solution = sql_gen_chain.invoke({"messages": messages})
|
||
except ValueError as e:
|
||
messages = truncate_strings_from_end(messages)
|
||
sql_solution = sql_gen_chain.invoke({"messages": messages})
|
||
|
||
messages += [
|
||
(
|
||
"ai",
|
||
f"{sql_solution.prefix} \n Code: {sql_solution.code}"
|
||
)
|
||
]
|
||
iterations = iterations + 1
|
||
return {"generation": sql_solution, "messages": messages, "iterations": iterations}
|
||
# 检查SQL是否合法
|
||
def sql_check(state: GraphState):
|
||
print("---检查SQL---")
|
||
messages = state["messages"]
|
||
sql_solution = state["generation"]
|
||
iterations = state["iterations"]
|
||
|
||
code = sql_solution.code
|
||
sql_executor = SQLDatabase.from_uri(data_cfg.connection_string)
|
||
|
||
try:
|
||
sql_executor.run(code)
|
||
except Exception as e:
|
||
print("---SQL检查错误---")
|
||
error_msg = [("human", f"生成的SQL代码无法通过执行测试: {e}")]
|
||
messages += error_msg
|
||
return {
|
||
"generation": sql_solution,
|
||
"messages": messages,
|
||
"iterations": iterations,
|
||
"error": "yes",
|
||
}
|
||
|
||
print("---SQL检测通过---")
|
||
return {
|
||
"generation": sql_solution,
|
||
"messages": messages,
|
||
"iterations": iterations,
|
||
"error": "no",
|
||
}
|
||
# SQL若错误,反思
|
||
def reflect(state: GraphState):
|
||
messages = state["messages"]
|
||
iterations = state["iterations"]
|
||
sql_solution = state["generation"]
|
||
|
||
try:
|
||
reflections = sql_gen_chain.invoke({"messages": messages}).prefix
|
||
except ValueError as e:
|
||
messages = truncate_strings_from_end(messages)
|
||
reflections = sql_gen_chain.invoke({"messages": messages}).prefix
|
||
|
||
messages += [("ai", f"这里是对于错误的反思:{reflections}")]
|
||
return {
|
||
"generation": sql_solution,
|
||
"messages": messages,
|
||
"iterations": iterations,
|
||
"error": "no",
|
||
}
|
||
# 删除消息,保持消息列表的长度不超过3条
|
||
def delete_messages(state: GraphState):
|
||
messages = state["messages"]
|
||
if len(messages) > 3:
|
||
new_messages = messages[-3:]
|
||
state["messages"] = new_messages
|
||
return state
|
||
|
||
# 根据错误状态和迭代次数决定是否结束工作流
|
||
def decide_to_finish(state: GraphState):
|
||
error = state["error"]
|
||
iterations = state["iterations"]
|
||
|
||
if error == "no" or iterations == MAX_ITERATIONS:
|
||
print("---DECISION: FINISH---")
|
||
return "end"
|
||
else:
|
||
print("---DECISION: RE-TRY SOLUTION---")
|
||
if FLAG == "reflect":
|
||
return "reflect"
|
||
else:
|
||
return "generate"
|
||
|
||
# 状态图
|
||
workflow = StateGraph(GraphState)
|
||
|
||
workflow.add_node("generate", generate)
|
||
workflow.add_node("check_code", sql_check) # check code
|
||
workflow.add_node("reflect", reflect) # reflect
|
||
workflow.add_node("delete_messages",delete_messages)
|
||
|
||
# Build graph
|
||
workflow.add_edge(START, "generate")
|
||
workflow.add_edge("generate", "delete_messages")
|
||
workflow.add_edge("delete_messages", "check_code")
|
||
workflow.add_conditional_edges(
|
||
"check_code",
|
||
decide_to_finish,
|
||
{
|
||
"end": END,
|
||
"reflect": "reflect",
|
||
"generate": "generate",
|
||
},
|
||
)
|
||
workflow.add_edge("reflect", "generate")
|
||
app = workflow.compile()
|
||
|
||
return app
|
||
|
||
|
||
import config
|
||
if __name__ == '__main__':
|
||
# model_name = "qwen"
|
||
# model_cfg = config.model_config()
|
||
# modelLoader = ModelLoader(model_cfg)
|
||
# modelManager = ModelManager(modelLoader)
|
||
# modelManager.switch_model(model_name)
|
||
# model = modelManager.get_model()
|
||
|
||
from langchain_community.chat_models import ChatTongyi
|
||
model = ChatTongyi(model="qwen-turbo")
|
||
data_cfg = config.metadata("contracts")
|
||
qwen_cfg = config.qwen_config("qwen_graph.conf")
|
||
sql_graph = create_sql_graph(model, qwen_cfg, data_cfg)
|
||
input = {"messages": [("human", "教育、工业互联网、物联网相关行业的项目都有些什么,列举合同金额、项目名称、部门、签订时间、分公司")], "iterations": 0}
|
||
res = sql_graph.invoke(
|
||
input
|
||
)
|
||
print(input) |