from http import HTTPStatus import os import sqlalchemy import pandas as pd from enum import StrEnum, IntEnum from pydantic import BaseModel, Field import markdown import logging, logging.config import uuid from typing import Callable, Any, Optional, Union from sqlcode.utils import enrich_input import yaml logger = logging.getLogger(__name__) qustion_logger = logging.getLogger('question') def load_logging_config(config_path) -> None: with open(config_path, 'rt') as f: config = yaml.safe_load(f.read()) logging.config.dictConfig(config) # enum for return type class ReturnType(StrEnum): ''' - SQL: 返回 SQL 语句 - JSON: 返回 JSON 格式的数据 - TEXT: 返回文本格式的数据 - HTML: 返回 HTML 格式的数据 - WX_MD: 返回企业微信 Markdown 格式的数据 ''' SQL = "sql" '''返回 SQL 语句''' JSON = "json" '''返回 JSON 格式的数据''' TEXT = "text" '''返回文本格式的数据''' HTML = 'html' '''返回 HTML 格式的数据''' WX_MD = 'wx_markdown' '''返回企业微信 Markdown 格式的数据''' class QueryResult(BaseModel): status: int = Field(default=200, description='HTTP 状态码, 200 表示成功, 其他表示失败') result: str | dict[str, Any] | None = Field(default=None, description='查询结果') sql: str | None = Field(default=None, description='查询的 SQL 语句') thought: str | None = Field(default=None, description='大模型产生的推理') error: str | None = Field(default=None, description='错误信息') def _retrieve_sql(text) -> Union[str|QueryResult]: sql = text['output'] if (n:=sql.find('```sql')) >= 0: n += 6 sql = sql[n:sql.index('```', n)] elif (n:=sql.find('```\n')) >= 0: n += 3 sql = sql[n:sql.index('```', n)] elif (n:=sql.find('```\r\n')) >= 0: n += 3 sql = sql[n:sql.index('```', n)] elif (n:=sql.find('```')) >= 0: sql = sql[sql.index('\n', n): sql.index('```', n+3)] else: return QueryResult(status=HTTPStatus.FAILED_DEPENDENCY, error='从结果中无法找到有效的查询语句。', thought=text['output']) sql = sql.strip() return sql class Generator: max_retry: Optional[int] = 2 def generate(self, input:dict) -> QueryResult: ''' 抽象生成器类 params: - prompt: 输入文本 returns: - code: HTTP 状态码, 200 表示成功, 其他表示失败 - sql: 生成的 SQL 语句,如果 code 不是 200,则 sql 中可能包含错误信息 - text: 生成的完全文本,内容可能包含 SQL 语句,如果 code 不是 200,则 text 中可能包含错误信息 ''' cnt=0 sql = "" while cnt < self.max_retry: code, text = self._generate(input) if code != HTTPStatus.OK: return QueryResult(status=code, error=text['err'], thought=text['output']) res = _retrieve_sql(text) if isinstance(res, QueryResult): cnt+=1 print("Error: 从结果中无法找到有效的查询语句, retry again...") continue else: sql = res break if sql == "": return QueryResult(status=HTTPStatus.FAILED_DEPENDENCY, error='从结果中无法找到有效的查询语句。', thought=text['output']) logger.info('[QUESTION] %s\n%s\n%s', input['messages'][0], sql, text['output']) qustion_logger.info('%s\n--------\n%s\n--------------------\n', sql, text['output']) return QueryResult(status=HTTPStatus.OK, sql=sql, thought=text['output']) def _generate(self, input:dict) -> tuple[HTTPStatus,str]: """ 根据给定的提示生成SQL语句。 参数: - prompt: str - 提供给模型的提示,用于生成SQL语句。 返回: - HTTPStatus - 请求的状态码,200 表示成功。 - str - full text, 从大模型的返回的完整回复,如果 code 不是 200,则 text 中可能包含错误信息。 如果 code 是 200,则 text 中的 SQL 语句应包含在围栏式代码块 ```sql ``` 中。 """ raise NotImplementedError() class FormatSuggest(IntEnum): ''' 对结果集的格式化建议 ''' NoResult = 0 '''没有查询结果''' SingleRow = 1 '''结果为一行''' SingleColumn = 2 '''结果为一列''' MultiRow = 3 '''结果为多行,但行数不多,建议直接列举结果''' ExportToExcel = 4 '''结果集较大,建议导出到 Excel 文件''' def _suggest_format(df:pd.DataFrame) -> FormatSuggest: row_count = len(df.index) if row_count == 0: return FormatSuggest.NoResult elif row_count == 1: return FormatSuggest.SingleRow elif row_count < 20: if len(df.columns) == 1: return FormatSuggest.SingleColumn else: return FormatSuggest.MultiRow else: return FormatSuggest.ExportToExcel def _translate_column_names(df:pd.DataFrame, translate:Callable[[str],str]): if translate is None: return # translate column names col_names = df.columns.to_list() zh_names = translate(' | '.join(col_names).replace('_', ' ')).split('|') zh_names = list(map(lambda s: s.strip(), zh_names)) if len(col_names) == len(zh_names): df.columns = zh_names else: logger.error('[TRANSLATE] count not match (%d->%d) %s %s', len(col_names), len(zh_names), col_names, zh_names) # pandas 数据转为 markdown表格 def df_to_markdown(df): headers = list(df.columns) # 计算每列内容的总长度 total_length = sum(df[col].astype(str).str.len().sum() + len(col) for col in headers) # 计算每列相对比例 proportions = [((df[col].astype(str).str.len().sum() + len(col)) / total_length) for col in headers] # 假设总宽度为 100,按比例分配宽度 widths = [int(100 * prop) for prop in proportions] # 创建 Markdown 表格的表头 markdown_table = "| " + " | ".join([f"{header:<{width}}" for header, width in zip(headers, widths)]) + " |\n" markdown_table += "| " + " | ".join(["---" * width for width in widths]) + " |\n" # 添加数据行 for _, row in df.iterrows(): markdown_table += "| " + " | ".join([f"{str(value):<{width}}" for value, width in zip(row, widths)]) + " |\n" return markdown_table class Formatter: ''' 抽象格式化器类 ''' def __init__(self, tranlate) -> None: self.translate = tranlate def format(self, df:pd.DataFrame) -> str: ''' 格式化指定的 DataFrame。 ''' raise NotImplementedError() class _JsonFormatter(Formatter): def __init__(self) -> None: super().__init__(tranlate=None) def format(self, df:pd.DataFrame): return df.to_dict(orient='split', index=False) class _MarkdownFormatter(Formatter): def __init__(self, tranlate, output_dir:str, site_url:str) -> None: super().__init__(tranlate=tranlate) self.output_dir = output_dir self.site_url = site_url def format(self, df:pd.DataFrame): suggest = _suggest_format(df) if suggest == FormatSuggest.NoResult: return '没有符合条件的记录' _translate_column_names(df, self.translate) if suggest == FormatSuggest.SingleRow: return ', '.join([f'{k}:{v}' for k,v in df.to_dict(orient='records')[0].items()]) elif suggest == FormatSuggest.SingleColumn: return df[df.columns[0]].str.cat(sep='\n') elif suggest == FormatSuggest.MultiRow: output_file = uuid.uuid4().hex + '.xlsx' df.to_excel(os.path.join(self.output_dir, output_file), index=False) return df_to_markdown(df.head(10)) + \ '\n\n查询结果已导出到Excel文件,请点击链接下载文件:{}/{}'.format(self.site_url, output_file) else: # TODO 测试阶段仅保存前 10 条记录 output_file = uuid.uuid4().hex + '.xlsx' df.to_excel(os.path.join(self.output_dir, output_file), index=False) return df_to_markdown(df.head(10)) + \ '\n\n结果记录数或信息量太多,查询结果已导出到Excel文件(测试阶段仅导出前10行),请点击链接下载文件:{}/{}'.format(self.site_url, output_file) class _HtmlFormatter(_MarkdownFormatter): def __init__(self, tranlate, output_dir: str, site_url: str) -> None: super().__init__(tranlate, output_dir, site_url) def format(self, df:pd.DataFrame): answer = super().format(df) return markdown.markdown(answer, extensions=['markdown.extensions.tables']) class _WechatFormatter(_MarkdownFormatter): def __init__(self, tranlate, output_dir: str, site_url: str) -> None: super().__init__(tranlate, output_dir, site_url) def format(self, df: pd.DataFrame): suggest = _suggest_format(df) if suggest == FormatSuggest.NoResult: return '没有符合条件的记录' _translate_column_names(df, self.translate) if suggest == FormatSuggest.ExportToExcel: # TODO 测试阶段仅保存前 10 条记录 output_file = uuid.uuid4().hex + '.xlsx' # df.head(10).to_excel(os.path.join(self.output_dir, output_file), index=False) # return f'结果记录数或信息量太多,查询结果已导出到Excel文件,请下载查看 (测试阶段仅导出前10行)'+df.head(10).to_string() # return df.head(10).to_html(index=False) return df_to_markdown(df.head(10)) elif suggest == FormatSuggest.MultiRow: if len(df.columns) <= 3: return '\n'.join([', '.join([str(v) for v in row.values()]) for row in df.to_dict(orient='records')]) else: return '\n'.join([', '.join([f'{k}:{v}' for k,v in row.items()]) for row in df.to_dict(orient='records')]) else: return super().format(df) def NewFormatter(format:str, tranlate:Callable[[str],str], output_dir:str, site_url:str) -> Formatter: if format == ReturnType.JSON: return _JsonFormatter() elif format == ReturnType.TEXT: return _MarkdownFormatter(tranlate, output_dir, site_url) elif format == ReturnType.HTML: return _HtmlFormatter(tranlate, output_dir, site_url) elif format == ReturnType.WX_MD: return _WechatFormatter(tranlate, output_dir, site_url) else: raise ValueError('不支持的格式化器类型') class Executor: ''' 抽象执行器类 params: - sql: SQL 语句 returns: - code: HTTP 状态码, 200 表示成功, 其他表示失败 - text: 生成的完全文本,内容可能包含 SQL 语句,如果 code 不是 200,则 text 中可能包含错误信息 ''' def __init__(self, generator:Generator, formatter:Formatter, max_retry:Optional[int] = 2) -> None: self.generator = generator self.formatter = formatter self.max_retry = max_retry def get_sql(self, cnt:int, input:dict) -> dict: """ 根据提供的问题生成SQL 参数: - cnt (int): 重试次数。 - input (dict): 主要包含用户查询的问题。 返回: - dict: 生成sql的返回体。 """ print('第{}次尝试'.format(cnt+1)) if cnt == 0: ret = self.generator.generate(input) else: new_input = enrich_input("之前生成的SQL语句有误或查询结果为空,请重新生成一个不同的SQL语句, question:", input) ret = self.generator.generate(new_input) return ret def query(self, connection_string:str, input:dict) -> QueryResult: """ 根据提供的模型、数据库和问题生成并执行SQL查询,然后根据返回类型返回结果。 参数: - database: str - 数据库的连接字符串。 - prompt: str - 提示文本。 - question: str - 用户的查询问题。 - input: 问题,input = {"messages": [("human", req.question)], "iterations": 0} 返回: - status_code - HTTP状态码。 - result 查询结果根据return_type的不同可以是字典、字符串或DataFrame的markdown表示。 - sql - 生成的SQL语句。 - thought - 模型生成的推理 """ error: Union[Exception|str|None] = None # 使用logger记录问答对 load_logging_config('config/development/logging.yaml') qa_logger = logging.getLogger('qa_cache') #数据库连接 sql_engine = sqlalchemy.create_engine(connection_string) with sql_engine.connect() as connection: cnt = 0 ret = None while cnt < self.max_retry: ret = self.get_sql(cnt, input) if ret.status != HTTPStatus.OK: qa_logger.error('[FAIL] %s\n%s', input['messages'][0][1], ret.error) return ret try: #查询数据 result = connection.execute(sqlalchemy.text(ret.sql)) df = pd.DataFrame(result, columns=result.keys()) print('******************\n', df.head(30),'\n***************\n') ret.result = self.formatter.format(df) qa_logger.info('[SUCCESS] \nInput:\n%s\nResult:\n%s', input['messages'][0][1], ret.result) break except Exception as e: error = e qa_logger.error('[QUERY] %s\n%s', input['messages'][0][1], error) ret.status=HTTPStatus.INTERNAL_SERVER_ERROR ret.error=str(error) print("Error: SQL执行错误,retry agin...{}".format(error)) cnt+=1 return ret