bizwechat/wechat.py
2025-02-17 10:34:35 +08:00

234 lines
9.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- encoding:utf-8 -*-
"""
微信企业后台 接口.
"""
import xml.etree.cElementTree as ET
from bizwechat import WXBizMsgCrypt
from http import HTTPStatus
import fastapi
import logging
import config
import asyncio
import aiohttp
from sqlcode.sql_agent import create_sql_agent, remove_markdown_code_block
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain.agents.agent_types import AgentType
from langchain_core.tools import StructuredTool
from langchain.prompts import PromptTemplate
from langchain_chroma import Chroma
from langchain_community.embeddings import DashScopeEmbeddings
from sqlcode.utils import parse, format_docs
from langchain_core.runnables import RunnablePassthrough
from sqlcode.multi_agent import create_sql_graph
app = fastapi.FastAPI()
logger = logging.getLogger('sqlcode')
def get_wxcpt():
'''
创建微信企业后台接口的加密解密对象
'''
wxbiz_config = config.bizwechat_config()
return WXBizMsgCrypt(sToken=wxbiz_config.token,
sEncodingAESKey=wxbiz_config.aes_key,
sReceiveId=wxbiz_config.corp_id)
@app.get('/', include_in_schema=False)
async def verify_url(request:fastapi.Request):
'''
验证微信公众号的请求URL的合法性
'''
signature = request.query_params.get('msg_signature')
timestamp = request.query_params.get('timestamp')
nonce = request.query_params.get('nonce')
echostr = request.query_params.get('echostr')
if not signature or not timestamp or not nonce or not echostr:
logger.error('verify_url failed, missing parameters')
return fastapi.Response('', HTTPStatus.BAD_REQUEST)
code, echostr = get_wxcpt().VerifyURL(signature, timestamp, nonce, echostr)
if code == 0:
logger.info('verify_url success, echostr: %s', echostr)
return fastapi.Response(echostr, HTTPStatus.OK)
else:
logger.error('verify_url failed, error code: %s', code)
return fastapi.Response('', HTTPStatus.BAD_REQUEST)
@app.post('/', include_in_schema=False)
async def receive_message(request:fastapi.Request):
'''
接收微信公众号消息,并回复
'''
signature = request.query_params.get('msg_signature')
timestamp = request.query_params.get('timestamp')
nonce = request.query_params.get('nonce')
if not signature or not timestamp or not nonce:
logger.error('receive_message failed, missing parameters')
return '', HTTPStatus.BAD_REQUEST
logger.info('receive_message timestamp: %s, nonce: %s', timestamp, nonce)
post_data = await request.body()
code, post_data = get_wxcpt().DecryptMsg(post_data, signature, timestamp, nonce)
if code != 0:
logger.error('receive_message failed, error code: %s', code)
return '', HTTPStatus.BAD_REQUEST
xml = ET.fromstring(post_data)
content = xml.find('Content').text
from_user = xml.find('FromUserName').text
asyncio.create_task(async_query_and_reply(
apikey=config.bizwechat_config().qgi_api_key,
model_name='qwen',
database='contracts',
question=content,
to_user=from_user,
))
logger.info('receive_message success, timestamp: %s, nonce: %s', timestamp, nonce)
return ''
from sqlcode.qgi import Executor, NewFormatter, ReturnType
# from sqlcode.qwenapi import Generator
from sqlcode.langchain_model import Generator
from sqlcode.modelloader import ModelLoader, ModelManager
async def async_query_and_reply(apikey:str, model_name:str, database:str, question:str, to_user:str):
"""
根据请求执行查询并返回结果。
该函数接收一个模型名称、数据库名称和一个请求对象,通过这些信息来执行特定的查询操作。
查询的结果会根据请求中指定的返回类型进行处理和包装。
* :param model: 字符串类型,表示要使用的大模型的名称。例如 qwen-turbo 等
* :param database: 字符串类型,表示要查询的数据库的名称。
* :param req: Request 类型的对象,包含查询的具体问题和期望的返回类型。
"""
logger.info('开始处理问题:%s', question)
metadata = config.metadata(database)
qwen_cfg = config.qwen_config("qwen_graph.conf")
re_cfg = config.refineProblem_config()
modelLoader = ModelLoader(config.model_config())
modelManager = ModelManager(modelLoader)
modelManager.switch_model(model_name)
model = modelManager.get_model()
# db_dir = "chroma_db"
# # 需要先检查向量数据库是否为最新执行sqlcode/store_vecstore.py来更新向量数据库
# vectorstore = Chroma(persist_directory=db_dir, embedding_function=DashScopeEmbeddings())
# retriever = vectorstore.as_retriever(search_kwargs={'k': 2})
# re_prompt = PromptTemplate.from_template(re_cfg.prompt)
# context_chain = retriever | format_docs
# context = context_chain.invoke(question)
# rag_chain = (
# {"context": lambda x: context, "question": RunnablePassthrough()}
# | re_prompt
# | model
# | parse
# )
# refine_question = rag_chain.invoke(question)
# db = SQLDatabase.from_uri(metadata.connection_string)
# toolkit = SQLDatabaseToolkit(db=db, llm=model)
# remove_markdown_tool = StructuredTool.from_function(
# func=remove_markdown_code_block,
# description="当发生因为markdown标记导致的函数输入错误时可以调用此函数来删除标记",
# name="remove_markdown_code_block"
# )
# prompt = PromptTemplate.from_template(qwen_cfg.prompt)
# prompt = prompt.partial(metadata=metadata.metadata, product=metadata.product, example=qwen_cfg.params["example"], context=context)
# agent_executor = create_sql_agent(
# llm=model,
# prompt=prompt,
# toolkit=toolkit,
# verbose=True,
# agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
# extra_tools=[remove_markdown_tool],
# agent_executor_kwargs={"handle_parsing_errors": True}
# )
agent_executor = create_sql_graph(model=model, qwen_cfg=qwen_cfg, data_cfg=metadata)
formatter = NewFormatter(format=ReturnType.WX_MD,
tranlate=None,
output_dir=config.osp.join(config.BASE_DIR, 'output'),
site_url=config.app_config().site_url + '/output',
)
generator = Generator(agentExcutor=agent_executor,
messages=[{'role': 'system', 'content': qwen_cfg.system}],
apikey=config.api_key(apikey).dashscope_api_key,
seed=0,
)
executor = Executor(generator=generator, formatter=formatter)
# input = {"input": question, "refined_question": refine_question}
input = {"messages": [("human", question)], "iterations": 0}
ret = executor.query(connection_string=metadata.connection_string,
input=input,
)
# thought and result 分开发送因为消息大小限制为 2048 字节(utf-8)
await send_msg(to_user, question, ret.error)
if ret.thought.startswith('```') and ret.thought.endswith('```'):
await send_msg(to_user, question, '针对这个问题,采用 SQL 查询:\n' + ret.thought)
else:
await send_msg(to_user, question, ret.thought)
await send_msg(to_user, question, ret.result)
async def send_msg(to_user:str, question:str, content:str):
'''
发送问题的回复给指定用户
'''
if not content:
return
msg = {
"touser" : to_user,
"msgtype": "markdown",
"agentid" : config.bizwechat_config().agent_id,
"markdown": {"content": content,},
}
access_token = config.wxbiz_token()
url = f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}"
async with aiohttp.ClientSession() as session:
async with session.post(url, json=msg) as resp:
resp_text = await resp.text()
logger.debug('send_msg to %s %s: %s', to_user, question, resp_text)
if __name__ == '__main__':
import sys
if len(sys.argv) > 1 and sys.argv[1] == 'test':
logging.basicConfig(level=logging.DEBUG)
asyncio.run(async_query_and_reply(apikey=config.bizwechat_config().qgi_api_key,
model='qwen-max',
database='contracts',
question="2020年业绩最好的分公司",
to_user='SunHaiWen',
))
elif len(sys.argv) > 1 and sys.argv[1] == 'send':
logging.basicConfig(level=logging.DEBUG)
asyncio.run(send_msg(to_user='SunHaiWen',
question="2020年业绩最好的分公司",
content="2020年业绩最好的分公司是北京分公司",
))
else:
import uvicorn
uvicorn.run(app, host='0.0.0.0', port=9000)