#!/usr/bin/env python # -*- encoding:utf-8 -*- import fastapi.staticfiles from http import HTTPStatus import fastapi import config import asyncio from fastapi.middleware.cors import CORSMiddleware # 设置允许的源,可以是单个源或多个源 app = fastapi.FastAPI() app.mount("/output",fastapi.staticfiles.StaticFiles(directory="output"), name="output") origins = [ "*" ] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) from sqlcode.qgi import Executor, NewFormatter from sqlcode.langchain_model import Generator from sqlcode.modelloader import ModelLoader, ModelManager from sqlcode.qa_cache import QACache, load_qa_pairs from pydantic import BaseModel from sqlcode.qgi import ReturnType, QueryResult from sqlcode.multi_agent import create_sql_graph class QueryRequest(BaseModel): ''' * question: str 表示要查询的问题。 * return_type: str 表示要返回的结果类型。 ''' question: str return_type: ReturnType = ReturnType.TEXT qa_dict = load_qa_pairs('log/test.log') qa_cache = QACache(similarity_threshold=0.999) qa_cache.add(qa_dict) def try_search_in_cache(question:str) -> QueryResult: cached_answer = qa_cache.find_similar(question) if cached_answer != None: return QueryResult(status=HTTPStatus.OK, result=cached_answer, sql='', thought='从缓存得到结果', error=None) else: return None qwen_cfg = config.qwen_config("qwen_graph.conf") re_cfg = config.refineProblem_config() model_cfg = config.model_config() modelLoader = ModelLoader(model_cfg) modelManager = ModelManager(modelLoader) @app.post('/{model_name}/{database}') async def query(model_name:str, database:str, apikey:str, req:QueryRequest) -> QueryResult: ''' 根据请求执行查询并返回查询结果。 - :param model: str 表示要使用的大模型的名称。例如 qwen-turbo 等 - :param database: str 表示要查询的数据库的名称。 - :param req: QueryRequest 包含查询的具体问题和期望的返回类型。 ''' print('---------进入----------') # 完全匹配查找缓存结果 searched_result = try_search_in_cache(req.question) if searched_result is not None: print('从缓存找到结果') print(searched_result.result) return searched_result # 调用模型生成sql和答案 client = config.api_key(apikey) if not client: return QueryResult(status=HTTPStatus.UNAUTHORIZED, error='invalid apikey') if not database in client.databases: return QueryResult(status=HTTPStatus.FORBIDDEN, error='database permission denied') metadata = config.metadata(database) modelManager.switch_model(model_name) model = modelManager.get_model() agent_executor = create_sql_graph(model=model, qwen_cfg=qwen_cfg, data_cfg=metadata) generator = Generator(agentExcutor=agent_executor, messages=[{'role':'system', 'content': qwen_cfg.system}], apikey=client.dashscope_api_key, seed=0, ) input = {"messages": [("human", req.question)], "iterations": 0} if req.return_type == ReturnType.SQL: return generator.generate(input) formatter = NewFormatter(format=req.return_type, tranlate=None, output_dir=config.osp.join(config.BASE_DIR, 'output'), site_url=config.app_config().site_url + '/output', ) executor = Executor(generator=generator, formatter=formatter) query_result = executor.query(connection_string=metadata.connection_string, input=input, ) qa_cache.add({req.question: query_result.result}) return query_result if __name__ == '__main__': import sys if len(sys.argv) > 1 and sys.argv[1] == 'test': import logging logging.basicConfig(level=logging.DEBUG) ret = asyncio.run(query(apikey='ccscc', model='qwen-max', database='students', req=QueryRequest( question="2022年有哪几门课程", return_type=ReturnType.TEXT ), )) print(ret) else: import uvicorn uvicorn.run(app, host='0.0.0.0', port=9000)