234 lines
9.2 KiB
Python
234 lines
9.2 KiB
Python
#!/usr/bin/env python
|
||
# -*- encoding:utf-8 -*-
|
||
|
||
"""
|
||
微信企业后台 接口.
|
||
"""
|
||
|
||
import xml.etree.cElementTree as ET
|
||
from bizwechat import WXBizMsgCrypt
|
||
from http import HTTPStatus
|
||
import fastapi
|
||
import logging
|
||
import config
|
||
import asyncio
|
||
import aiohttp
|
||
from sqlcode.sql_agent import create_sql_agent, remove_markdown_code_block
|
||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||
from langchain_community.utilities import SQLDatabase
|
||
from langchain.agents.agent_types import AgentType
|
||
from langchain_core.tools import StructuredTool
|
||
from langchain.prompts import PromptTemplate
|
||
from langchain_chroma import Chroma
|
||
from langchain_community.embeddings import DashScopeEmbeddings
|
||
from sqlcode.utils import parse, format_docs
|
||
from langchain_core.runnables import RunnablePassthrough
|
||
from sqlcode.multi_agent import create_sql_graph
|
||
|
||
app = fastapi.FastAPI()
|
||
logger = logging.getLogger('sqlcode')
|
||
|
||
def get_wxcpt():
|
||
'''
|
||
创建微信企业后台接口的加密解密对象
|
||
'''
|
||
wxbiz_config = config.bizwechat_config()
|
||
return WXBizMsgCrypt(sToken=wxbiz_config.token,
|
||
sEncodingAESKey=wxbiz_config.aes_key,
|
||
sReceiveId=wxbiz_config.corp_id)
|
||
|
||
|
||
@app.get('/', include_in_schema=False)
|
||
async def verify_url(request:fastapi.Request):
|
||
'''
|
||
验证微信公众号的请求URL的合法性
|
||
'''
|
||
signature = request.query_params.get('msg_signature')
|
||
timestamp = request.query_params.get('timestamp')
|
||
nonce = request.query_params.get('nonce')
|
||
echostr = request.query_params.get('echostr')
|
||
|
||
if not signature or not timestamp or not nonce or not echostr:
|
||
logger.error('verify_url failed, missing parameters')
|
||
return fastapi.Response('', HTTPStatus.BAD_REQUEST)
|
||
|
||
code, echostr = get_wxcpt().VerifyURL(signature, timestamp, nonce, echostr)
|
||
|
||
if code == 0:
|
||
logger.info('verify_url success, echostr: %s', echostr)
|
||
return fastapi.Response(echostr, HTTPStatus.OK)
|
||
else:
|
||
logger.error('verify_url failed, error code: %s', code)
|
||
return fastapi.Response('', HTTPStatus.BAD_REQUEST)
|
||
|
||
|
||
@app.post('/', include_in_schema=False)
|
||
async def receive_message(request:fastapi.Request):
|
||
'''
|
||
接收微信公众号消息,并回复
|
||
'''
|
||
signature = request.query_params.get('msg_signature')
|
||
timestamp = request.query_params.get('timestamp')
|
||
nonce = request.query_params.get('nonce')
|
||
|
||
if not signature or not timestamp or not nonce:
|
||
logger.error('receive_message failed, missing parameters')
|
||
return '', HTTPStatus.BAD_REQUEST
|
||
|
||
logger.info('receive_message timestamp: %s, nonce: %s', timestamp, nonce)
|
||
|
||
post_data = await request.body()
|
||
code, post_data = get_wxcpt().DecryptMsg(post_data, signature, timestamp, nonce)
|
||
|
||
if code != 0:
|
||
logger.error('receive_message failed, error code: %s', code)
|
||
return '', HTTPStatus.BAD_REQUEST
|
||
|
||
xml = ET.fromstring(post_data)
|
||
content = xml.find('Content').text
|
||
from_user = xml.find('FromUserName').text
|
||
|
||
asyncio.create_task(async_query_and_reply(
|
||
apikey=config.bizwechat_config().qgi_api_key,
|
||
model_name='qwen',
|
||
database='contracts',
|
||
question=content,
|
||
to_user=from_user,
|
||
))
|
||
|
||
logger.info('receive_message success, timestamp: %s, nonce: %s', timestamp, nonce)
|
||
return ''
|
||
|
||
|
||
from sqlcode.qgi import Executor, NewFormatter, ReturnType
|
||
# from sqlcode.qwenapi import Generator
|
||
from sqlcode.langchain_model import Generator
|
||
from sqlcode.modelloader import ModelLoader, ModelManager
|
||
|
||
async def async_query_and_reply(apikey:str, model_name:str, database:str, question:str, to_user:str):
|
||
"""
|
||
根据请求执行查询并返回结果。
|
||
|
||
该函数接收一个模型名称、数据库名称和一个请求对象,通过这些信息来执行特定的查询操作。
|
||
查询的结果会根据请求中指定的返回类型进行处理和包装。
|
||
|
||
* :param model: 字符串类型,表示要使用的大模型的名称。例如 qwen-turbo 等
|
||
* :param database: 字符串类型,表示要查询的数据库的名称。
|
||
* :param req: Request 类型的对象,包含查询的具体问题和期望的返回类型。
|
||
"""
|
||
logger.info('开始处理问题:%s', question)
|
||
metadata = config.metadata(database)
|
||
qwen_cfg = config.qwen_config("qwen_graph.conf")
|
||
re_cfg = config.refineProblem_config()
|
||
|
||
modelLoader = ModelLoader(config.model_config())
|
||
modelManager = ModelManager(modelLoader)
|
||
modelManager.switch_model(model_name)
|
||
model = modelManager.get_model()
|
||
|
||
# db_dir = "chroma_db"
|
||
# # 需要先检查向量数据库是否为最新,执行sqlcode/store_vecstore.py来更新向量数据库
|
||
# vectorstore = Chroma(persist_directory=db_dir, embedding_function=DashScopeEmbeddings())
|
||
# retriever = vectorstore.as_retriever(search_kwargs={'k': 2})
|
||
# re_prompt = PromptTemplate.from_template(re_cfg.prompt)
|
||
# context_chain = retriever | format_docs
|
||
# context = context_chain.invoke(question)
|
||
# rag_chain = (
|
||
# {"context": lambda x: context, "question": RunnablePassthrough()}
|
||
# | re_prompt
|
||
# | model
|
||
# | parse
|
||
# )
|
||
# refine_question = rag_chain.invoke(question)
|
||
|
||
# db = SQLDatabase.from_uri(metadata.connection_string)
|
||
# toolkit = SQLDatabaseToolkit(db=db, llm=model)
|
||
# remove_markdown_tool = StructuredTool.from_function(
|
||
# func=remove_markdown_code_block,
|
||
# description="当发生因为markdown标记导致的函数输入错误时,可以调用此函数来删除标记",
|
||
# name="remove_markdown_code_block"
|
||
# )
|
||
# prompt = PromptTemplate.from_template(qwen_cfg.prompt)
|
||
# prompt = prompt.partial(metadata=metadata.metadata, product=metadata.product, example=qwen_cfg.params["example"], context=context)
|
||
# agent_executor = create_sql_agent(
|
||
# llm=model,
|
||
# prompt=prompt,
|
||
# toolkit=toolkit,
|
||
# verbose=True,
|
||
# agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||
# extra_tools=[remove_markdown_tool],
|
||
# agent_executor_kwargs={"handle_parsing_errors": True}
|
||
# )
|
||
|
||
agent_executor = create_sql_graph(model=model, qwen_cfg=qwen_cfg, data_cfg=metadata)
|
||
|
||
formatter = NewFormatter(format=ReturnType.WX_MD,
|
||
tranlate=None,
|
||
output_dir=config.osp.join(config.BASE_DIR, 'output'),
|
||
site_url=config.app_config().site_url + '/output',
|
||
)
|
||
|
||
generator = Generator(agentExcutor=agent_executor,
|
||
messages=[{'role': 'system', 'content': qwen_cfg.system}],
|
||
apikey=config.api_key(apikey).dashscope_api_key,
|
||
seed=0,
|
||
)
|
||
|
||
executor = Executor(generator=generator, formatter=formatter)
|
||
|
||
# input = {"input": question, "refined_question": refine_question}
|
||
input = {"messages": [("human", question)], "iterations": 0}
|
||
ret = executor.query(connection_string=metadata.connection_string,
|
||
input=input,
|
||
)
|
||
|
||
# thought and result 分开发送因为消息大小限制为 2048 字节(utf-8)
|
||
await send_msg(to_user, question, ret.error)
|
||
if ret.thought.startswith('```') and ret.thought.endswith('```'):
|
||
await send_msg(to_user, question, '针对这个问题,采用 SQL 查询:\n' + ret.thought)
|
||
else:
|
||
await send_msg(to_user, question, ret.thought)
|
||
await send_msg(to_user, question, ret.result)
|
||
|
||
async def send_msg(to_user:str, question:str, content:str):
|
||
'''
|
||
发送问题的回复给指定用户
|
||
'''
|
||
if not content:
|
||
return
|
||
|
||
msg = {
|
||
"touser" : to_user,
|
||
"msgtype": "markdown",
|
||
"agentid" : config.bizwechat_config().agent_id,
|
||
"markdown": {"content": content,},
|
||
}
|
||
access_token = config.wxbiz_token()
|
||
url = f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}"
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(url, json=msg) as resp:
|
||
resp_text = await resp.text()
|
||
logger.debug('send_msg to %s %s: %s', to_user, question, resp_text)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
import sys
|
||
if len(sys.argv) > 1 and sys.argv[1] == 'test':
|
||
logging.basicConfig(level=logging.DEBUG)
|
||
asyncio.run(async_query_and_reply(apikey=config.bizwechat_config().qgi_api_key,
|
||
model='qwen-max',
|
||
database='contracts',
|
||
question="2020年业绩最好的分公司",
|
||
to_user='SunHaiWen',
|
||
))
|
||
elif len(sys.argv) > 1 and sys.argv[1] == 'send':
|
||
logging.basicConfig(level=logging.DEBUG)
|
||
asyncio.run(send_msg(to_user='SunHaiWen',
|
||
question="2020年业绩最好的分公司",
|
||
content="2020年业绩最好的分公司是北京分公司",
|
||
))
|
||
else:
|
||
import uvicorn
|
||
uvicorn.run(app, host='0.0.0.0', port=9000)
|