#!/usr/bin/env python # -*- encoding:utf-8 -*- """ 微信企业后台 接口. """ import xml.etree.cElementTree as ET from bizwechat import WXBizMsgCrypt from http import HTTPStatus import fastapi import logging import config import asyncio import aiohttp from sqlcode.sql_agent import create_sql_agent, remove_markdown_code_block from langchain_community.agent_toolkits import SQLDatabaseToolkit from langchain_community.utilities import SQLDatabase from langchain.agents.agent_types import AgentType from langchain_core.tools import StructuredTool from langchain.prompts import PromptTemplate from langchain_chroma import Chroma from langchain_community.embeddings import DashScopeEmbeddings from sqlcode.utils import parse, format_docs from langchain_core.runnables import RunnablePassthrough from sqlcode.multi_agent import create_sql_graph app = fastapi.FastAPI() logger = logging.getLogger('sqlcode') def get_wxcpt(): ''' 创建微信企业后台接口的加密解密对象 ''' wxbiz_config = config.bizwechat_config() return WXBizMsgCrypt(sToken=wxbiz_config.token, sEncodingAESKey=wxbiz_config.aes_key, sReceiveId=wxbiz_config.corp_id) @app.get('/', include_in_schema=False) async def verify_url(request:fastapi.Request): ''' 验证微信公众号的请求URL的合法性 ''' signature = request.query_params.get('msg_signature') timestamp = request.query_params.get('timestamp') nonce = request.query_params.get('nonce') echostr = request.query_params.get('echostr') if not signature or not timestamp or not nonce or not echostr: logger.error('verify_url failed, missing parameters') return fastapi.Response('', HTTPStatus.BAD_REQUEST) code, echostr = get_wxcpt().VerifyURL(signature, timestamp, nonce, echostr) if code == 0: logger.info('verify_url success, echostr: %s', echostr) return fastapi.Response(echostr, HTTPStatus.OK) else: logger.error('verify_url failed, error code: %s', code) return fastapi.Response('', HTTPStatus.BAD_REQUEST) @app.post('/', include_in_schema=False) async def receive_message(request:fastapi.Request): ''' 接收微信公众号消息,并回复 ''' signature = request.query_params.get('msg_signature') timestamp = request.query_params.get('timestamp') nonce = request.query_params.get('nonce') if not signature or not timestamp or not nonce: logger.error('receive_message failed, missing parameters') return '', HTTPStatus.BAD_REQUEST logger.info('receive_message timestamp: %s, nonce: %s', timestamp, nonce) post_data = await request.body() code, post_data = get_wxcpt().DecryptMsg(post_data, signature, timestamp, nonce) if code != 0: logger.error('receive_message failed, error code: %s', code) return '', HTTPStatus.BAD_REQUEST xml = ET.fromstring(post_data) content = xml.find('Content').text from_user = xml.find('FromUserName').text asyncio.create_task(async_query_and_reply( apikey=config.bizwechat_config().qgi_api_key, model_name='qwen', database='contracts', question=content, to_user=from_user, )) logger.info('receive_message success, timestamp: %s, nonce: %s', timestamp, nonce) return '' from sqlcode.qgi import Executor, NewFormatter, ReturnType # from sqlcode.qwenapi import Generator from sqlcode.langchain_model import Generator from sqlcode.modelloader import ModelLoader, ModelManager async def async_query_and_reply(apikey:str, model_name:str, database:str, question:str, to_user:str): """ 根据请求执行查询并返回结果。 该函数接收一个模型名称、数据库名称和一个请求对象,通过这些信息来执行特定的查询操作。 查询的结果会根据请求中指定的返回类型进行处理和包装。 * :param model: 字符串类型,表示要使用的大模型的名称。例如 qwen-turbo 等 * :param database: 字符串类型,表示要查询的数据库的名称。 * :param req: Request 类型的对象,包含查询的具体问题和期望的返回类型。 """ logger.info('开始处理问题:%s', question) metadata = config.metadata(database) qwen_cfg = config.qwen_config("qwen_graph.conf") re_cfg = config.refineProblem_config() modelLoader = ModelLoader(config.model_config()) modelManager = ModelManager(modelLoader) modelManager.switch_model(model_name) model = modelManager.get_model() # db_dir = "chroma_db" # # 需要先检查向量数据库是否为最新,执行sqlcode/store_vecstore.py来更新向量数据库 # vectorstore = Chroma(persist_directory=db_dir, embedding_function=DashScopeEmbeddings()) # retriever = vectorstore.as_retriever(search_kwargs={'k': 2}) # re_prompt = PromptTemplate.from_template(re_cfg.prompt) # context_chain = retriever | format_docs # context = context_chain.invoke(question) # rag_chain = ( # {"context": lambda x: context, "question": RunnablePassthrough()} # | re_prompt # | model # | parse # ) # refine_question = rag_chain.invoke(question) # db = SQLDatabase.from_uri(metadata.connection_string) # toolkit = SQLDatabaseToolkit(db=db, llm=model) # remove_markdown_tool = StructuredTool.from_function( # func=remove_markdown_code_block, # description="当发生因为markdown标记导致的函数输入错误时,可以调用此函数来删除标记", # name="remove_markdown_code_block" # ) # prompt = PromptTemplate.from_template(qwen_cfg.prompt) # prompt = prompt.partial(metadata=metadata.metadata, product=metadata.product, example=qwen_cfg.params["example"], context=context) # agent_executor = create_sql_agent( # llm=model, # prompt=prompt, # toolkit=toolkit, # verbose=True, # agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, # extra_tools=[remove_markdown_tool], # agent_executor_kwargs={"handle_parsing_errors": True} # ) agent_executor = create_sql_graph(model=model, qwen_cfg=qwen_cfg, data_cfg=metadata) formatter = NewFormatter(format=ReturnType.WX_MD, tranlate=None, output_dir=config.osp.join(config.BASE_DIR, 'output'), site_url=config.app_config().site_url + '/output', ) generator = Generator(agentExcutor=agent_executor, messages=[{'role': 'system', 'content': qwen_cfg.system}], apikey=config.api_key(apikey).dashscope_api_key, seed=0, ) executor = Executor(generator=generator, formatter=formatter) # input = {"input": question, "refined_question": refine_question} input = {"messages": [("human", question)], "iterations": 0} ret = executor.query(connection_string=metadata.connection_string, input=input, ) # thought and result 分开发送因为消息大小限制为 2048 字节(utf-8) await send_msg(to_user, question, ret.error) if ret.thought.startswith('```') and ret.thought.endswith('```'): await send_msg(to_user, question, '针对这个问题,采用 SQL 查询:\n' + ret.thought) else: await send_msg(to_user, question, ret.thought) await send_msg(to_user, question, ret.result) async def send_msg(to_user:str, question:str, content:str): ''' 发送问题的回复给指定用户 ''' if not content: return msg = { "touser" : to_user, "msgtype": "markdown", "agentid" : config.bizwechat_config().agent_id, "markdown": {"content": content,}, } access_token = config.wxbiz_token() url = f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}" async with aiohttp.ClientSession() as session: async with session.post(url, json=msg) as resp: resp_text = await resp.text() logger.debug('send_msg to %s %s: %s', to_user, question, resp_text) if __name__ == '__main__': import sys if len(sys.argv) > 1 and sys.argv[1] == 'test': logging.basicConfig(level=logging.DEBUG) asyncio.run(async_query_and_reply(apikey=config.bizwechat_config().qgi_api_key, model='qwen-max', database='contracts', question="2020年业绩最好的分公司", to_user='SunHaiWen', )) elif len(sys.argv) > 1 and sys.argv[1] == 'send': logging.basicConfig(level=logging.DEBUG) asyncio.run(send_msg(to_user='SunHaiWen', question="2020年业绩最好的分公司", content="2020年业绩最好的分公司是北京分公司", )) else: import uvicorn uvicorn.run(app, host='0.0.0.0', port=9000)