bizwechat/query.py

139 lines
4.5 KiB
Python
Raw Normal View History

2025-02-17 10:34:35 +08:00
#!/usr/bin/env python
# -*- encoding:utf-8 -*-
import fastapi.staticfiles
from http import HTTPStatus
import fastapi
import config
import asyncio
from fastapi.middleware.cors import CORSMiddleware
# 设置允许的源,可以是单个源或多个源
app = fastapi.FastAPI()
app.mount("/output",fastapi.staticfiles.StaticFiles(directory="output"), name="output")
origins = [
"*"
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
from sqlcode.qgi import Executor, NewFormatter
from sqlcode.langchain_model import Generator
from sqlcode.modelloader import ModelLoader, ModelManager
from sqlcode.qa_cache import QACache, load_qa_pairs
from pydantic import BaseModel
from sqlcode.qgi import ReturnType, QueryResult
from sqlcode.multi_agent import create_sql_graph
class QueryRequest(BaseModel):
'''
* question: str 表示要查询的问题
* return_type: str 表示要返回的结果类型
'''
question: str
return_type: ReturnType = ReturnType.TEXT
qa_dict = load_qa_pairs('log/test.log')
qa_cache = QACache(similarity_threshold=0.999)
qa_cache.add(qa_dict)
def try_search_in_cache(question:str) -> QueryResult:
cached_answer = qa_cache.find_similar(question)
if cached_answer != None:
return QueryResult(status=HTTPStatus.OK, result=cached_answer, sql='', thought='从缓存得到结果', error=None)
else:
return None
qwen_cfg = config.qwen_config("qwen_graph.conf")
re_cfg = config.refineProblem_config()
model_cfg = config.model_config()
modelLoader = ModelLoader(model_cfg)
modelManager = ModelManager(modelLoader)
@app.post('/{model_name}/{database}')
async def query(model_name:str, database:str, apikey:str, req:QueryRequest) -> QueryResult:
'''
根据请求执行查询并返回查询结果
- :param model: str 表示要使用的大模型的名称例如 qwen-turbo
- :param database: str 表示要查询的数据库的名称
- :param req: QueryRequest 包含查询的具体问题和期望的返回类型
'''
print('---------进入----------')
# 完全匹配查找缓存结果
searched_result = try_search_in_cache(req.question)
if searched_result is not None:
print('从缓存找到结果')
print(searched_result.result)
return searched_result
# 调用模型生成sql和答案
client = config.api_key(apikey)
if not client:
return QueryResult(status=HTTPStatus.UNAUTHORIZED, error='invalid apikey')
if not database in client.databases:
return QueryResult(status=HTTPStatus.FORBIDDEN, error='database permission denied')
metadata = config.metadata(database)
modelManager.switch_model(model_name)
model = modelManager.get_model()
agent_executor = create_sql_graph(model=model, qwen_cfg=qwen_cfg, data_cfg=metadata)
generator = Generator(agentExcutor=agent_executor,
messages=[{'role':'system', 'content': qwen_cfg.system}],
apikey=client.dashscope_api_key,
seed=0,
)
input = {"messages": [("human", req.question)], "iterations": 0}
if req.return_type == ReturnType.SQL:
return generator.generate(input)
formatter = NewFormatter(format=req.return_type,
tranlate=None,
output_dir=config.osp.join(config.BASE_DIR, 'output'),
site_url=config.app_config().site_url + '/output',
)
executor = Executor(generator=generator, formatter=formatter)
query_result = executor.query(connection_string=metadata.connection_string,
input=input,
)
qa_cache.add({req.question: query_result.result})
return query_result
if __name__ == '__main__':
import sys
if len(sys.argv) > 1 and sys.argv[1] == 'test':
import logging
logging.basicConfig(level=logging.DEBUG)
ret = asyncio.run(query(apikey='ccscc',
model='qwen-max',
database='students',
req=QueryRequest(
question="2022年有哪几门课程",
return_type=ReturnType.TEXT
),
))
print(ret)
else:
import uvicorn
uvicorn.run(app, host='0.0.0.0', port=9000)