2025-02-17 10:34:35 +08:00

378 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from http import HTTPStatus
import os
import sqlalchemy
import pandas as pd
from enum import StrEnum, IntEnum
from pydantic import BaseModel, Field
import markdown
import logging, logging.config
import uuid
from typing import Callable, Any, Optional, Union
from sqlcode.utils import enrich_input
import yaml
logger = logging.getLogger(__name__)
qustion_logger = logging.getLogger('question')
def load_logging_config(config_path) -> None:
with open(config_path, 'rt') as f:
config = yaml.safe_load(f.read())
logging.config.dictConfig(config)
# enum for return type
class ReturnType(StrEnum):
'''
- SQL: 返回 SQL 语句
- JSON: 返回 JSON 格式的数据
- TEXT: 返回文本格式的数据
- HTML: 返回 HTML 格式的数据
- WX_MD: 返回企业微信 Markdown 格式的数据
'''
SQL = "sql"
'''返回 SQL 语句'''
JSON = "json"
'''返回 JSON 格式的数据'''
TEXT = "text"
'''返回文本格式的数据'''
HTML = 'html'
'''返回 HTML 格式的数据'''
WX_MD = 'wx_markdown'
'''返回企业微信 Markdown 格式的数据'''
class QueryResult(BaseModel):
status: int = Field(default=200, description='HTTP 状态码, 200 表示成功, 其他表示失败')
result: str | dict[str, Any] | None = Field(default=None, description='查询结果')
sql: str | None = Field(default=None, description='查询的 SQL 语句')
thought: str | None = Field(default=None, description='大模型产生的推理')
error: str | None = Field(default=None, description='错误信息')
def _retrieve_sql(text) -> Union[str|QueryResult]:
sql = text['output']
if (n:=sql.find('```sql')) >= 0:
n += 6
sql = sql[n:sql.index('```', n)]
elif (n:=sql.find('```\n')) >= 0:
n += 3
sql = sql[n:sql.index('```', n)]
elif (n:=sql.find('```\r\n')) >= 0:
n += 3
sql = sql[n:sql.index('```', n)]
elif (n:=sql.find('```')) >= 0:
sql = sql[sql.index('\n', n): sql.index('```', n+3)]
else:
return QueryResult(status=HTTPStatus.FAILED_DEPENDENCY, error='从结果中无法找到有效的查询语句。', thought=text['output'])
sql = sql.strip()
return sql
class Generator:
max_retry: Optional[int] = 2
def generate(self, input:dict) -> QueryResult:
'''
抽象生成器类
params:
- prompt: 输入文本
returns:
- code: HTTP 状态码, 200 表示成功, 其他表示失败
- sql: 生成的 SQL 语句,如果 code 不是 200则 sql 中可能包含错误信息
- text: 生成的完全文本,内容可能包含 SQL 语句,如果 code 不是 200则 text 中可能包含错误信息
'''
cnt=0
sql = ""
while cnt < self.max_retry:
code, text = self._generate(input)
if code != HTTPStatus.OK:
return QueryResult(status=code, error=text['err'], thought=text['output'])
res = _retrieve_sql(text)
if isinstance(res, QueryResult):
cnt+=1
print("Error: 从结果中无法找到有效的查询语句, retry again...")
continue
else:
sql = res
break
if sql == "":
return QueryResult(status=HTTPStatus.FAILED_DEPENDENCY, error='从结果中无法找到有效的查询语句。', thought=text['output'])
logger.info('[QUESTION] %s\n%s\n%s', input['messages'][0], sql, text['output'])
qustion_logger.info('%s\n--------\n%s\n--------------------\n', sql, text['output'])
return QueryResult(status=HTTPStatus.OK, sql=sql, thought=text['output'])
def _generate(self, input:dict) -> tuple[HTTPStatus,str]:
"""
根据给定的提示生成SQL语句。
参数:
- prompt: str - 提供给模型的提示用于生成SQL语句。
返回:
- HTTPStatus - 请求的状态码200 表示成功。
- str - full text, 从大模型的返回的完整回复,如果 code 不是 200则 text 中可能包含错误信息。
如果 code 是 200则 text 中的 SQL 语句应包含在围栏式代码块 ```sql ``` 中。
"""
raise NotImplementedError()
class FormatSuggest(IntEnum):
'''
对结果集的格式化建议
'''
NoResult = 0
'''没有查询结果'''
SingleRow = 1
'''结果为一行'''
SingleColumn = 2
'''结果为一列'''
MultiRow = 3
'''结果为多行,但行数不多,建议直接列举结果'''
ExportToExcel = 4
'''结果集较大,建议导出到 Excel 文件'''
def _suggest_format(df:pd.DataFrame) -> FormatSuggest:
row_count = len(df.index)
if row_count == 0:
return FormatSuggest.NoResult
elif row_count == 1:
return FormatSuggest.SingleRow
elif row_count < 20:
if len(df.columns) == 1:
return FormatSuggest.SingleColumn
else:
return FormatSuggest.MultiRow
else:
return FormatSuggest.ExportToExcel
def _translate_column_names(df:pd.DataFrame, translate:Callable[[str],str]):
if translate is None:
return
# translate column names
col_names = df.columns.to_list()
zh_names = translate(' | '.join(col_names).replace('_', ' ')).split('|')
zh_names = list(map(lambda s: s.strip(), zh_names))
if len(col_names) == len(zh_names):
df.columns = zh_names
else:
logger.error('[TRANSLATE] count not match (%d->%d) %s %s', len(col_names), len(zh_names), col_names, zh_names)
# pandas 数据转为 markdown表格
def df_to_markdown(df):
headers = list(df.columns)
# 计算每列内容的总长度
total_length = sum(df[col].astype(str).str.len().sum() + len(col) for col in headers)
# 计算每列相对比例
proportions = [((df[col].astype(str).str.len().sum() + len(col)) / total_length) for col in headers]
# 假设总宽度为 100按比例分配宽度
widths = [int(100 * prop) for prop in proportions]
# 创建 Markdown 表格的表头
markdown_table = "| " + " | ".join([f"{header:<{width}}" for header, width in zip(headers, widths)]) + " |\n"
markdown_table += "| " + " | ".join(["---" * width for width in widths]) + " |\n"
# 添加数据行
for _, row in df.iterrows():
markdown_table += "| " + " | ".join([f"{str(value):<{width}}" for value, width in zip(row, widths)]) + " |\n"
return markdown_table
class Formatter:
'''
抽象格式化器类
'''
def __init__(self, tranlate) -> None:
self.translate = tranlate
def format(self, df:pd.DataFrame) -> str:
'''
格式化指定的 DataFrame。
'''
raise NotImplementedError()
class _JsonFormatter(Formatter):
def __init__(self) -> None:
super().__init__(tranlate=None)
def format(self, df:pd.DataFrame):
return df.to_dict(orient='split', index=False)
class _MarkdownFormatter(Formatter):
def __init__(self, tranlate, output_dir:str, site_url:str) -> None:
super().__init__(tranlate=tranlate)
self.output_dir = output_dir
self.site_url = site_url
def format(self, df:pd.DataFrame):
suggest = _suggest_format(df)
if suggest == FormatSuggest.NoResult:
return '没有符合条件的记录'
_translate_column_names(df, self.translate)
if suggest == FormatSuggest.SingleRow:
return ', '.join([f'{k}:{v}' for k,v in df.to_dict(orient='records')[0].items()])
elif suggest == FormatSuggest.SingleColumn:
return df[df.columns[0]].str.cat(sep='\n')
elif suggest == FormatSuggest.MultiRow:
output_file = uuid.uuid4().hex + '.xlsx'
df.to_excel(os.path.join(self.output_dir, output_file), index=False)
return df_to_markdown(df.head(10)) + \
'\n\n查询结果已导出到Excel文件,请点击链接下载文件:{}/{}'.format(self.site_url, output_file)
else:
# TODO 测试阶段仅保存前 10 条记录
output_file = uuid.uuid4().hex + '.xlsx'
df.to_excel(os.path.join(self.output_dir, output_file), index=False)
return df_to_markdown(df.head(10)) + \
'\n\n结果记录数或信息量太多查询结果已导出到Excel文件(测试阶段仅导出前10行),请点击链接下载文件:{}/{}'.format(self.site_url, output_file)
class _HtmlFormatter(_MarkdownFormatter):
def __init__(self, tranlate, output_dir: str, site_url: str) -> None:
super().__init__(tranlate, output_dir, site_url)
def format(self, df:pd.DataFrame):
answer = super().format(df)
return markdown.markdown(answer, extensions=['markdown.extensions.tables'])
class _WechatFormatter(_MarkdownFormatter):
def __init__(self, tranlate, output_dir: str, site_url: str) -> None:
super().__init__(tranlate, output_dir, site_url)
def format(self, df: pd.DataFrame):
suggest = _suggest_format(df)
if suggest == FormatSuggest.NoResult:
return '没有符合条件的记录'
_translate_column_names(df, self.translate)
if suggest == FormatSuggest.ExportToExcel:
# TODO 测试阶段仅保存前 10 条记录
output_file = uuid.uuid4().hex + '.xlsx'
# df.head(10).to_excel(os.path.join(self.output_dir, output_file), index=False)
# return f'结果记录数或信息量太多查询结果已导出到Excel文件请<a href="{self.site_url}/{output_file}">下载查看</a> (测试阶段仅导出前10行)'+df.head(10).to_string()
# return df.head(10).to_html(index=False)
return df_to_markdown(df.head(10))
elif suggest == FormatSuggest.MultiRow:
if len(df.columns) <= 3:
return '\n'.join([', '.join([str(v) for v in row.values()]) for row in df.to_dict(orient='records')])
else:
return '\n'.join([', '.join([f'{k}:{v}' for k,v in row.items()]) for row in df.to_dict(orient='records')])
else:
return super().format(df)
def NewFormatter(format:str, tranlate:Callable[[str],str], output_dir:str, site_url:str) -> Formatter:
if format == ReturnType.JSON:
return _JsonFormatter()
elif format == ReturnType.TEXT:
return _MarkdownFormatter(tranlate, output_dir, site_url)
elif format == ReturnType.HTML:
return _HtmlFormatter(tranlate, output_dir, site_url)
elif format == ReturnType.WX_MD:
return _WechatFormatter(tranlate, output_dir, site_url)
else:
raise ValueError('不支持的格式化器类型')
class Executor:
'''
抽象执行器类
params:
- sql: SQL 语句
returns:
- code: HTTP 状态码, 200 表示成功, 其他表示失败
- text: 生成的完全文本,内容可能包含 SQL 语句,如果 code 不是 200则 text 中可能包含错误信息
'''
def __init__(self, generator:Generator, formatter:Formatter, max_retry:Optional[int] = 2) -> None:
self.generator = generator
self.formatter = formatter
self.max_retry = max_retry
def get_sql(self, cnt:int, input:dict) -> dict:
"""
根据提供的问题生成SQL
参数:
- cnt (int): 重试次数。
- input (dict): 主要包含用户查询的问题。
返回:
- dict: 生成sql的返回体。
"""
print('{}次尝试'.format(cnt+1))
if cnt == 0:
ret = self.generator.generate(input)
else:
new_input = enrich_input("之前生成的SQL语句有误或查询结果为空请重新生成一个不同的SQL语句, question:", input)
ret = self.generator.generate(new_input)
return ret
def query(self, connection_string:str, input:dict) -> QueryResult:
"""
根据提供的模型、数据库和问题生成并执行SQL查询然后根据返回类型返回结果。
参数:
- database: str - 数据库的连接字符串。
- prompt: str - 提示文本。
- question: str - 用户的查询问题。
- input: 问题input = {"messages": [("human", req.question)], "iterations": 0}
返回:
- status_code - HTTP状态码。
- result 查询结果根据return_type的不同可以是字典、字符串或DataFrame的markdown表示。
- sql - 生成的SQL语句。
- thought - 模型生成的推理
"""
error: Union[Exception|str|None] = None
# 使用logger记录问答对
load_logging_config('config/development/logging.yaml')
qa_logger = logging.getLogger('qa_cache')
#数据库连接
sql_engine = sqlalchemy.create_engine(connection_string)
with sql_engine.connect() as connection:
cnt = 0
ret = None
while cnt < self.max_retry:
ret = self.get_sql(cnt, input)
if ret.status != HTTPStatus.OK:
qa_logger.error('[FAIL] %s\n%s', input['messages'][0][1], ret.error)
return ret
try:
#查询数据
result = connection.execute(sqlalchemy.text(ret.sql))
df = pd.DataFrame(result, columns=result.keys())
print('******************\n', df.head(30),'\n***************\n')
ret.result = self.formatter.format(df)
qa_logger.info('[SUCCESS] \nInput:\n%s\nResult:\n%s', input['messages'][0][1], ret.result)
break
except Exception as e:
error = e
qa_logger.error('[QUERY] %s\n%s', input['messages'][0][1], error)
ret.status=HTTPStatus.INTERNAL_SERVER_ERROR
ret.error=str(error)
print("Error: SQL执行错误retry agin...{}".format(error))
cnt+=1
return ret