139 lines
4.5 KiB
Python
139 lines
4.5 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding:utf-8 -*-
|
|
|
|
import fastapi.staticfiles
|
|
from http import HTTPStatus
|
|
import fastapi
|
|
import config
|
|
import asyncio
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
# 设置允许的源,可以是单个源或多个源
|
|
|
|
app = fastapi.FastAPI()
|
|
|
|
app.mount("/output",fastapi.staticfiles.StaticFiles(directory="output"), name="output")
|
|
|
|
origins = [
|
|
"*"
|
|
]
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
from sqlcode.qgi import Executor, NewFormatter
|
|
from sqlcode.langchain_model import Generator
|
|
from sqlcode.modelloader import ModelLoader, ModelManager
|
|
from sqlcode.qa_cache import QACache, load_qa_pairs
|
|
from pydantic import BaseModel
|
|
from sqlcode.qgi import ReturnType, QueryResult
|
|
from sqlcode.multi_agent import create_sql_graph
|
|
|
|
class QueryRequest(BaseModel):
|
|
'''
|
|
* question: str 表示要查询的问题。
|
|
* return_type: str 表示要返回的结果类型。
|
|
'''
|
|
question: str
|
|
return_type: ReturnType = ReturnType.TEXT
|
|
|
|
qa_dict = load_qa_pairs('log/test.log')
|
|
qa_cache = QACache(similarity_threshold=0.999)
|
|
qa_cache.add(qa_dict)
|
|
|
|
def try_search_in_cache(question:str) -> QueryResult:
|
|
cached_answer = qa_cache.find_similar(question)
|
|
if cached_answer != None:
|
|
return QueryResult(status=HTTPStatus.OK, result=cached_answer, sql='', thought='从缓存得到结果', error=None)
|
|
else:
|
|
return None
|
|
|
|
|
|
qwen_cfg = config.qwen_config("qwen_graph.conf")
|
|
re_cfg = config.refineProblem_config()
|
|
|
|
model_cfg = config.model_config()
|
|
modelLoader = ModelLoader(model_cfg)
|
|
modelManager = ModelManager(modelLoader)
|
|
|
|
@app.post('/{model_name}/{database}')
|
|
async def query(model_name:str, database:str, apikey:str, req:QueryRequest) -> QueryResult:
|
|
'''
|
|
根据请求执行查询并返回查询结果。
|
|
|
|
- :param model: str 表示要使用的大模型的名称。例如 qwen-turbo 等
|
|
- :param database: str 表示要查询的数据库的名称。
|
|
- :param req: QueryRequest 包含查询的具体问题和期望的返回类型。
|
|
'''
|
|
print('---------进入----------')
|
|
# 完全匹配查找缓存结果
|
|
searched_result = try_search_in_cache(req.question)
|
|
if searched_result is not None:
|
|
print('从缓存找到结果')
|
|
print(searched_result.result)
|
|
return searched_result
|
|
|
|
# 调用模型生成sql和答案
|
|
client = config.api_key(apikey)
|
|
if not client:
|
|
return QueryResult(status=HTTPStatus.UNAUTHORIZED, error='invalid apikey')
|
|
|
|
if not database in client.databases:
|
|
return QueryResult(status=HTTPStatus.FORBIDDEN, error='database permission denied')
|
|
|
|
metadata = config.metadata(database)
|
|
modelManager.switch_model(model_name)
|
|
model = modelManager.get_model()
|
|
|
|
agent_executor = create_sql_graph(model=model, qwen_cfg=qwen_cfg, data_cfg=metadata)
|
|
|
|
generator = Generator(agentExcutor=agent_executor,
|
|
messages=[{'role':'system', 'content': qwen_cfg.system}],
|
|
apikey=client.dashscope_api_key,
|
|
seed=0,
|
|
)
|
|
|
|
input = {"messages": [("human", req.question)], "iterations": 0}
|
|
|
|
if req.return_type == ReturnType.SQL:
|
|
return generator.generate(input)
|
|
|
|
formatter = NewFormatter(format=req.return_type,
|
|
tranlate=None,
|
|
output_dir=config.osp.join(config.BASE_DIR, 'output'),
|
|
site_url=config.app_config().site_url + '/output',
|
|
)
|
|
|
|
executor = Executor(generator=generator, formatter=formatter)
|
|
|
|
query_result = executor.query(connection_string=metadata.connection_string,
|
|
input=input,
|
|
)
|
|
|
|
qa_cache.add({req.question: query_result.result})
|
|
return query_result
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import sys
|
|
if len(sys.argv) > 1 and sys.argv[1] == 'test':
|
|
import logging
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
ret = asyncio.run(query(apikey='ccscc',
|
|
model='qwen-max',
|
|
database='students',
|
|
req=QueryRequest(
|
|
question="2022年有哪几门课程",
|
|
return_type=ReturnType.TEXT
|
|
),
|
|
))
|
|
print(ret)
|
|
else:
|
|
import uvicorn
|
|
uvicorn.run(app, host='0.0.0.0', port=9000) |