378 lines
14 KiB
Python
378 lines
14 KiB
Python
from http import HTTPStatus
|
||
import os
|
||
import sqlalchemy
|
||
import pandas as pd
|
||
from enum import StrEnum, IntEnum
|
||
from pydantic import BaseModel, Field
|
||
import markdown
|
||
import logging, logging.config
|
||
import uuid
|
||
from typing import Callable, Any, Optional, Union
|
||
from sqlcode.utils import enrich_input
|
||
import yaml
|
||
|
||
logger = logging.getLogger(__name__)
|
||
qustion_logger = logging.getLogger('question')
|
||
|
||
def load_logging_config(config_path) -> None:
|
||
with open(config_path, 'rt') as f:
|
||
config = yaml.safe_load(f.read())
|
||
logging.config.dictConfig(config)
|
||
|
||
# enum for return type
|
||
class ReturnType(StrEnum):
|
||
'''
|
||
- SQL: 返回 SQL 语句
|
||
- JSON: 返回 JSON 格式的数据
|
||
- TEXT: 返回文本格式的数据
|
||
- HTML: 返回 HTML 格式的数据
|
||
- WX_MD: 返回企业微信 Markdown 格式的数据
|
||
'''
|
||
SQL = "sql"
|
||
'''返回 SQL 语句'''
|
||
|
||
JSON = "json"
|
||
'''返回 JSON 格式的数据'''
|
||
|
||
TEXT = "text"
|
||
'''返回文本格式的数据'''
|
||
|
||
HTML = 'html'
|
||
'''返回 HTML 格式的数据'''
|
||
|
||
WX_MD = 'wx_markdown'
|
||
'''返回企业微信 Markdown 格式的数据'''
|
||
|
||
|
||
class QueryResult(BaseModel):
|
||
status: int = Field(default=200, description='HTTP 状态码, 200 表示成功, 其他表示失败')
|
||
result: str | dict[str, Any] | None = Field(default=None, description='查询结果')
|
||
sql: str | None = Field(default=None, description='查询的 SQL 语句')
|
||
thought: str | None = Field(default=None, description='大模型产生的推理')
|
||
error: str | None = Field(default=None, description='错误信息')
|
||
|
||
|
||
def _retrieve_sql(text) -> Union[str|QueryResult]:
|
||
sql = text['output']
|
||
if (n:=sql.find('```sql')) >= 0:
|
||
n += 6
|
||
sql = sql[n:sql.index('```', n)]
|
||
elif (n:=sql.find('```\n')) >= 0:
|
||
n += 3
|
||
sql = sql[n:sql.index('```', n)]
|
||
elif (n:=sql.find('```\r\n')) >= 0:
|
||
n += 3
|
||
sql = sql[n:sql.index('```', n)]
|
||
elif (n:=sql.find('```')) >= 0:
|
||
sql = sql[sql.index('\n', n): sql.index('```', n+3)]
|
||
else:
|
||
return QueryResult(status=HTTPStatus.FAILED_DEPENDENCY, error='从结果中无法找到有效的查询语句。', thought=text['output'])
|
||
|
||
sql = sql.strip()
|
||
return sql
|
||
|
||
class Generator:
|
||
max_retry: Optional[int] = 2
|
||
|
||
def generate(self, input:dict) -> QueryResult:
|
||
'''
|
||
抽象生成器类
|
||
|
||
params:
|
||
- prompt: 输入文本
|
||
returns:
|
||
- code: HTTP 状态码, 200 表示成功, 其他表示失败
|
||
- sql: 生成的 SQL 语句,如果 code 不是 200,则 sql 中可能包含错误信息
|
||
- text: 生成的完全文本,内容可能包含 SQL 语句,如果 code 不是 200,则 text 中可能包含错误信息
|
||
'''
|
||
cnt=0
|
||
sql = ""
|
||
while cnt < self.max_retry:
|
||
code, text = self._generate(input)
|
||
|
||
if code != HTTPStatus.OK:
|
||
return QueryResult(status=code, error=text['err'], thought=text['output'])
|
||
res = _retrieve_sql(text)
|
||
if isinstance(res, QueryResult):
|
||
cnt+=1
|
||
print("Error: 从结果中无法找到有效的查询语句, retry again...")
|
||
continue
|
||
else:
|
||
sql = res
|
||
break
|
||
|
||
if sql == "":
|
||
return QueryResult(status=HTTPStatus.FAILED_DEPENDENCY, error='从结果中无法找到有效的查询语句。', thought=text['output'])
|
||
|
||
logger.info('[QUESTION] %s\n%s\n%s', input['messages'][0], sql, text['output'])
|
||
qustion_logger.info('%s\n--------\n%s\n--------------------\n', sql, text['output'])
|
||
|
||
return QueryResult(status=HTTPStatus.OK, sql=sql, thought=text['output'])
|
||
|
||
|
||
def _generate(self, input:dict) -> tuple[HTTPStatus,str]:
|
||
"""
|
||
根据给定的提示生成SQL语句。
|
||
|
||
参数:
|
||
- prompt: str - 提供给模型的提示,用于生成SQL语句。
|
||
|
||
返回:
|
||
- HTTPStatus - 请求的状态码,200 表示成功。
|
||
- str - full text, 从大模型的返回的完整回复,如果 code 不是 200,则 text 中可能包含错误信息。
|
||
如果 code 是 200,则 text 中的 SQL 语句应包含在围栏式代码块 ```sql ``` 中。
|
||
"""
|
||
raise NotImplementedError()
|
||
|
||
|
||
class FormatSuggest(IntEnum):
|
||
'''
|
||
对结果集的格式化建议
|
||
'''
|
||
|
||
NoResult = 0
|
||
'''没有查询结果'''
|
||
|
||
SingleRow = 1
|
||
'''结果为一行'''
|
||
|
||
SingleColumn = 2
|
||
'''结果为一列'''
|
||
|
||
MultiRow = 3
|
||
'''结果为多行,但行数不多,建议直接列举结果'''
|
||
|
||
ExportToExcel = 4
|
||
'''结果集较大,建议导出到 Excel 文件'''
|
||
|
||
|
||
def _suggest_format(df:pd.DataFrame) -> FormatSuggest:
|
||
row_count = len(df.index)
|
||
|
||
if row_count == 0:
|
||
return FormatSuggest.NoResult
|
||
elif row_count == 1:
|
||
return FormatSuggest.SingleRow
|
||
elif row_count < 20:
|
||
if len(df.columns) == 1:
|
||
return FormatSuggest.SingleColumn
|
||
else:
|
||
return FormatSuggest.MultiRow
|
||
else:
|
||
return FormatSuggest.ExportToExcel
|
||
|
||
def _translate_column_names(df:pd.DataFrame, translate:Callable[[str],str]):
|
||
if translate is None:
|
||
return
|
||
|
||
# translate column names
|
||
col_names = df.columns.to_list()
|
||
zh_names = translate(' | '.join(col_names).replace('_', ' ')).split('|')
|
||
zh_names = list(map(lambda s: s.strip(), zh_names))
|
||
if len(col_names) == len(zh_names):
|
||
df.columns = zh_names
|
||
else:
|
||
logger.error('[TRANSLATE] count not match (%d->%d) %s %s', len(col_names), len(zh_names), col_names, zh_names)
|
||
|
||
|
||
# pandas 数据转为 markdown表格
|
||
def df_to_markdown(df):
|
||
headers = list(df.columns)
|
||
# 计算每列内容的总长度
|
||
total_length = sum(df[col].astype(str).str.len().sum() + len(col) for col in headers)
|
||
# 计算每列相对比例
|
||
proportions = [((df[col].astype(str).str.len().sum() + len(col)) / total_length) for col in headers]
|
||
# 假设总宽度为 100,按比例分配宽度
|
||
widths = [int(100 * prop) for prop in proportions]
|
||
# 创建 Markdown 表格的表头
|
||
markdown_table = "| " + " | ".join([f"{header:<{width}}" for header, width in zip(headers, widths)]) + " |\n"
|
||
markdown_table += "| " + " | ".join(["---" * width for width in widths]) + " |\n"
|
||
# 添加数据行
|
||
for _, row in df.iterrows():
|
||
markdown_table += "| " + " | ".join([f"{str(value):<{width}}" for value, width in zip(row, widths)]) + " |\n"
|
||
return markdown_table
|
||
|
||
class Formatter:
|
||
'''
|
||
抽象格式化器类
|
||
'''
|
||
def __init__(self, tranlate) -> None:
|
||
self.translate = tranlate
|
||
|
||
def format(self, df:pd.DataFrame) -> str:
|
||
'''
|
||
格式化指定的 DataFrame。
|
||
'''
|
||
raise NotImplementedError()
|
||
|
||
class _JsonFormatter(Formatter):
|
||
def __init__(self) -> None:
|
||
super().__init__(tranlate=None)
|
||
|
||
def format(self, df:pd.DataFrame):
|
||
return df.to_dict(orient='split', index=False)
|
||
|
||
|
||
class _MarkdownFormatter(Formatter):
|
||
def __init__(self, tranlate, output_dir:str, site_url:str) -> None:
|
||
super().__init__(tranlate=tranlate)
|
||
self.output_dir = output_dir
|
||
self.site_url = site_url
|
||
|
||
def format(self, df:pd.DataFrame):
|
||
suggest = _suggest_format(df)
|
||
|
||
if suggest == FormatSuggest.NoResult:
|
||
return '没有符合条件的记录'
|
||
|
||
_translate_column_names(df, self.translate)
|
||
|
||
if suggest == FormatSuggest.SingleRow:
|
||
return ', '.join([f'{k}:{v}' for k,v in df.to_dict(orient='records')[0].items()])
|
||
elif suggest == FormatSuggest.SingleColumn:
|
||
return df[df.columns[0]].str.cat(sep='\n')
|
||
elif suggest == FormatSuggest.MultiRow:
|
||
output_file = uuid.uuid4().hex + '.xlsx'
|
||
df.to_excel(os.path.join(self.output_dir, output_file), index=False)
|
||
return df_to_markdown(df.head(10)) + \
|
||
'\n\n查询结果已导出到Excel文件,请点击链接下载文件:{}/{}'.format(self.site_url, output_file)
|
||
else:
|
||
# TODO 测试阶段仅保存前 10 条记录
|
||
output_file = uuid.uuid4().hex + '.xlsx'
|
||
df.to_excel(os.path.join(self.output_dir, output_file), index=False)
|
||
return df_to_markdown(df.head(10)) + \
|
||
'\n\n结果记录数或信息量太多,查询结果已导出到Excel文件(测试阶段仅导出前10行),请点击链接下载文件:{}/{}'.format(self.site_url, output_file)
|
||
|
||
|
||
|
||
class _HtmlFormatter(_MarkdownFormatter):
|
||
def __init__(self, tranlate, output_dir: str, site_url: str) -> None:
|
||
super().__init__(tranlate, output_dir, site_url)
|
||
|
||
def format(self, df:pd.DataFrame):
|
||
answer = super().format(df)
|
||
return markdown.markdown(answer, extensions=['markdown.extensions.tables'])
|
||
|
||
|
||
class _WechatFormatter(_MarkdownFormatter):
|
||
def __init__(self, tranlate, output_dir: str, site_url: str) -> None:
|
||
super().__init__(tranlate, output_dir, site_url)
|
||
|
||
def format(self, df: pd.DataFrame):
|
||
suggest = _suggest_format(df)
|
||
|
||
if suggest == FormatSuggest.NoResult:
|
||
return '没有符合条件的记录'
|
||
|
||
_translate_column_names(df, self.translate)
|
||
|
||
if suggest == FormatSuggest.ExportToExcel:
|
||
# TODO 测试阶段仅保存前 10 条记录
|
||
output_file = uuid.uuid4().hex + '.xlsx'
|
||
# df.head(10).to_excel(os.path.join(self.output_dir, output_file), index=False)
|
||
# return f'结果记录数或信息量太多,查询结果已导出到Excel文件,请<a href="{self.site_url}/{output_file}">下载查看</a> (测试阶段仅导出前10行)'+df.head(10).to_string()
|
||
# return df.head(10).to_html(index=False)
|
||
return df_to_markdown(df.head(10))
|
||
elif suggest == FormatSuggest.MultiRow:
|
||
if len(df.columns) <= 3:
|
||
return '\n'.join([', '.join([str(v) for v in row.values()]) for row in df.to_dict(orient='records')])
|
||
else:
|
||
return '\n'.join([', '.join([f'{k}:{v}' for k,v in row.items()]) for row in df.to_dict(orient='records')])
|
||
else:
|
||
return super().format(df)
|
||
|
||
|
||
def NewFormatter(format:str, tranlate:Callable[[str],str], output_dir:str, site_url:str) -> Formatter:
|
||
if format == ReturnType.JSON:
|
||
return _JsonFormatter()
|
||
elif format == ReturnType.TEXT:
|
||
return _MarkdownFormatter(tranlate, output_dir, site_url)
|
||
elif format == ReturnType.HTML:
|
||
return _HtmlFormatter(tranlate, output_dir, site_url)
|
||
elif format == ReturnType.WX_MD:
|
||
return _WechatFormatter(tranlate, output_dir, site_url)
|
||
else:
|
||
raise ValueError('不支持的格式化器类型')
|
||
|
||
|
||
class Executor:
|
||
'''
|
||
抽象执行器类
|
||
|
||
params:
|
||
- sql: SQL 语句
|
||
returns:
|
||
- code: HTTP 状态码, 200 表示成功, 其他表示失败
|
||
- text: 生成的完全文本,内容可能包含 SQL 语句,如果 code 不是 200,则 text 中可能包含错误信息
|
||
'''
|
||
def __init__(self, generator:Generator, formatter:Formatter, max_retry:Optional[int] = 2) -> None:
|
||
self.generator = generator
|
||
self.formatter = formatter
|
||
self.max_retry = max_retry
|
||
|
||
def get_sql(self, cnt:int, input:dict) -> dict:
|
||
"""
|
||
根据提供的问题生成SQL
|
||
|
||
参数:
|
||
- cnt (int): 重试次数。
|
||
- input (dict): 主要包含用户查询的问题。
|
||
|
||
返回:
|
||
- dict: 生成sql的返回体。
|
||
"""
|
||
print('第{}次尝试'.format(cnt+1))
|
||
if cnt == 0:
|
||
ret = self.generator.generate(input)
|
||
else:
|
||
new_input = enrich_input("之前生成的SQL语句有误或查询结果为空,请重新生成一个不同的SQL语句, question:", input)
|
||
ret = self.generator.generate(new_input)
|
||
return ret
|
||
|
||
def query(self, connection_string:str, input:dict) -> QueryResult:
|
||
"""
|
||
根据提供的模型、数据库和问题生成并执行SQL查询,然后根据返回类型返回结果。
|
||
|
||
参数:
|
||
- database: str - 数据库的连接字符串。
|
||
- prompt: str - 提示文本。
|
||
- question: str - 用户的查询问题。
|
||
- input: 问题,input = {"messages": [("human", req.question)], "iterations": 0}
|
||
|
||
返回:
|
||
- status_code - HTTP状态码。
|
||
- result 查询结果根据return_type的不同可以是字典、字符串或DataFrame的markdown表示。
|
||
- sql - 生成的SQL语句。
|
||
- thought - 模型生成的推理
|
||
"""
|
||
error: Union[Exception|str|None] = None
|
||
# 使用logger记录问答对
|
||
load_logging_config('config/development/logging.yaml')
|
||
qa_logger = logging.getLogger('qa_cache')
|
||
#数据库连接
|
||
sql_engine = sqlalchemy.create_engine(connection_string)
|
||
with sql_engine.connect() as connection:
|
||
cnt = 0
|
||
ret = None
|
||
while cnt < self.max_retry:
|
||
ret = self.get_sql(cnt, input)
|
||
if ret.status != HTTPStatus.OK:
|
||
qa_logger.error('[FAIL] %s\n%s', input['messages'][0][1], ret.error)
|
||
return ret
|
||
try:
|
||
#查询数据
|
||
result = connection.execute(sqlalchemy.text(ret.sql))
|
||
df = pd.DataFrame(result, columns=result.keys())
|
||
print('******************\n', df.head(30),'\n***************\n')
|
||
ret.result = self.formatter.format(df)
|
||
|
||
qa_logger.info('[SUCCESS] \nInput:\n%s\nResult:\n%s', input['messages'][0][1], ret.result)
|
||
break
|
||
except Exception as e:
|
||
error = e
|
||
qa_logger.error('[QUERY] %s\n%s', input['messages'][0][1], error)
|
||
ret.status=HTTPStatus.INTERNAL_SERVER_ERROR
|
||
ret.error=str(error)
|
||
print("Error: SQL执行错误,retry agin...{}".format(error))
|
||
cnt+=1
|
||
return ret |