234 lines
9.2 KiB
Python
234 lines
9.2 KiB
Python
|
|
#!/usr/bin/env python
|
|||
|
|
# -*- encoding:utf-8 -*-
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
微信企业后台 接口.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import xml.etree.cElementTree as ET
|
|||
|
|
from bizwechat import WXBizMsgCrypt
|
|||
|
|
from http import HTTPStatus
|
|||
|
|
import fastapi
|
|||
|
|
import logging
|
|||
|
|
import config
|
|||
|
|
import asyncio
|
|||
|
|
import aiohttp
|
|||
|
|
from sqlcode.sql_agent import create_sql_agent, remove_markdown_code_block
|
|||
|
|
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
|||
|
|
from langchain_community.utilities import SQLDatabase
|
|||
|
|
from langchain.agents.agent_types import AgentType
|
|||
|
|
from langchain_core.tools import StructuredTool
|
|||
|
|
from langchain.prompts import PromptTemplate
|
|||
|
|
from langchain_chroma import Chroma
|
|||
|
|
from langchain_community.embeddings import DashScopeEmbeddings
|
|||
|
|
from sqlcode.utils import parse, format_docs
|
|||
|
|
from langchain_core.runnables import RunnablePassthrough
|
|||
|
|
from sqlcode.multi_agent import create_sql_graph
|
|||
|
|
|
|||
|
|
app = fastapi.FastAPI()
|
|||
|
|
logger = logging.getLogger('sqlcode')
|
|||
|
|
|
|||
|
|
def get_wxcpt():
|
|||
|
|
'''
|
|||
|
|
创建微信企业后台接口的加密解密对象
|
|||
|
|
'''
|
|||
|
|
wxbiz_config = config.bizwechat_config()
|
|||
|
|
return WXBizMsgCrypt(sToken=wxbiz_config.token,
|
|||
|
|
sEncodingAESKey=wxbiz_config.aes_key,
|
|||
|
|
sReceiveId=wxbiz_config.corp_id)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.get('/', include_in_schema=False)
|
|||
|
|
async def verify_url(request:fastapi.Request):
|
|||
|
|
'''
|
|||
|
|
验证微信公众号的请求URL的合法性
|
|||
|
|
'''
|
|||
|
|
signature = request.query_params.get('msg_signature')
|
|||
|
|
timestamp = request.query_params.get('timestamp')
|
|||
|
|
nonce = request.query_params.get('nonce')
|
|||
|
|
echostr = request.query_params.get('echostr')
|
|||
|
|
|
|||
|
|
if not signature or not timestamp or not nonce or not echostr:
|
|||
|
|
logger.error('verify_url failed, missing parameters')
|
|||
|
|
return fastapi.Response('', HTTPStatus.BAD_REQUEST)
|
|||
|
|
|
|||
|
|
code, echostr = get_wxcpt().VerifyURL(signature, timestamp, nonce, echostr)
|
|||
|
|
|
|||
|
|
if code == 0:
|
|||
|
|
logger.info('verify_url success, echostr: %s', echostr)
|
|||
|
|
return fastapi.Response(echostr, HTTPStatus.OK)
|
|||
|
|
else:
|
|||
|
|
logger.error('verify_url failed, error code: %s', code)
|
|||
|
|
return fastapi.Response('', HTTPStatus.BAD_REQUEST)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post('/', include_in_schema=False)
|
|||
|
|
async def receive_message(request:fastapi.Request):
|
|||
|
|
'''
|
|||
|
|
接收微信公众号消息,并回复
|
|||
|
|
'''
|
|||
|
|
signature = request.query_params.get('msg_signature')
|
|||
|
|
timestamp = request.query_params.get('timestamp')
|
|||
|
|
nonce = request.query_params.get('nonce')
|
|||
|
|
|
|||
|
|
if not signature or not timestamp or not nonce:
|
|||
|
|
logger.error('receive_message failed, missing parameters')
|
|||
|
|
return '', HTTPStatus.BAD_REQUEST
|
|||
|
|
|
|||
|
|
logger.info('receive_message timestamp: %s, nonce: %s', timestamp, nonce)
|
|||
|
|
|
|||
|
|
post_data = await request.body()
|
|||
|
|
code, post_data = get_wxcpt().DecryptMsg(post_data, signature, timestamp, nonce)
|
|||
|
|
|
|||
|
|
if code != 0:
|
|||
|
|
logger.error('receive_message failed, error code: %s', code)
|
|||
|
|
return '', HTTPStatus.BAD_REQUEST
|
|||
|
|
|
|||
|
|
xml = ET.fromstring(post_data)
|
|||
|
|
content = xml.find('Content').text
|
|||
|
|
from_user = xml.find('FromUserName').text
|
|||
|
|
|
|||
|
|
asyncio.create_task(async_query_and_reply(
|
|||
|
|
apikey=config.bizwechat_config().qgi_api_key,
|
|||
|
|
model_name='qwen',
|
|||
|
|
database='contracts',
|
|||
|
|
question=content,
|
|||
|
|
to_user=from_user,
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
logger.info('receive_message success, timestamp: %s, nonce: %s', timestamp, nonce)
|
|||
|
|
return ''
|
|||
|
|
|
|||
|
|
|
|||
|
|
from sqlcode.qgi import Executor, NewFormatter, ReturnType
|
|||
|
|
# from sqlcode.qwenapi import Generator
|
|||
|
|
from sqlcode.langchain_model import Generator
|
|||
|
|
from sqlcode.modelloader import ModelLoader, ModelManager
|
|||
|
|
|
|||
|
|
async def async_query_and_reply(apikey:str, model_name:str, database:str, question:str, to_user:str):
|
|||
|
|
"""
|
|||
|
|
根据请求执行查询并返回结果。
|
|||
|
|
|
|||
|
|
该函数接收一个模型名称、数据库名称和一个请求对象,通过这些信息来执行特定的查询操作。
|
|||
|
|
查询的结果会根据请求中指定的返回类型进行处理和包装。
|
|||
|
|
|
|||
|
|
* :param model: 字符串类型,表示要使用的大模型的名称。例如 qwen-turbo 等
|
|||
|
|
* :param database: 字符串类型,表示要查询的数据库的名称。
|
|||
|
|
* :param req: Request 类型的对象,包含查询的具体问题和期望的返回类型。
|
|||
|
|
"""
|
|||
|
|
logger.info('开始处理问题:%s', question)
|
|||
|
|
metadata = config.metadata(database)
|
|||
|
|
qwen_cfg = config.qwen_config("qwen_graph.conf")
|
|||
|
|
re_cfg = config.refineProblem_config()
|
|||
|
|
|
|||
|
|
modelLoader = ModelLoader(config.model_config())
|
|||
|
|
modelManager = ModelManager(modelLoader)
|
|||
|
|
modelManager.switch_model(model_name)
|
|||
|
|
model = modelManager.get_model()
|
|||
|
|
|
|||
|
|
# db_dir = "chroma_db"
|
|||
|
|
# # 需要先检查向量数据库是否为最新,执行sqlcode/store_vecstore.py来更新向量数据库
|
|||
|
|
# vectorstore = Chroma(persist_directory=db_dir, embedding_function=DashScopeEmbeddings())
|
|||
|
|
# retriever = vectorstore.as_retriever(search_kwargs={'k': 2})
|
|||
|
|
# re_prompt = PromptTemplate.from_template(re_cfg.prompt)
|
|||
|
|
# context_chain = retriever | format_docs
|
|||
|
|
# context = context_chain.invoke(question)
|
|||
|
|
# rag_chain = (
|
|||
|
|
# {"context": lambda x: context, "question": RunnablePassthrough()}
|
|||
|
|
# | re_prompt
|
|||
|
|
# | model
|
|||
|
|
# | parse
|
|||
|
|
# )
|
|||
|
|
# refine_question = rag_chain.invoke(question)
|
|||
|
|
|
|||
|
|
# db = SQLDatabase.from_uri(metadata.connection_string)
|
|||
|
|
# toolkit = SQLDatabaseToolkit(db=db, llm=model)
|
|||
|
|
# remove_markdown_tool = StructuredTool.from_function(
|
|||
|
|
# func=remove_markdown_code_block,
|
|||
|
|
# description="当发生因为markdown标记导致的函数输入错误时,可以调用此函数来删除标记",
|
|||
|
|
# name="remove_markdown_code_block"
|
|||
|
|
# )
|
|||
|
|
# prompt = PromptTemplate.from_template(qwen_cfg.prompt)
|
|||
|
|
# prompt = prompt.partial(metadata=metadata.metadata, product=metadata.product, example=qwen_cfg.params["example"], context=context)
|
|||
|
|
# agent_executor = create_sql_agent(
|
|||
|
|
# llm=model,
|
|||
|
|
# prompt=prompt,
|
|||
|
|
# toolkit=toolkit,
|
|||
|
|
# verbose=True,
|
|||
|
|
# agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
|||
|
|
# extra_tools=[remove_markdown_tool],
|
|||
|
|
# agent_executor_kwargs={"handle_parsing_errors": True}
|
|||
|
|
# )
|
|||
|
|
|
|||
|
|
agent_executor = create_sql_graph(model=model, qwen_cfg=qwen_cfg, data_cfg=metadata)
|
|||
|
|
|
|||
|
|
formatter = NewFormatter(format=ReturnType.WX_MD,
|
|||
|
|
tranlate=None,
|
|||
|
|
output_dir=config.osp.join(config.BASE_DIR, 'output'),
|
|||
|
|
site_url=config.app_config().site_url + '/output',
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
generator = Generator(agentExcutor=agent_executor,
|
|||
|
|
messages=[{'role': 'system', 'content': qwen_cfg.system}],
|
|||
|
|
apikey=config.api_key(apikey).dashscope_api_key,
|
|||
|
|
seed=0,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
executor = Executor(generator=generator, formatter=formatter)
|
|||
|
|
|
|||
|
|
# input = {"input": question, "refined_question": refine_question}
|
|||
|
|
input = {"messages": [("human", question)], "iterations": 0}
|
|||
|
|
ret = executor.query(connection_string=metadata.connection_string,
|
|||
|
|
input=input,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# thought and result 分开发送因为消息大小限制为 2048 字节(utf-8)
|
|||
|
|
await send_msg(to_user, question, ret.error)
|
|||
|
|
if ret.thought.startswith('```') and ret.thought.endswith('```'):
|
|||
|
|
await send_msg(to_user, question, '针对这个问题,采用 SQL 查询:\n' + ret.thought)
|
|||
|
|
else:
|
|||
|
|
await send_msg(to_user, question, ret.thought)
|
|||
|
|
await send_msg(to_user, question, ret.result)
|
|||
|
|
|
|||
|
|
async def send_msg(to_user:str, question:str, content:str):
|
|||
|
|
'''
|
|||
|
|
发送问题的回复给指定用户
|
|||
|
|
'''
|
|||
|
|
if not content:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
msg = {
|
|||
|
|
"touser" : to_user,
|
|||
|
|
"msgtype": "markdown",
|
|||
|
|
"agentid" : config.bizwechat_config().agent_id,
|
|||
|
|
"markdown": {"content": content,},
|
|||
|
|
}
|
|||
|
|
access_token = config.wxbiz_token()
|
|||
|
|
url = f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}"
|
|||
|
|
|
|||
|
|
async with aiohttp.ClientSession() as session:
|
|||
|
|
async with session.post(url, json=msg) as resp:
|
|||
|
|
resp_text = await resp.text()
|
|||
|
|
logger.debug('send_msg to %s %s: %s', to_user, question, resp_text)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
import sys
|
|||
|
|
if len(sys.argv) > 1 and sys.argv[1] == 'test':
|
|||
|
|
logging.basicConfig(level=logging.DEBUG)
|
|||
|
|
asyncio.run(async_query_and_reply(apikey=config.bizwechat_config().qgi_api_key,
|
|||
|
|
model='qwen-max',
|
|||
|
|
database='contracts',
|
|||
|
|
question="2020年业绩最好的分公司",
|
|||
|
|
to_user='SunHaiWen',
|
|||
|
|
))
|
|||
|
|
elif len(sys.argv) > 1 and sys.argv[1] == 'send':
|
|||
|
|
logging.basicConfig(level=logging.DEBUG)
|
|||
|
|
asyncio.run(send_msg(to_user='SunHaiWen',
|
|||
|
|
question="2020年业绩最好的分公司",
|
|||
|
|
content="2020年业绩最好的分公司是北京分公司",
|
|||
|
|
))
|
|||
|
|
else:
|
|||
|
|
import uvicorn
|
|||
|
|
uvicorn.run(app, host='0.0.0.0', port=9000)
|