commit 6a7283484a152b0fe8148746216b21ae3eca70a5
Author: XLZ <1208121887@qq.com>
Date: Mon Feb 17 10:34:35 2025 +0800
Initial commit
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..921284e
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,149 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+nohup.out
+cert/
+pg-data/pgdata/
+pg-data/*.xlsx
+pg-data/*.rar
+log/
+output/
+config/wxbiz-access-token
+test/
+chroma_db/
\ No newline at end of file
diff --git a/.htaccess b/.htaccess
new file mode 100644
index 0000000..0c54fb9
--- /dev/null
+++ b/.htaccess
@@ -0,0 +1 @@
+# 请将伪静态规则或自定义Apache配置填写到此处
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..c8ba3cf
--- /dev/null
+++ b/README.md
@@ -0,0 +1,36 @@
+# qwensql
+
+#### 介绍
+
+#### 软件架构
+软件架构说明
+
+
+#### 安装教程
+
+1. xxxx
+2. xxxx
+3. xxxx
+
+#### 使用说明
+
+1. xxxx
+2. xxxx
+3. xxxx
+
+#### 参与贡献
+
+1. Fork 本仓库
+2. 新建 Feat_xxx 分支
+3. 提交代码
+4. 新建 Pull Request
+
+
+#### 特技
+
+1. 使用 Readme\_XXX.md 来支持不同的语言,例如 Readme\_en.md, Readme\_zh.md
+2. Gitee 官方博客 [blog.gitee.com](https://blog.gitee.com)
+3. 你可以 [https://gitee.com/explore](https://gitee.com/explore) 这个地址来了解 Gitee 上的优秀开源项目
+4. [GVP](https://gitee.com/gvp) 全称是 Gitee 最有价值开源项目,是综合评定出的优秀开源项目
+5. Gitee 官方提供的使用手册 [https://gitee.com/help](https://gitee.com/help)
+6. Gitee 封面人物是一档用来展示 Gitee 会员风采的栏目 [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/)
diff --git a/bizwechat/__init__.py b/bizwechat/__init__.py
new file mode 100644
index 0000000..66aedcb
--- /dev/null
+++ b/bizwechat/__init__.py
@@ -0,0 +1,300 @@
+#!/usr/bin/env python
+# -*- encoding:utf-8 -*-
+
+""" 微信企业后台 接口.
+@copyright: Copyright (c) 1998-2014 Tencent Inc.
+
+"""
+# ------------------------------------------------------------------------
+import logging
+import base64
+import random
+import hashlib
+import time
+import struct
+from Crypto.Cipher import AES
+import xml.etree.cElementTree as ET
+import socket
+
+
+
+#########################################################################
+# Author: jonyqin
+# Created Time: Thu 11 Sep 2014 01:53:58 PM CST
+# File Name: ierror.py
+# Description:定义错误码含义
+#########################################################################
+WXBizMsgCrypt_OK = 0
+WXBizMsgCrypt_ValidateSignature_Error = -40001
+WXBizMsgCrypt_ParseXml_Error = -40002
+WXBizMsgCrypt_ComputeSignature_Error = -40003
+WXBizMsgCrypt_IllegalAesKey = -40004
+WXBizMsgCrypt_ValidateCorpid_Error = -40005
+WXBizMsgCrypt_EncryptAES_Error = -40006
+WXBizMsgCrypt_DecryptAES_Error = -40007
+WXBizMsgCrypt_IllegalBuffer = -40008
+WXBizMsgCrypt_EncodeBase64_Error = -40009
+WXBizMsgCrypt_DecodeBase64_Error = -40010
+WXBizMsgCrypt_GenReturnXml_Error = -40011
+
+"""
+关于Crypto.Cipher模块,ImportError: No module named 'Crypto'解决方案
+请到官方网站 https://www.dlitz.net/software/pycrypto/ 下载pycrypto。
+下载后,按照README中的“Installation”小节的提示进行pycrypto安装。
+"""
+
+
+class FormatException(Exception):
+ pass
+
+
+def throw_exception(message, exception_class=FormatException):
+ """my define raise exception function"""
+ raise exception_class(message)
+
+
+class SHA1:
+ """计算企业微信的消息签名接口"""
+
+ def getSHA1(self, token, timestamp, nonce, encrypt):
+ """用SHA1算法生成安全签名
+ @param token: 票据
+ @param timestamp: 时间戳
+ @param encrypt: 密文
+ @param nonce: 随机字符串
+ @return: 安全签名
+ """
+ try:
+ sortlist = [token, timestamp, nonce, encrypt]
+ sortlist.sort()
+ sha = hashlib.sha1()
+ sha.update("".join(sortlist).encode())
+ return WXBizMsgCrypt_OK, sha.hexdigest()
+ except Exception as e:
+ logger = logging.getLogger('sqlcode')
+ logger.error(e)
+ return WXBizMsgCrypt_ComputeSignature_Error, None
+
+
+class XMLParse:
+ """提供提取消息格式中的密文及生成回复消息格式的接口"""
+
+ # xml消息模板
+ AES_TEXT_RESPONSE_TEMPLATE = """
+
+
+%(timestamp)s
+
+"""
+
+ def extract(self, xmltext):
+ """提取出xml数据包中的加密消息
+ @param xmltext: 待提取的xml字符串
+ @return: 提取出的加密消息字符串
+ """
+ try:
+ xml_tree = ET.fromstring(xmltext)
+ encrypt = xml_tree.find("Encrypt")
+ return WXBizMsgCrypt_OK, encrypt.text
+ except Exception as e:
+ logger = logging.getLogger('sqlcode')
+ logger.error(e)
+ return WXBizMsgCrypt_ParseXml_Error, None
+
+ def generate(self, encrypt, signature, timestamp, nonce):
+ """生成xml消息
+ @param encrypt: 加密后的消息密文
+ @param signature: 安全签名
+ @param timestamp: 时间戳
+ @param nonce: 随机字符串
+ @return: 生成的xml字符串
+ """
+ resp_dict = {
+ 'msg_encrypt': encrypt,
+ 'msg_signaturet': signature,
+ 'timestamp': timestamp,
+ 'nonce': nonce,
+ }
+ resp_xml = self.AES_TEXT_RESPONSE_TEMPLATE % resp_dict
+ return resp_xml
+
+
+class PKCS7Encoder():
+ """提供基于PKCS7算法的加解密接口"""
+
+ block_size = 32
+
+ def encode(self, text):
+ """ 对需要加密的明文进行填充补位
+ @param text: 需要进行填充补位操作的明文
+ @return: 补齐明文字符串
+ """
+ text_length = len(text)
+ # 计算需要填充的位数
+ amount_to_pad = self.block_size - (text_length % self.block_size)
+ if amount_to_pad == 0:
+ amount_to_pad = self.block_size
+ # 获得补位所用的字符
+ pad = chr(amount_to_pad)
+ return text + (pad * amount_to_pad).encode()
+
+ def decode(self, decrypted):
+ """删除解密后明文的补位字符
+ @param decrypted: 解密后的明文
+ @return: 删除补位字符后的明文
+ """
+ pad = ord(decrypted[-1])
+ if pad < 1 or pad > 32:
+ pad = 0
+ return decrypted[:-pad]
+
+
+class Prpcrypt(object):
+ """提供接收和推送给企业微信消息的加解密接口"""
+
+ def __init__(self, key):
+
+ # self.key = base64.b64decode(key+"=")
+ self.key = key
+ # 设置加解密模式为AES的CBC模式
+ self.mode = AES.MODE_CBC
+
+ def encrypt(self, text, receiveid):
+ """对明文进行加密
+ @param text: 需要加密的明文
+ @return: 加密得到的字符串
+ """
+ # 16位随机字符串添加到明文开头
+ text = text.encode()
+ text = self.get_random_str() + struct.pack("I", socket.htonl(len(text))) + text + receiveid.encode()
+
+ # 使用自定义的填充方式对明文进行补位填充
+ pkcs7 = PKCS7Encoder()
+ text = pkcs7.encode(text)
+ # 加密
+ cryptor = AES.new(self.key, self.mode, self.key[:16])
+ try:
+ ciphertext = cryptor.encrypt(text)
+ # 使用BASE64对加密后的字符串进行编码
+ return WXBizMsgCrypt_OK, base64.b64encode(ciphertext)
+ except Exception as e:
+ logger = logging.getLogger('sqlcode')
+ logger.error(e)
+ return WXBizMsgCrypt_EncryptAES_Error, None
+
+ def decrypt(self, text, receiveid):
+ """对解密后的明文进行补位删除
+ @param text: 密文
+ @return: 删除填充补位后的明文
+ """
+ try:
+ cryptor = AES.new(self.key, self.mode, self.key[:16])
+ # 使用BASE64对密文进行解码,然后AES-CBC解密
+ plain_text = cryptor.decrypt(base64.b64decode(text))
+ except Exception as e:
+ logger = logging.getLogger('sqlcode')
+ logger.error(e)
+ return WXBizMsgCrypt_DecryptAES_Error, None
+ try:
+ pad = plain_text[-1]
+ # 去掉补位字符串
+ # pkcs7 = PKCS7Encoder()
+ # plain_text = pkcs7.encode(plain_text)
+ # 去除16位随机字符串
+ content = plain_text[16:-pad]
+ xml_len = socket.ntohl(struct.unpack("I", content[: 4])[0])
+ xml_content = content[4: xml_len + 4]
+ from_receiveid = content[xml_len + 4:]
+ except Exception as e:
+ logger = logging.getLogger('sqlcode')
+ logger.error(e)
+ return WXBizMsgCrypt_IllegalBuffer, None
+
+ if from_receiveid.decode('utf8') != receiveid:
+ return WXBizMsgCrypt_ValidateCorpid_Error, None
+ return 0, xml_content
+
+ def get_random_str(self):
+ """ 随机生成16位字符串
+ @return: 16位字符串
+ """
+ return str(random.randint(1000000000000000, 9999999999999999)).encode()
+
+
+class WXBizMsgCrypt(object):
+ # 构造函数
+ def __init__(self, sToken, sEncodingAESKey, sReceiveId):
+ try:
+ self.key = base64.b64decode(sEncodingAESKey + "=")
+ assert len(self.key) == 32
+ except:
+ throw_exception("[error]: EncodingAESKey unvalid !", FormatException)
+ # return ierror.WXBizMsgCrypt_IllegalAesKey,None
+ self.m_sToken = sToken
+ self.m_sReceiveId = sReceiveId
+
+ # 验证URL
+ # @param sMsgSignature: 签名串,对应URL参数的msg_signature
+ # @param sTimeStamp: 时间戳,对应URL参数的timestamp
+ # @param sNonce: 随机串,对应URL参数的nonce
+ # @param sEchoStr: 随机串,对应URL参数的echostr
+ # @param sReplyEchoStr: 解密之后的echostr,当return返回0时有效
+ # @return:成功0,失败返回对应的错误码
+
+ def VerifyURL(self, sMsgSignature, sTimeStamp, sNonce, sEchoStr):
+ sha1 = SHA1()
+ ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, sEchoStr)
+ if ret != 0:
+ return ret, None
+ if not signature == sMsgSignature:
+ return WXBizMsgCrypt_ValidateSignature_Error, None
+ pc = Prpcrypt(self.key)
+ ret, sReplyEchoStr = pc.decrypt(sEchoStr, self.m_sReceiveId)
+ return ret, sReplyEchoStr
+
+ def EncryptMsg(self, sReplyMsg, sNonce, timestamp=None):
+ # 将企业回复用户的消息加密打包
+ # @param sReplyMsg: 企业号待回复用户的消息,xml格式的字符串
+ # @param sTimeStamp: 时间戳,可以自己生成,也可以用URL参数的timestamp,如为None则自动用当前时间
+ # @param sNonce: 随机串,可以自己生成,也可以用URL参数的nonce
+ # sEncryptMsg: 加密后的可以直接回复用户的密文,包括msg_signature, timestamp, nonce, encrypt的xml格式的字符串,
+ # return:成功0,sEncryptMsg,失败返回对应的错误码None
+ pc = Prpcrypt(self.key)
+ ret, encrypt = pc.encrypt(sReplyMsg, self.m_sReceiveId)
+ encrypt = encrypt.decode('utf8')
+ if ret != 0:
+ return ret, None
+ if timestamp is None:
+ timestamp = str(int(time.time()))
+ # 生成安全签名
+ sha1 = SHA1()
+ ret, signature = sha1.getSHA1(self.m_sToken, timestamp, sNonce, encrypt)
+ if ret != 0:
+ return ret, None
+ xmlParse = XMLParse()
+ return ret, xmlParse.generate(encrypt, signature, timestamp, sNonce)
+
+ def DecryptMsg(self, sPostData, sMsgSignature, sTimeStamp, sNonce):
+ # 检验消息的真实性,并且获取解密后的明文
+ # @param sMsgSignature: 签名串,对应URL参数的msg_signature
+ # @param sTimeStamp: 时间戳,对应URL参数的timestamp
+ # @param sNonce: 随机串,对应URL参数的nonce
+ # @param sPostData: 密文,对应POST请求的数据
+ # xml_content: 解密后的原文,当return返回0时有效
+ # @return: 成功0,失败返回对应的错误码
+ # 验证安全签名
+ xmlParse = XMLParse()
+ ret, encrypt = xmlParse.extract(sPostData)
+ if ret != 0:
+ return ret, None
+ sha1 = SHA1()
+ ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, encrypt)
+ if ret != 0:
+ return ret, None
+ if not signature == sMsgSignature:
+ return WXBizMsgCrypt_ValidateSignature_Error, None
+ pc = Prpcrypt(self.key)
+ ret, xml_content = pc.decrypt(encrypt, self.m_sReceiveId)
+ return ret, xml_content
+
+
diff --git a/charts.py b/charts.py
new file mode 100644
index 0000000..b99313c
--- /dev/null
+++ b/charts.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python
+# -*- encoding:utf-8 -*-
+
+import asyncio
+from datetime import datetime
+from http.client import HTTPException
+
+from openpyxl.reader.excel import load_workbook
+from starlette.responses import JSONResponse
+from fastapi import FastAPI
+import config
+
+app = FastAPI()
+
+def json_serializable(obj):
+ if isinstance(obj, datetime):
+ return obj.isoformat() # 转换为 ISO 格式的字符串
+ raise TypeError(f"Type {type(obj)} not serializable")
+
+@app.get("/read-excel")
+async def read_excel_rows():
+ try:
+ rows="1,2,3,4,5"
+ # 加载Excel工作簿
+ file_path =config.osp.join(config.BASE_DIR, 'output/','4834ed97e0ba477b9d239560e4b12be6.xlsx')
+ workbook = load_workbook(filename=file_path)
+ sheet = workbook.active # 或者使用workbook.get_sheet_by_name('Sheet1')
+
+ # 获取要读取的行号列表
+ # row_numbers = [int(r) for r in rows.split(',') if r.isdigit()]
+
+ # 读取指定行的数据
+ data = []
+
+ for row in sheet:
+ row_data = [json_serializable(cell.value) if isinstance(cell.value, datetime) else cell.value for cell in row]
+ data.append(row_data)
+ return JSONResponse(content={"data": data})
+
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+if __name__ == '__main__':
+
+ import uvicorn
+ uvicorn.run(app, host='0.0.0.0', port=9001)
\ No newline at end of file
diff --git a/config/__init__.py b/config/__init__.py
new file mode 100644
index 0000000..caef719
--- /dev/null
+++ b/config/__init__.py
@@ -0,0 +1,134 @@
+from os import path as osp
+import rtoml
+from typing import Any, Optional
+import requests
+import time
+from threading import Lock
+from os import environ
+from dataclasses import dataclass
+
+ENVIRONMENT = environ.get('ENVIRONMENT', 'development')
+CONFIG_DIR = osp.dirname(__file__)
+BASE_DIR = osp.realpath(osp.join(CONFIG_DIR, '..'))
+
+def config_file(*paths):
+ return osp.join(CONFIG_DIR, ENVIRONMENT, *paths)
+
+def load_config(*paths) -> dict[str, Any]:
+ with open(config_file(*paths), 'r') as f:
+ return rtoml.load(f)
+
+@dataclass
+class DatabaseConfig:
+ connection_string: str
+ metadata: str
+ type: str
+ product: str
+
+def metadata(database:str):
+ d = load_config('database', database+'.conf')
+ return DatabaseConfig(**d)
+
+@dataclass
+class QwenConfig:
+ system: str
+ prompt: str
+ params: dict[str, Any]
+
+ def gen_prompt(self, database:DatabaseConfig):
+ return self.prompt.format(database=database, **self.params)
+
+def qwen_config(path: Optional[str] = None):
+ if path is not None:
+ d = load_config(path)
+ else:
+ d = load_config('myqwen.conf')
+ return QwenConfig(**d)
+
+
+@dataclass
+class ApiKeyConfig:
+ dashscope_api_key: str
+ admin: bool
+ databases: list[str]
+
+def api_key(key:str):
+ d = load_config('apikeys.conf')
+ if key in d:
+ return ApiKeyConfig(**d[key])
+ else:
+ return None
+
+@dataclass
+class BizWechatConfig:
+ corp_id:str
+ corp_secret:str
+ agent_id:int
+ token:str
+ aes_key:str
+ qgi_api_key:str
+
+@dataclass
+class AppConfig:
+ site_url:str
+
+def app_config():
+ d = load_config('app.conf')
+ d.pop('bizwechat')
+ return AppConfig(**d)
+
+def bizwechat_config():
+ d = load_config('app.conf')
+ return BizWechatConfig(**d['bizwechat'])
+
+@dataclass
+class RefineProblemConfig:
+ prompt:str
+
+def refineProblem_config(path: Optional[str] = None):
+ if path is not None:
+ d = load_config(path)
+ else:
+ d = load_config("problem_refine.conf")
+ return RefineProblemConfig(**d)
+
+import yaml
+def model_config():
+ with open(config_file("model.yaml"), 'r') as file:
+ return yaml.safe_load(file)
+
+class WxBiz:
+ lock = Lock()
+ access_token:str = None
+ expire_time:float = 0
+
+def wxbiz_token():
+ with WxBiz.lock:
+ cfg_file = osp.join(CONFIG_DIR, 'wxbiz-access-token')
+ if WxBiz.expire_time == 0 and osp.exists(cfg_file):
+ with open(cfg_file, 'r') as f:
+ token = f.read()
+ items = token.split(':')
+ WxBiz.expire_time = float(items[0])
+ WxBiz.access_token = items[1]
+
+ if WxBiz.expire_time < time.time():
+ cfg = bizwechat_config()
+ data = requests.get(f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={cfg.corp_id}&corpsecret={cfg.corp_secret}').json()
+
+ # 假设返回的token数据格式如下
+ # {
+ # "errcode": 0,
+ # "errmsg": "ok",
+ # "access_token": "ACCESS_TOKEN",
+ # "expires_in": 7200
+ # }
+ if 'errcode' in data and data['errcode'] == 0:
+ WxBiz.access_token = data['access_token']
+ WxBiz.expire_time = time.time() + data['expires_in'] - 60
+ with open(cfg_file, 'w') as f:
+ f.write(f"{WxBiz.expire_time}:{WxBiz.access_token}")
+ else:
+ raise RuntimeError("Failed to refresh token:", data)
+
+ return WxBiz.access_token
diff --git a/config/development/apikeys.conf b/config/development/apikeys.conf
new file mode 100644
index 0000000..5271729
--- /dev/null
+++ b/config/development/apikeys.conf
@@ -0,0 +1,10 @@
+[ccscc]
+dashscope_api_key = "sk-6b39b56d21aa4406b0c67061f2e31e81"
+admin = false
+databases = ["students"]
+
+["YUVietLgiGmtqzYUVIIGjrNoLMsGM0FI"]
+dashscope_api_key = "sk-6b39b56d21aa4406b0c67061f2e31e81"
+admin = false
+databases = ["contracts"]
+
diff --git a/config/development/app.conf b/config/development/app.conf
new file mode 100644
index 0000000..92a370e
--- /dev/null
+++ b/config/development/app.conf
@@ -0,0 +1,11 @@
+# 站点配置
+site_url = "http://111.230.243.127:9000"
+
+# 企业微信接口参数
+[bizwechat]
+token = "8kUGYXi"
+aes_key = "A5RyPqAu5UYBGI4QJTqLbBVyHXvevIUsaMrhct1lpxo"
+corp_id = "wwcbc2d6338dd362d0"
+corp_secret = "8tWn0YsuOdZc3xcV4HjDy2nRZJ9i9KIGHQq4vUjwYzk"
+agent_id = 1000003
+qgi_api_key = "YUVietLgiGmtqzYUVIIGjrNoLMsGM0FI"
diff --git a/config/development/database/contracts.conf b/config/development/database/contracts.conf
new file mode 100644
index 0000000..334277f
--- /dev/null
+++ b/config/development/database/contracts.conf
@@ -0,0 +1,25 @@
+# 数据库配置 0910
+
+connection_string = "mysql+pymysql://root:Ccscc_2025@10.1.12.6:3306/contracts?charset=utf8mb4"
+
+type = "MySQL"
+product = "MySQL 5.7"
+
+metadata = """
+CREATE TABLE `contracts` (
+`经办人` VARCHAR(30) comment '示例:中通服建设有限公司综合能源分公司-业务支撑中心-祝瑞敏',
+`合同形式` VARCHAR(10) comment '枚举值:单项合同、订单合同、确收单合同、框架子合同、框架合同、结算单',
+`所属分公司` VARCHAR(16) comment '枚举值:一分公司、二分公司、三分公司、四分公司、五分公司、六分公司、七分公司、北京分公司、数字基建分公司、上海分公司、智网分公司、河北分公司、综合能源分公司、本部',
+`合同名称` VARCHAR(137) comment '可以提取到项目地点、时间、客户名称、专业的相关信息(示例:2019-2020年株洲政企信息化职称技术配合服务采购订单)',
+`项目来源` VARCHAR(10) comment '枚举值:招投标、委托、邀标',
+`专业` VARCHAR(40) comment '示例:系统集成-信息系统集成服务-楼宇智能化',
+`地点` VARCHAR(35) comment '示例:中国-广东省-广州市-天河区',
+`客商类型` VARCHAR(45) comment '示例:集团客户-建筑与房地产-建筑与房地产',
+`客户名称` VARCHAR(54) comment '示例:广州铁路公安局、中国移动通信集团安徽有限公司宣城分公司',
+`合同签订金额(人民币)` float comment '“超大项目”金额大于等于1亿,“重大项目”金额大于1000万而小于1亿,“一般项目”金额大于500万而小于1000万,“小项目”金额小于500万',
+`签订日期` date comment '以CURRENT_DATE获取的时间为准作为当前日期',
+`合同有效期(结束)` date comment '示例:2001-01-01'
+)
+
+"""
+
diff --git a/config/development/database/contracts_0.conf b/config/development/database/contracts_0.conf
new file mode 100644
index 0000000..a4b1965
--- /dev/null
+++ b/config/development/database/contracts_0.conf
@@ -0,0 +1,197 @@
+# 数据库配置
+
+connection_string = "mysql+pymysql://root:H1wNPOz3@172.16.16.13:3306/contracts?charset=utf8mb4"
+
+type = "MySQL"
+product = "MySQL 5.7"
+
+metadata = """
+# 数据库表字段描述
+[table_sql]
+`经办人` VARCHAR(30),
+【描述】`经办人`是指项目经理,提问 “谁的项目”、“项目经理” 等类似字眼时,通常涉及对该字段进行筛选。
+【举例】中通服建设七分公司-湖南分公司交付项目部-蔡胜华|中通服建设一分公司-河北集客项目部-常楠|...
+
+`经办单位` VARCHAR(16),
+【描述】`经办单位`是指各个分公司的下属部门,提问 “部门” 类似字眼时,通常涉及对该字段进行筛选。
+【举例】一分集客项目部|业务支撑中心|采购管理中心|网优交付项目部|...
+
+`经办日期` date,
+
+`所属分公司` VARCHAR(16),
+【描述】提问 “一分”、“北分”、“数分”、“智网”、“四分”、“七分公司”、“综合能源分公司” 等类似字眼时,通常涉及对该字段进行筛选。
+【所有可能的值】中通服建设有限公司一分公司|中通服建设有限公司二分公司|中通服建设有限公司三分公司|中通服建设有限公司四分公司|中通服建设有限公司五分公司|中通服建设有限公司六分公司|中通服建设有限公司七分公司|中通服建设有限公司北京分公司|中通服建设有限公司数字基建分公司|中通服建设有限公司上海分公司|中通服建设有限公司智网分公司|中通服建设有限公司河北分公司|中通服建设有限公司综合能源分公司|中通服建设有限公司本部
+
+`合同形式` VARCHAR(10),
+【所有可能的值】单项合同|订单合同|确收单合同|框架子合同|框架合同|结算单
+
+`是否主合同` bool,
+【所有可能的值】0|1|
+
+`合同名称` VARCHAR(137),
+【描述】`合同名称`即项目名称,从中可能提取到项目`地点`、`时间`、`客户名称`、`最终客户名称`、`专业`的相关信息。
+
+`合同编号` VARCHAR(37),
+
+`框架合同编号` VARCHAR(23),
+
+`框架合同名称` VARCHAR(85),
+
+`主合同编号` VARCHAR(25),
+
+`主合同名称` VARCHAR(103),
+
+`项目来源` VARCHAR(10),
+【描述】提问“招投标”“委托”“邀标”相关字眼时,通常涉及对该字段进行筛选。
+【所有可能的值】招投标|委托|邀标
+
+`投标项目名称` VARCHAR(92),
+【描述】`投标项目名称`和`合同名称`描述基本一致。
+
+`编号生成时间` date,
+
+`专业` VARCHAR(40),
+【举例】系统集成-信息系统集成服务-视频监控集成|工程设计-勘察设计-其他勘查设计-其他|工程施工-设备工程-通信设备安装调试-基站|其他-其他-咨询服务|工程施工-管线工程-通信线路施工-线路|系统集成-信息系统集成服务-其他|工程施工-管线工程-通信管道施工-本地网管道|工程施工-建筑智能化-智能化及集成|工程施工-设备工程-通信设备安装调试-数据-网络交换设备|工程施工-管线工程-通信线路施工-电缆|...
+
+`地点` VARCHAR(35),
+【描述】`地点`的值只包含省市区县的内容,不会包含一些常见的地区俗称。
+【注意】涉及地区俗称时,需要分析其所在的省市区县信息进行筛选,不能用地区俗称进行筛选。如:提问“京津冀”的项目时,筛选的`地点`应该是北京、天津或河北,而不是直接筛选 “京津冀”。
+
+`是否关联交易` bool,
+【所有可能的值】0|1|
+
+`合同类型名称` VARCHAR(29),
+【描述】`合同类型名称`通常和`专业`有关。
+【举例】市场经营收入类|系统集成服务类|工程施工类|工程设计类|工程总包收入|工程分包收入|通信网络维护类|设施管理类|国际类|国际贸易服务收入
+
+`聚焦行业` VARCHAR(19),
+
+`管理分公司` VARCHAR(16),
+
+`建议实施单位` VARCHAR(55),
+
+`项目部` VARCHAR(26),
+【描述】该字段和经办单位的意思一致。
+
+`最小经营单元` VARCHAR(17),
+【描述】`最小经营单位`结合了`所属分公司`和`项目部`的内容。
+
+`省公司统一编号` VARCHAR(31),
+
+`统一编号生成时间` date,
+
+`客户名称` VARCHAR(54),
+【举例】广州铁路公安局|广东电网有限责任公司广州供电局|广东电网有限责任公司广州供电局|中国移动通信集团安徽有限公司宣城分公司|中国移动通信集团安徽有限公司宣城分公司|中国移动通信集团安徽有限公司宣城分公司|中国电信股份有限公司合肥分公司|长沙海关技术中心|中共广东省委办公厅|南方电网数字平台科技(广东)有限公司|...
+
+`运营商` VARCHAR(10),
+【所有可能的值】中国电信|中国移动|中国联通|中国广电|中国铁塔|其他
+
+`中通服客商类型` VARCHAR(45),
+【举例】集团客户-建筑与房地产-建筑与房地产|集团客户-党政-党政管理|中国电信-主业上市-广东分公司|集团客户-中小聚类-中小企业|中国广电-中国广电网络集团-股份公司-广东省广播电视网络股份有限公司 (广东广电)|集团客户-互联网与IT传媒-互联网与IT科技|中国联通-各分公司-上海市分公司|中国电信-主业存续-广东省电信公司|中国电信-实业上市-安徽通服|中国电信-主业存续-山西分公司|...
+
+`合同签订金额(人民币)` float,
+【描述】`合同签订金额(人民币)`反映了项目的规模,提问“超大项目”“重大项目”“一般项目”“小项目”“营业额”等类似字眼时,通常涉及对该字段进行筛选。
+【注意】“超大项目”金额大于等于1亿,“重大项目”金额大于1000万而小于1亿,“一般项目”金额大于200万而小于1000万,“小项目”金额小于200万。有时需要计算的是金额的总值,有时需要计算平均值。
+
+`合同签订金额(不含税)` float,
+
+`是否垫资` bool,
+【所有可能的值】0|1|
+
+`垫资金额(元)` float,
+
+`垫资说明` VARCHAR(491),
+
+`签订日期` date,
+【描述】提问到“近几年”“去年”“今年”“上个季度”与项目日期相关内容时,通常涉及对该字段进行筛选。
+【注意】以CURRENT_DATE获取的时间为准作为当前日期。
+
+`签署日期` date,
+
+`合同有效期(开始)` date,
+
+`合同有效期(结束)` date,
+
+`最终客户名称` VARCHAR(57),
+【描述】`最终客户名称`描述和`客户名称描述一致`。
+
+`最终中通服客商类型` VARCHAR(45),
+【描述】`最终中通服客商类型`和`中通服客商类型`描述一致。
+
+`税率` float,
+
+`是否通服内部合作` bool,
+【描述】提问“内部”“内部项目”“内部合作”“通服内部”等类似字眼时,通常涉及对该字段进行筛选。
+【所有可能的值】0|1|
+
+`项目组织模式` VARCHAR(13),
+【所有可能的值】非总包非全咨|总包-过程总包-PC总包|总包-过程总包-EPC总包|总包-过程总包-施工总包|总包-过程总包-DB总包|全过程咨询|非总包非全过程咨询|总包-过程总包-EP总包
+
+`合同结算金额(含税)` float,
+
+`列账收入(含税)` float,
+
+`开票金额(含税)` float,
+
+`收款金额(含税)` float,
+
+`是否业务关闭` bool,
+【所有可能的值】0|1|
+
+`业务关闭时间` date,
+
+`是否财务关闭` bool,
+【所有可能的值】0|1|
+
+`财务关闭时间` date,
+
+`甲方订单编号` VARCHAR(256),
+
+`甲方合同编号` VARCHAR(68),
+
+`框架子合同编号` VARCHAR(106),
+
+`确收类型` VARCHAR(10),
+
+`业务拓展方式` VARCHAR(10),
+【所有可能的值】合作拓展|自主拓展|联合拓展|主业总包,通服分包|LH
+
+`主实业协同` bool,
+【所有可能的值】0|1|
+
+`协同类型` VARCHAR(10),
+
+`主业合同金额` float,
+
+`对方联系人` VARCHAR(17),
+
+`对方联系电话` VARCHAR(18),
+
+`中标时间` date,
+
+`协同拓展的主业公司` VARCHAR(10),
+
+`主业合同额` float,
+
+`是否运营商政企` bool
+【所有可能的值】0|1|
+)
+
+# 数据库表不同字段要求
+[field_requirement]
+- 如果要对 '经办人' 进行筛选,必须使用 LIKE 语句进行模糊匹配。
+- 如果要对 '合同名称' 进行筛选,必须使用 LIKE 语句进行模糊匹配。
+- 如果要对 '客户名称' 进行筛选,必须使用 LIKE 语句进行模糊匹配。
+- 如果要对 '最终客户名称' 进行筛选,必须使用 LIKE 语句进行模糊匹配。
+- 如果要对 '所属分公司' 进行筛选,必须使用 LIKE 语句进行模糊匹配。
+- 如果要对 '经办单位' 进行筛选,必须使用 LIKE 语句进行模糊匹配。
+- 如果要对 '投标项目名称' 进行筛选,必须使用 LIKE 语句进行模糊匹配。
+- 数据库中的金额单位为人民币元,输出时必须除以一万用 ROUND 函数取整,输出结果使用万元为单位,。
+- 如果要对 '专业' 进行筛选,必须使用 LIKE 语句进行模糊匹配。
+- `项目来源` 指的是获取项目的途径,如 '招投标'。
+
+# 术语或缩写
+[preliminary]
+- 如 '七分' 这种名称是 '中通服建设七分公司' 或 '中通服建设有限公司七分公司' 的简称,一般使用模糊匹配,其他分公司简称,如 '一分',同理。
+- '小项目'指的是 '合同签订金额(人民币)' 低于20万元的项目。
+"""
diff --git a/config/development/database/contracts_1.conf b/config/development/database/contracts_1.conf
new file mode 100644
index 0000000..d3432ff
--- /dev/null
+++ b/config/development/database/contracts_1.conf
@@ -0,0 +1,29 @@
+# 数据库配置 0812
+
+connection_string = "mysql+pymysql://root:H1wNPOz3@172.16.16.13:3306/contracts?charset=utf8mb4"
+
+type = "MySQL"
+product = "MySQL 5.7"
+
+metadata = """
+CREATE TABLE `contracts` (
+`经办人` VARCHAR(30) comment '示例:综合能源分公司-业务支撑中心-祝瑞敏',
+`经办单位` VARCHAR(16) comment '示例:网优交付项目部',
+`经办日期` date comment '示例:2001-01-01',
+`所属分公司` VARCHAR(16) comment '枚举值:一分公司、二分公司、三分公司、四分公司、五分公司、六分公司、七分公司、北京分公司、数字基建分公司、上海分公司、智网分公司、河北分公司、综合能源分公司、本部',
+`合同名称` VARCHAR(137) comment '可以提取到项目地点、时间、客户名称、专业的相关信息(示例:2019-2020年株洲政企信息化职称技术配合服务采购订单)',
+`项目来源` VARCHAR(10) comment '枚举值:招投标、委托、邀标',
+`专业` VARCHAR(40) comment '示例:系统集成-信息系统集成服务-楼宇智能化',
+`地点` VARCHAR(35) comment '示例:中国-广东省-广州市-天河区',
+`合同类型名称` VARCHAR(29) comment '模糊匹配枚举值:市场经营收入类、系统集成服务类合同、工程设计类合同、通信网络维护类合同、软件开发类合同、供应链服务类合同、增值服务类合同、工程施工类合同、系统营维支撑类合同、设施管理类合同、国际类、国际贸易服务收入合同、贸易服务类合同、工程监理类合同、工程总包收入合同、工程分包收入合同',
+`聚焦行业` VARCHAR(19) comment '枚举值:其他、广电、IDC、电力、厂家、市政、交通、轨道',
+`项目部` VARCHAR(26) comment '示例:湖南分公司交付项目部',
+`客户名称` VARCHAR(54) comment '示例:广州铁路公安局、中国移动通信集团安徽有限公司宣城分公司',
+`运营商` VARCHAR(10) comment '枚举值:中国电信、中国移动、中国联通、中国广电、中国铁塔、其他',
+`合同签订金额(人民币)` float comment '“超大项目”金额大于等于1亿,“重大项目”金额大于1000万而小于1亿,“一般项目”金额大于500万而小于1000万,“小项目”金额小于500万',
+`签订日期` date comment '以CURRENT_DATE获取的时间为准作为当前日期',
+`业务拓展方式` VARCHAR(10) comment '枚举值:合作拓展、自主拓展、联合拓展、主业总包,通服分包、LH',
+`主实业协同` bool,
+)
+
+"""
diff --git a/config/development/database/contracts_2.conf b/config/development/database/contracts_2.conf
new file mode 100644
index 0000000..9eb6918
--- /dev/null
+++ b/config/development/database/contracts_2.conf
@@ -0,0 +1,32 @@
+# 数据库配置 0828
+
+connection_string = "mysql+pymysql://root:H1wNPOz3@172.16.16.13:3306/contracts?charset=utf8mb4"
+
+type = "MySQL"
+product = "MySQL 5.7"
+
+metadata = """
+CREATE TABLE `contracts` (
+`经办人` VARCHAR(30) comment '示例:中通服建设有限公司综合能源分公司-业务支撑中心-祝瑞敏',
+`经办单位` VARCHAR(16) comment '示例:网优交付项目部',
+`经办日期` date comment '示例:2001-01-01',
+`合同形式` VARCHAR(10) comment '枚举值:单项合同|订单合同|确收单合同|框架子合同|框架合同|结算单',
+`所属分公司` VARCHAR(16) comment '枚举值:中通服建设有限公司一分公司、中通服建设有限公司二分公司、中通服建设有限公司三分公司、中通服建设有限公司四分公司、中通服建设有限公司五分公司、中通服建设有限公司六分公司、中通服建设有限公司七分公司、北京分公司、数字基建分公司、上海分公司、智网分公司、河北分公司、综合能源分公司、本部',
+`合同名称` VARCHAR(137) comment '可以提取到项目地点、时间、客户名称、专业的相关信息(示例:2019-2020年株洲政企信息化职称技术配合服务采购订单)',
+`项目来源` VARCHAR(10) comment '枚举值:招投标、委托、邀标',
+`专业` VARCHAR(40) comment '示例:系统集成-信息系统集成服务-楼宇智能化',
+`地点` VARCHAR(35) comment '示例:中国-广东省-广州市-天河区',
+`合同类型名称` VARCHAR(29) comment '枚举值:市场经营收入类、系统集成服务类合同、工程设计类合同、其他、通信网络维护类合同、软件开发类合同、供应链服务类合同、增值服务类合同、工程施工类合同、系统营维支撑类合同、设施管理类合同、其他、国际类、国际贸易服务收入合同、贸易服务类合同、工程监理类合同、工程总包收入合同、工程分包收入合同',
+`聚焦行业` VARCHAR(19) comment '枚举值:其他、广电、IDC、电力、厂家、市政、交通、轨道',
+`项目部` VARCHAR(26) comment '示例:湖南分公司交付项目部',
+`中通服客商类型` VARCHAR(45) comment '示例:集团客户-建筑与房地产-建筑与房地产',
+`客户名称` VARCHAR(54) comment '示例:广州铁路公安局、中国移动通信集团安徽有限公司宣城分公司',
+`运营商` VARCHAR(10) comment '枚举值:中国电信、中国移动、中国联通、中国广电、中国铁塔、其他',
+`合同签订金额(人民币)` float comment '“超大项目”金额大于等于1亿,“重大项目”金额大于1000万而小于1亿,“一般项目”金额大于500万而小于1000万,“小项目”金额小于500万',
+`签订日期` date comment '以CURRENT_DATE获取的时间为准作为当前日期',
+`业务拓展方式` VARCHAR(10) comment '枚举值:合作拓展、自主拓展、联合拓展、主业总包,通服分包、LH',
+`主实业协同` bool,
+)
+
+"""
+
diff --git a/config/development/database/students.conf b/config/development/database/students.conf
new file mode 100644
index 0000000..6d2cac9
--- /dev/null
+++ b/config/development/database/students.conf
@@ -0,0 +1,36 @@
+# 数据库配置
+
+connection_string = "mysql+pymysql://root:H1wNPOz3@mysql.local:3306/students?charset=utf8mb4"
+
+type = "MySQL"
+product = "MySQL 5.7"
+
+metadata = """
+CREATE TABLE students (
+ student_id INTEGER PRIMARY KEY,
+ student_name VARCHAR(100), -- 学生姓名
+ major VARCHAR(100), -- 专业
+ year_of_enrollment INTEGER, -- 入学年份
+ student_age INTEGER -- 学生年龄
+);
+
+CREATE TABLE courses (
+ course_id INTEGER PRIMARY KEY,
+ course_name VARCHAR(100), -- 课程名称
+ credit REAL -- 学分
+);
+
+CREATE TABLE scores (
+ student_id INTEGER,
+ course_id INTEGER,
+ score INTEGER, -- 得分
+ semester VARCHAR(50), -- 学期
+ PRIMARY KEY (student_id, course_id),
+ FOREIGN KEY (student_id) REFERENCES students(student_id),
+ FOREIGN KEY (course_id) REFERENCES courses(course_id)
+);
+
+- 数据库中 'courses' 表中 'course_name' 字段有效值为 '计算机基础','数据结构','高等物理','线性代数','微积分','编程语言','量子力学','概率论','数据库系统','计算机网络'。
+- 数据库中 'scores' 表中 'semester' 字段有效值为 '2020年秋季', '2021年春季', '2021年秋季', '2022年春季', '2020年秋季', '2021年春季', '2021年秋季', '2022年春季', '2022年秋季', '2023年春季'。
+- 数据库中 'students' 表中 'major' 字段有效值为 '计算机科学', '物理学', '数学'。
+"""
diff --git a/config/development/logging.yaml b/config/development/logging.yaml
new file mode 100644
index 0000000..0bf7d11
--- /dev/null
+++ b/config/development/logging.yaml
@@ -0,0 +1,71 @@
+version: 1
+disable_existing_loggers: false
+formatters:
+ default:
+ (): uvicorn.logging.DefaultFormatter
+ fmt: '%(asctime)s - %(levelname)s %(message)s'
+ use_colors: null
+ access:
+ (): uvicorn.logging.AccessFormatter
+ fmt: '%(asctime)s - %(levelname)s %(client_addr)s - "%(request_line)s" %(status_code)s'
+ qa_formatter:
+ format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+handlers:
+ default:
+ formatter: default
+ class: logging.StreamHandler
+ stream: 'ext://sys.stderr'
+ access:
+ formatter: access
+ class: logging.StreamHandler
+ stream: 'ext://sys.stdout'
+ sql_file:
+ class: logging.handlers.RotatingFileHandler
+ level: INFO
+ formatter: default
+ filename: ./log/sql.log
+ maxBytes: 10485760 # 10MB
+ backupCount: 50 #保留50个log文件
+ encoding: utf8
+ err_file:
+ class: logging.handlers.RotatingFileHandler
+ level: ERROR
+ formatter: default
+ filename: ./log/err.log
+ maxBytes: 10485760 # 10MB
+ backupCount: 50 #保留50个log文件
+ encoding: utf8
+ qa_handler:
+ class: logging.FileHandler
+ level: DEBUG
+ formatter: qa_formatter
+ filename: log/test.log
+ mode: a
+ encoding: utf8
+loggers:
+ root:
+ handlers:
+ - default
+ - err_file
+ level: INFO
+ uvicorn:
+ level: INFO
+ uvicorn.error:
+ level: INFO
+ uvicorn.access:
+ handlers:
+ - access
+ level: INFO
+ propagate: false
+ sqlcode:
+ level: DEBUG
+ question:
+ handlers:
+ - sql_file
+ level: INFO
+ propagate: false
+ qa_cache:
+ level: DEBUG
+ handlers: [qa_handler]
+ propagate: false
+
diff --git a/config/development/model.yaml b/config/development/model.yaml
new file mode 100644
index 0000000..0dc8565
--- /dev/null
+++ b/config/development/model.yaml
@@ -0,0 +1,15 @@
+models:
+ gpt-3:
+ type: openai
+ api_key: "your-openai-api-key"
+ model_name: "text-davinci-003"
+ qwen:
+ type: tongyi
+ model_name: qwen-turbo
+ api_key: sk-6b39b56d21aa4406b0c67061f2e31e81
+ # 需要开通付费
+ qianfan:
+ type: qianfan
+ model_name: SQLCoder-7B
+ ak:
+ sk:
diff --git a/config/development/myqwen.conf b/config/development/myqwen.conf
new file mode 100644
index 0000000..d82e728
--- /dev/null
+++ b/config/development/myqwen.conf
@@ -0,0 +1,90 @@
+# qwen 模型配置
+
+# system role prompt
+system = "你擅长编写 SQL 代码,请结合具体问题编写正确规范的 SQL 代码"
+
+# prompt 模板及参数,在模板中可以使用 {question} {database.metadata} 以及 {params.xxx} 引用参数
+prompt = """
+### 数据库结构
+
+{database.metadata}
+
+- [table_sql]标记下的是数据库表的建表语句,它告诉我们数据库有哪些字段以及这些字段的类型。
+- [field_requirement]标记下是编写SQL语句是对不同字段的要求,模型输出SQL语句时应严格遵守这些要求。
+
+### 问题
+
+根据以上建表语句,生成一个 SQL 来回答如下问题: [QUESTION]{question}[/QUESTION]
+
+### 步骤
+
+1、结合建表语句[table_sql],分析该问题是否为一个指示了要查询某些数据库字段的 “明确提问”。如果该问题不是 “明确提问”,那么进行第二步,否则执行第三步。
+2、如果输入的问题不是 “明确提问”,那么它就是一个 “困难提问”,你需要对它进行扩充,生成一个 “明确提问”。
+3、结合建表语句,根据 “明确提问” 生成一个 SQL 语句。
+
+### 明确提问示例
+生成的 “明确提问” 格式应尽可能规范。一个 “明确问题” 通常会尽可能齐全地写明待查询的中文字段名或其相近名称,示例如下:[EXAMPLE]{example}[/EXAMPLE]。
+如果你由 “困难提问” 生成了 “明确提问”,那么输出中需要添加 “明确提问”,包含在```expanded```标记中。
+
+### 输出要求
+
+- 输出的字段名必须用中文描述。
+- 输出的 SQL 语句必须能够通过 {database.product} 验证。
+- 输出的 SQL 语句必须包含在 ```sql ``` 标记中。
+- 输出的 SQL 语句不要添加注释。
+- 输出的 SQL 语句使用反引号来引用中文字段名。
+- 输出的 SQL 语句中包含的字段名必须和上述的[table_sql]中的字段名保持一致。
+- 输出的 SQL 语句禁止使用别名。
+- 输出的 SQL 语句在 where 从句中的条件判断中的字段名应和[table_sql]中的字段名保持一致。
+
+### 输出格式
+'''expanded
+[EXPANDED]
+'''
+'''sql
+[SQL]
+'''
+[ANSWER]
+"""
+
+# params.requirements = """
+# - 生成的 “明确提问” 格式应尽可能规范。一个 “明确问题” 通常会尽可能齐全地写明待查询的中文字段名或其相近名称,示例如下:[EXAMPLE]{example}[/EXAMPLE]。
+# - 如果你由 “困难提问” 生成了 “明确提问”,那么输出中需要添加 “明确提问”,包含在```expanded```标记中。
+# """
+
+params.example = """
+示例1:
+输入:签订日期在2022年,合同名称中包含智慧城市,合同金额在300万以上的合同有哪些,列出合同名称,合同金额。
+输出:
+'''sql
+SELECT `合同名称`, `合同签订金额(人民币)` / 10000 AS 合同金额(万元)
+FROM `contracts`
+WHERE `签订日期` BETWEEN '2022-01-01' AND '2022-12-31'
+AND `合同名称` LIKE '%智慧城市%'
+AND `合同签订金额(人民币)` > 3000000;
+'''
+
+示例2:
+输入:所属分公司是七分公司,客户名称是中共广东省委办公厅的项目有哪些,列出所有信息。
+输出:
+'''sql
+SELECT *
+FROM contracts
+WHERE
+ 经办单位 LIKE '%七分公司%'
+ AND 客户名称 LIKE '%中共广东省委办公厅%';
+'''
+
+示例3:
+输入:业务拓展方式是联合拓展,所属分公司是二分公司,地点不在佛山市的项目一共有多少个。
+输出:
+'''sql
+SELECT COUNT(*)
+FROM contracts
+WHERE
+ `业务拓展方式` = '联合拓展'
+ AND `所属分公司` LIKE '%二分公司%'
+ AND `地点` NOT LIKE '佛山%'
+'''
+"""
+
diff --git a/config/development/problem_refine.conf b/config/development/problem_refine.conf
new file mode 100644
index 0000000..0dc9b4c
--- /dev/null
+++ b/config/development/problem_refine.conf
@@ -0,0 +1,26 @@
+prompt = """
+你是一位信息提取领域的专家,擅长从文本中识别关键词汇和概念,并结合上下文信息优化问题。
+
+工作流:
+ 1. 接收用户输入的问题。
+ 2. 使用自然语言处理技术提取问题中的关键词。
+ 3. 分析问题上下文,理解用户意图。
+ 4. 结合关键词和上下文信息,生成更明确的问题。
+
+优化后的问题格式:请从数据库表 contracts 中按照筛选条件【】查询字段【】
+
+例子:
+输入问题:所属分公司是七分公司,客户名称是中共广东省委办公厅的项目有哪些,列出所有信息
+关键词:所属分公司、七分公司、客户名称、中共广东省委办公厅
+优化后的问题:请从数据库表 contracts 中按照筛选条件 客户名称中包含“中共广东省委办公厅” 查询所有字段
+
+请结合以下与问题相关的背景知识:
+{context}
+
+输入问题:{question}
+
+输出格式如下:
+关键词:
+优化后的问题:
+"""
+
diff --git a/config/development/qwen.conf b/config/development/qwen.conf
new file mode 100644
index 0000000..427249d
--- /dev/null
+++ b/config/development/qwen.conf
@@ -0,0 +1,36 @@
+# qwen 模型配置
+
+# system role prompt
+system = "你擅长编写 SQL 代码,请结合具体问题编写正确规范的 SQL 代码"
+
+# prompt 模板及参数,在模板中可以使用 {question} {database.metadata} 以及 {params.xxx} 引用参数
+prompt = """
+### 数据库结构
+
+{database.metadata}
+
+
+### 问题
+
+根据以上建表语句,生成一个 SQL 来回答如下问题: [QUESTION]{question}[/QUESTION]
+
+
+### 要求
+
+- 输出的字段名必须用中文描述。
+- 输出的 SQL 语句必须能够通过 {database.product} 验证。
+- 输出的 SQL 语句必须包含在 ```sql ``` 标记中。
+- 输出的 SQL 语句不要添加注释。
+
+{requirements}
+
+### 输出格式
+'''sql
+[SQL]
+'''
+[ANSWER]
+"""
+
+params.requirements = """"
+- 简要描述你的想法
+"""
diff --git a/config/development/qwen_agent.conf b/config/development/qwen_agent.conf
new file mode 100644
index 0000000..fcf0fe8
--- /dev/null
+++ b/config/development/qwen_agent.conf
@@ -0,0 +1,54 @@
+# qwen 模型配置
+
+# system role prompt
+system = "你擅长编写 SQL 代码,请结合具体问题编写正确规范的 SQL 代码,同时你是一个中英文专家,你可以理解prompt中的中英文语句"
+
+prompt = """
+你擅长编写 SQL 代码,请结合具体问题编写正确规范的 SQL 代码,同时你是一个中英文专家,你可以理解prompt中的中英文语句
+对于输出 SQL 语句有以下要求:
+- 输出的字段名必须用中文描述。
+- 输出的 SQL 语句必须能够通过 {product} 验证。
+- 输出的 SQL 语句必须包含在 ```sql ``` 标记中。
+- 默认对输出的 SQL语句使用 LIMIT 来限制行数,默认行数为20行
+- 输出的 SQL 语句在 where 从句中的条件判断中的字段名应和数据库表中的字段名保持一致。
+
+
+### 数据库结构
+以下是一些数据库信息:
+{metadata}
+
+### 步骤
+
+按照给定的格式回答以下问题。你可以使用下面这些工具:
+{tools}
+
+你需要遵循以下步骤进行思考:
+1. 首先查询数据库中有哪几个表,以及这些表的范式
+2. 根据提供的数据库信息和数据库表范式理解问题,生成相应的SQL语句
+3. 如果问题比较复杂,可以将它拆解成多步,使用多个SQL语句进行完成
+4. 在生成最终答案前,需要对 SQL 语句检验和执行来确保它是有效的答案;如果无效,则需要继续思考
+5. 最终答案中,将SQL查询的结果和SQL语句一起返回,注意返回的SQL语句要用```sql ```包围
+
+回答时需要遵循以下用---括起来的格式:
+---
+Question: 我需要回答的问题
+Thought: 回答这个上述我需要做些什么
+Action: ”{tool_names}“ 中的其中一个工具名
+Action Input: 选择工具所需要的输入
+Observation: 选择工具返回的结果
+...(这个思考/行动/行动输入/观察可以重复N次)
+Thought: 我现在知道最终答案
+Final Answer: 原始输入问题的最终答案,同时需要你给出解决问题的 SQL 语句,格式如下:
+```sql
+[SQL]
+```
+---
+现在开始回答,记得在给出最终答案前多按照指定格式进行一步一步的推理。
+Question: {input}
+{agent_scratchpad}
+
+"""
+
+params.example = """
+"""
+
diff --git a/config/development/qwen_agent_2.conf b/config/development/qwen_agent_2.conf
new file mode 100644
index 0000000..4c1e277
--- /dev/null
+++ b/config/development/qwen_agent_2.conf
@@ -0,0 +1,111 @@
+# qwen 模型配置
+
+# system role prompt
+system = "你擅长编写 SQL 代码,请结合具体问题编写正确规范的 SQL 代码,同时你是一个中英文专家,你可以理解prompt中的中英文语句"
+
+prompt = """
+你擅长编写 SQL 代码,请结合具体问题编写正确规范的 SQL 代码,同时你是一个中英文专家,你可以理解prompt中的中英文语句
+对于输出 SQL 语句有以下要求:
+- 输出的字段名必须用中文描述。
+- 输出的 SQL 语句必须能够通过 {product} 验证。
+- 输出的 SQL 语句必须包含在 ```sql ``` 标记中。
+- 默认对输出的 SQL语句使用 LIMIT 来限制行数,默认行数为20行
+- 输出的 SQL 语句中的字段名应和数据库表中的字段名保持一致。
+
+
+### 数据库结构
+以下是一些数据库信息:
+{metadata}
+
+- [table_sql]标记下的是数据库表的建表语句,它告诉我们数据库有哪些字段以及这些字段的类型。
+
+### 步骤
+
+按照给定的格式回答以下问题。你可以使用下面这些工具:
+{tools}
+
+你需要遵循以下步骤进行思考:
+1. 首先查询数据库中有哪几个表,以及这些表的范式
+2. 根据提供的数据库信息和数据库表范式理解问题,生成相应的SQL语句
+3. 如果问题比较复杂,可以将它拆解成多步,使用多个SQL语句进行完成
+4. 在生成最终答案前,需要对 SQL 语句检验和执行来确保它是有效的答案;如果无效,则需要继续思考
+5. 最终答案中,将SQL查询的结果和SQL语句一起返回,注意返回的SQL语句要用```sql ```包围
+
+以下是一些问题的问答案例:
+{example}
+
+以下上与问题相关的上下文:
+{context}
+
+回答时需要遵循以下用---括起来的格式:
+---
+Question: 我需要回答的问题
+Refined_question: 经过大模型优化后的问题
+Thought: 回答这个上述我需要做些什么
+Action: ”{tool_names}“ 中的其中一个工具名
+Action Input: 选择工具所需要的输入
+Observation: 选择工具返回的结果
+...(这个思考/行动/行动输入/观察可以重复N次)
+Thought: 我现在知道最终答案
+Final Answer: 原始输入问题的最终答案,同时需要你给出解决问题的 SQL 语句,格式如下:
+```sql
+SELECT [字段名]
+FROM contracts
+WHERE
+ [条件语句]
+LIMIT 20;
+```
+---
+现在开始回答,记得在给出最终答案前多按照指定格式进行一步一步的推理。
+输入的问题会经过大模型进行信息提取、优化,下面会同时给出优化后的问题,请结合原始问题和优化后问题回答。
+Question: {input}
+Refined_question: {refined_question}
+{agent_scratchpad}
+
+"""
+
+params.example = """
+- 示例1
+输入:所属分公司是七分公司,客户名称是中共广东省委办公厅的项目有哪些,列出所有信息。
+输出:
+'''sql
+SELECT *
+FROM contracts
+WHERE
+ 经办单位 LIKE '%七分公司%'
+ AND 客户名称 LIKE '%中共广东省委办公厅%'
+LIMIT 20;
+
+- 示例2
+输入:广州运维重大项目。
+输出:
+'''sql
+SELECT `合同名称`,`合同签订金额(人民币)`,`所属分公司`,`项目部`,`客户名称`,`签订时间`
+FROM contracts
+WHERE
+ `合同名称` LIKE '%运维%'
+ AND `地点` LIKE '%广州%'
+ AND `合同签订金额(人民币)` > 3000000
+LIMIT 20;
+
+- 示例3
+输入:粤东的大项目。
+输出:
+'''sql
+SELECT *
+FROM contracts
+WHERE
+ (`地点` LIKE '%汕头市%'
+ OR `地点` LIKE '%潮州市%'
+ OR `地点` LIKE '%梅州市%'
+ OR `地点` LIKE '%汕尾市%'
+ OR `地点` LIKE '%揭阳市%'
+ OR `合同名称` LIKE '%汕头市%'
+ OR `合同名称` LIKE '%潮州市%'
+ OR `合同名称` LIKE '%梅州市%'
+ OR `合同名称` LIKE '%汕尾市%'
+ OR `合同名称` LIKE '%揭阳市%')
+ AND `合同签订金额(人民币)` > 10000000
+LIMIT 20;
+"""
+
diff --git a/config/development/qwen_graph.conf b/config/development/qwen_graph.conf
new file mode 100644
index 0000000..fe10cb4
--- /dev/null
+++ b/config/development/qwen_graph.conf
@@ -0,0 +1,85 @@
+# qwen 模型配置
+
+# system role prompt
+system = "你是一位专业的数据库分析师,具有将自然语言问题转化为精确SQL查询的能力"
+
+prompt = """
+- Role: SQL转换专家
+- Background: 用户需要一个能够理解自然语言问题并将其转化为SQL查询语句的智能代理。
+- Profile: 你是一位专业的数据库分析师,具有将自然语言问题转化为精确SQL查询的能力。
+- Skills: 理解自然语言、SQL语言知识、问题解析、查询构建。
+- Goals: 设计一个智能代理,能够接收自然语言问题并生成相应的SQL查询语句。
+- Constrains:
+ 1. 输出的字段名必须用中文描述。
+ 2. 输出的 SQL 语句必须能够通过 {product} 验证。
+ 3. 输出的 SQL 语句中的字段名应和数据库表中的字段名保持一致。
+ 4. 对 `所属分公司`、`专业`、`客户名称`、`合同名称`、`经办人`、`客商类型`、`地点` 等字段进行筛选时,必须使用 LIKE 语句进行模糊匹配。
+- OutputFormat: 使用```sql```标记的SQL查询语句。
+
+- Database metedata
+{metadata}
+[table_sql]标记下的是数据库表的建表语句,它告诉我们数据库有哪些字段以及这些字段的类型。
+
+- Workflow:
+1. 首先查询数据库中有哪几个表,以及这些表的范式
+2. 根据提供的数据库信息和数据库表范式理解问题,生成相应的SQL语句
+3. 如果问题比较复杂,可以将它拆解成多步,使用多个SQL语句进行完成
+4. 在生成最终答案前,需要对 SQL 语句检验和执行来确保它是有效的答案;如果无效,则需要继续思考
+5. 最终答案中,将SQL查询的结果和SQL语句一起返回,注意返回的SQL语句要用```sql ```包围
+
+- Examples:
+{example}
+
+Final Answer:
+- Prefix: 对问题和生成SQL语句的描述
+- Code:
+```sql
+SELECT [字段名1], [字段名2], ...
+FROM [表名]
+WHERE [条件语句]
+GROUP BY [字段名1], [字段名2], ...
+HAVING [条件语句]
+ORDER BY [字段名] ASC|DESC
+LIMIT [数量] OFFSET [偏移量];
+```
+---
+Question:
+"""
+
+params.example = """
+- 示例1
+输入:所属分公司是七分公司,客户名称是中共广东省委办公厅的项目有哪些,列出所有信息。
+输出:我们需要从contracts表中筛选出所属分公司字段包含'七分公司'以及客户名称中包含'中共广东省委办公厅'的项目名称。
+Prefix:
+Code:
+```sql
+SELECT *
+FROM contracts
+WHERE
+ `所属分公司` LIKE '%七分公司%'
+ AND `客户名称` LIKE '%中共广东省委办公厅%';
+```
+
+- 示例2
+输入:粤东的大项目。
+输出:
+Prefix: 粤东地区包含汕头市、潮州市、梅州市、汕尾市、揭阳市,我们需要在contracts表中筛选地点字段包含这些城市或合同名称包含这些城市的项目。
+Code:
+```sql
+SELECT *
+FROM contracts
+WHERE
+ (`地点` LIKE '%汕头市%'
+ OR `地点` LIKE '%潮州市%'
+ OR `地点` LIKE '%梅州市%'
+ OR `地点` LIKE '%汕尾市%'
+ OR `地点` LIKE '%揭阳市%'
+ OR `合同名称` LIKE '%汕头市%'
+ OR `合同名称` LIKE '%潮州市%'
+ OR `合同名称` LIKE '%梅州市%'
+ OR `合同名称` LIKE '%汕尾市%'
+ OR `合同名称` LIKE '%揭阳市%')
+ AND `合同签订金额(人民币)` > 10000000;
+```
+"""
+
diff --git a/config/development/rag.txt b/config/development/rag.txt
new file mode 100644
index 0000000..2e2d610
--- /dev/null
+++ b/config/development/rag.txt
@@ -0,0 +1,457 @@
+长三角 上海市,江苏省,浙江省,安徽省
+粤东 汕头市,梅州市,汕尾市,潮州市,揭阳市
+珠三角 广州市,深圳市,珠海市,佛山市,惠州市,东莞市,中山市,江门市,肇庆市
+京津唐 北京市,天津市,河北省唐山市
+成渝经济区 四川省成都市,重庆市
+环渤海经济圈 北京市,天津市,河北省,山东省,辽宁省,山西省,内蒙古自治区
+西南 云南省,贵州省,广西壮族自治区,四川省,重庆市,西藏自治区
+华北 北京市,天津市,河北省,山西省,内蒙古自治区
+华中 河南省,湖北省,湖南省
+华东 上海市,江苏省,浙江省,安徽省,福建省,江西省,山东省
+华南 广东省,广西壮族自治区,海南省
+西北 陕西省,甘肃省,青海省,宁夏回族自治区,新疆维吾尔自治区
+东北 黑龙江省,吉林省,辽宁省
+胶东半岛 山东省青岛市,烟台市,威海市,潍坊市,日照市
+苏锡常 江苏省苏州市,无锡市,常州市
+闽南金三角 福建省厦门市,泉州市,漳州市
+赣鄱平原 江西省南昌市,九江市,上饶市,抚州市,景德镇市,鹰潭市
+鲁西平原 山东省德州市,聊城市,菏泽市,济宁市,滨州市
+豫东平原 河南省商丘市,周口市,开封市,许昌市,漯河市
+湘中 湖南省长沙市,株洲市,湘潭市,娄底市,邵阳市
+鄂西生态文化旅游圈 湖北省恩施土家族苗族自治州,宜昌市,荆门市,荆州市
+陕北 陕西省榆林市,延安市
+关中 陕西省西安市,宝鸡市,咸阳市,渭南市,铜川市
+晋东南 山西省长治市,晋城市
+冀北山区 河北省张家口市,承德市
+辽东半岛 辽宁省大连市,丹东市,营口市,鞍山市,盘锦市
+吉黑沿边 吉林省延边朝鲜族自治州,黑龙江省牡丹江市,佳木斯市,双鸭山市,鸡西市,鹤岗市,伊春市,七台河市,绥化市,黑河市,大兴安岭
+珠江口西岸都市圈 广东省珠海市,中山市,江门市
+粤西 广东省湛江市,茂名市,阳江市
+粤北 广东省韶关市,清远市,河源市,梅州市,潮州市,云浮市
+浙东 浙江省宁波市,舟山市,台州市,温州市
+浙西 浙江省金华市,衢州市,丽水市
+皖南 安徽省黄山市,宣城市,池州市,芜湖市,铜陵市,马鞍山市
+皖北 安徽省亳州市,阜阳市,宿州市,淮北市,蚌埠市,淮南市
+闽东北 福建省福州市,宁德市,莆田市
+赣南 江西省赣州市
+赣南闽西革命老区 江西省赣州市,福建省龙岩市
+湘西 湖南省湘西土家族苗族自治州,张家界市,怀化市,常德市,邵阳市
+黔中 贵州省贵阳市,遵义市,安顺市,黔南布依族苗族自治州
+黔东南 贵州省黔东南苗族侗族自治州
+黔西南 贵州省黔西南布依族苗族自治州
+滇东北 云南省昭通市,曲靖市
+滇西 云南省大理白族自治州,保山市,德宏傣族景颇族自治州,临沧市,怒江傈僳族自治州,丽江市,迪庆藏族自治州
+滇南 云南省红河哈尼族彝族自治州,普洱市,西双版纳傣族自治州
+鲁南 山东省临沂市,枣庄市,济宁市,菏泽市,日照市
+豫南 河南省信阳市,驻马店市,南阳市
+豫北 河南省安阳市,鹤壁市,新乡市,焦作市,濮阳市
+豫中南 河南省许昌市,漯河市,平顶山市,周口市,商丘市
+晋北 山西省大同市,朔州市,忻州市
+晋中 山西省太原市,晋中市,吕梁市
+蒙东 内蒙古自治区赤峰市,通辽市,兴安盟,呼伦贝尔市,锡林郭勒盟东部
+蒙西 内蒙古自治区乌海市,巴彦淖尔市,阿拉善盟,鄂尔多斯市西部
+蒙南 内蒙古自治区包头市,呼和浩特市,乌兰察布市
+蒙北 内蒙古自治区呼伦贝尔市北部,锡林郭勒盟北部
+陇东 甘肃省庆阳市,平凉市
+陇南 甘肃省陇南市,天水市南部
+陇西 甘肃省定西市,天水市中部,平凉市南部
+陇北 甘肃省白银市,武威市,张掖市,酒泉市,嘉峪关市
+陇中 甘肃省兰州市,定西市,临夏回族自治州
+甘南 甘肃省甘南藏族自治州
+甘东南 甘肃省陇南市,天水市,平凉市南部
+青藏高原 西藏自治区拉萨市,昌都市,山南市,日喀则市,那曲市,阿里;青海省西宁市,海东市,海南藏族自治州,海北藏族自治州,黄南藏族自治州,果洛藏族自治州,玉树藏族自治州,海西蒙古族藏族自治州
+川东北 四川省广元市,巴中市,达州市,南充市,广安市
+川西北 四川省阿坝藏族羌族自治州,甘孜藏族自治州
+川南 四川省泸州市,宜宾市,内江市,自贡市
+川西 四川省雅安市,眉山市,乐山市,成都西部
+川中 四川省遂宁市,资阳市,德阳市,成都东部
+陕南 陕西省汉中市,安康市,商洛市
+关中平原 陕西省西安市,宝鸡市,咸阳市,渭南市,铜川市
+秦巴山区 陕西省汉中市,安康市,商洛市
+宁南 宁夏回族自治区固原市,吴忠市南部
+宁北 宁夏回族自治区银川市,石嘴山市,吴忠市北部,中卫市
+宁中 宁夏回族自治区银川市,石嘴山市,吴忠市,中卫市
+新北 新疆维吾尔自治区乌鲁木齐市,昌吉回族自治州,吐鲁番市,哈密市
+新南 新疆维吾尔自治区喀什,和田,阿克苏,克孜勒苏柯尔克孜自治州
+新东 新疆维吾尔自治区伊犁哈萨克自治州,塔城,阿勒泰
+新西 新疆维吾尔自治区博尔塔拉蒙古自治州,克拉玛依市,巴音郭楞蒙古自治州
+辽西 辽宁省朝阳市,葫芦岛市,阜新市
+辽南 辽宁省大连市,营口市,鞍山市南部
+辽北 辽宁省铁岭市,抚顺市,本溪市
+辽中 辽宁省沈阳市,辽阳市,鞍山市中部,盘锦市
+吉东 吉林省吉林市,延边朝鲜族自治州,通化市
+吉南 吉林省长春市,四平市,辽源市
+吉西 吉林省松原市,白城市
+吉北 吉林省白山市
+吉中 吉林省长春市,吉林市中部
+黑南 黑龙江省哈尔滨市,大庆市,齐齐哈尔市
+黑北 黑龙江省黑河市,大兴安岭
+黑东 黑龙江省佳木斯市,双鸭山市,鹤岗市,伊春市东部
+黑西 黑龙江省齐齐哈尔市,绥化市,伊春市西部
+黑中 黑龙江省哈尔滨市,牡丹江市,佳木斯市南部
+藏东 西藏自治区昌都市,林芝市
+藏南 西藏自治区山南市,日喀则市
+藏西 西藏自治区阿里
+藏北 西藏自治区那曲市
+桂北 广西壮族自治区桂林市,柳州市,贺州市
+桂中 广西壮族自治区南宁市,来宾市,河池市
+桂南 广西壮族自治区钦州市,北海市,防城港市
+桂西 广西壮族自治区百色市,崇左市
+桂东 广西壮族自治区梧州市,贵港市,玉林市
+琼北 海南省海口市,文昌市,澄迈县,定安县,屯昌县
+琼南 海南省三亚市,陵水黎族自治县,保亭黎族苗族自治县,乐东黎族自治县,东方市
+琼东 海南省琼海市,万宁市
+琼西 海南省儋州市,临高县,昌江黎族自治县
+琼中 海南省五指山市,白沙黎族自治县,琼中黎族苗族自治县
+豫东 河南省商丘市,周口市,开封市,许昌市东部
+豫中 河南省郑州市,许昌市,漯河市,平顶山市东部,驻马店市北部
+鲁东 山东省烟台市,威海市,青岛市东部
+鲁中 山东省济南市,淄博市,莱芜区,泰安市,潍坊市中部
+鲁西 山东省聊城市,德州市,菏泽市,济宁市,滨州市
+鲁北 山东省东营市,滨州市,德州市北部,聊城市北部
+鲁西南 山东省济宁市,菏泽市,枣庄市,临沂市西部
+豫西 河南省洛阳市,三门峡市,济源市,平顶山市西部
+鲁东半岛 山东省烟台市,威海市,青岛市东部
+胶莱平原 山东省青岛市,潍坊市,烟台市,威海市
+鲁西南平原 山东省菏泽市,济宁市,枣庄市,临沂市西部
+沂蒙山区 山东省临沂市,日照市
+淮北平原 安徽省淮北市,宿州市,亳州市,阜阳市,蚌埠市
+巢湖 安徽省合肥市,巢湖市,六安市东部
+淮河 安徽省淮南市,蚌埠市,阜阳市,亳州市,宿州市
+皖东北 安徽省宿州市,淮北市,蚌埠市北部
+赣东北 江西省上饶市,景德镇市,鹰潭市
+赣中 江西省南昌市,抚州市,吉安市
+赣西 江西省宜春市,萍乡市,新余市
+湘南 湖南省衡阳市,郴州市,永州市
+湘北 湖南省岳阳市,常德市,益阳市
+湘东 湖南省株洲市,湘潭市,长沙市东部
+京津冀 北京市,天津市,河北省
+东北三省 黑龙江省,吉林省,辽宁省
+江浙沪 江苏省,浙江省,上海市
+川渝 四川省,重庆市
+云贵川 云南省,贵州省,四川省
+两湖 湖北省,湖南省
+两广 广东省,广西壮族自治区
+陕甘宁 陕西省,甘肃省,宁夏回族自治区
+新青藏 新疆维吾尔自治区,青海省,西藏自治区
+蒙晋 内蒙古自治区,山西省
+鲁豫 山东省,河南省
+苏皖 江苏省,安徽省
+闽赣 福建省,江西省
+桂琼 广西壮族自治区,海南省
+华北平原 北京市,天津市,河北省,山西省,内蒙古自治区
+江南 上海市,江苏省南部,浙江省北部,安徽省南部,江西省东北部
+塞北 河北省北部,山西省北部,内蒙古自治区
+粤港澳大湾区 香港特别行政区、澳门特别行政区、广州市、深圳市、珠海市、佛山市、惠州市、东莞市、中山市、江门市、肇庆市
+环渤海 北京市、天津市、河北省、辽宁省、山东省
+长江中游城市群 武汉市、长沙市、南昌市、合肥市、南京市、芜湖市等
+中原城市群 郑州市、洛阳市、开封市、新乡市、焦作市等
+长江三角洲城市群 上海市、江苏省、浙江省、安徽省的主要城市
+成渝城市群 成都市、重庆市及周边城市
+长株潭城市群 长沙市、株洲市、湘潭市
+海峡西岸经济区 福州市、厦门市、泉州市、漳州市、宁德市、莆田市、三明市、龙岩市
+北部湾经济区 南宁市、北海市、钦州市、防城港市、海口市、三亚市等
+滇中城市经济圈 昆明市、曲靖市、玉溪市、楚雄彝族自治州等
+黔中城市群 贵阳市、遵义市、安顺市、六盘水市等
+陕甘宁革命老区 陕西省、甘肃省、宁夏回族自治区
+大别山革命老区 河南省信阳市、湖北省黄冈市、安徽省六安市等
+左右江革命老区 广西壮族自治区百色市、崇左市等
+太行山区 山西省长治市、晋城市、河北省石家庄市、保定市等
+武陵山区 重庆市、湖北省恩施土家族苗族自治州、湖南省湘西土家族苗族自治州等
+南岭山区 广东省韶关市、清远市、广西壮族自治区贺州市、湖南省永州市等
+川西北高原 四川省阿坝藏族羌族自治州、甘孜藏族自治州等
+黔东南苗族侗族自治州 贵州省黔东南苗族侗族自治州
+滇西北 云南省迪庆藏族自治州、怒江傈僳族自治州等
+陇南山区 甘肃省陇南市
+冀东 河北省唐山市、秦皇岛市、承德市等
+苏中 江苏省扬州市、泰州市、南通市等
+皖中 安徽省合肥市、六安市、滁州市等
+辽中南 辽宁省中部,包括沈阳市、大连市等
+苏北五市 江苏省淮安市、连云港市、宿迁市、盐城市、徐州市
+皖北六市 安徽省淮北市、亳州市、宿州市、蚌埠市、阜阳市、淮南市
+鄂西北 湖北省十堰市、襄阳市、荆门市等
+冀南 河北省邯郸市、邢台市、衡水市等
+粤西北 广东省韶关市、清远市、云浮市等
+桂东南 广西壮族自治区玉林市、贵港市、梧州市等
+海南北部 海南省海口市、文昌市、琼海市等
+海南南部 海南省三亚市、陵水黎族自治县、保亭黎族苗族自治县等
+川西南 四川省凉山彝族自治州、攀枝花市等
+黔南 贵州省黔南布依族苗族自治州
+滇东南 云南省红河哈尼族彝族自治州、文山壮族苗族自治州等
+滇西南 云南省德宏傣族景颇族自治州、西双版纳傣族自治州等
+冀中南 河北省石家庄市、邢台市、邯郸市等
+苏南 江苏省苏州市、无锡市、常州市等
+赣西北 江西省九江市、宜春市等
+鄂西南 湖北省恩施土家族苗族自治州、宜昌市等
+豫西南 河南省南阳市、信阳市、驻马店市等
+苏北 江苏省徐州市、连云港市、淮安市等
+皖中北 安徽省合肥市、蚌埠市、淮南市等
+桂西北 广西壮族自治区河池市、百色市等
+桂东北 广西壮族自治区贺州市、桂林市等
+海南中部 海南省琼中黎族苗族自治县、五指山市等
+山东半岛 山东省青岛市、烟台市、威海市等
+苏北地区 江苏省徐州市、连云港市、淮安市、宿迁市、盐城市
+鲁西北 山东省德州市、聊城市、滨州市等
+海南东部 海南省琼海市、万宁市、文昌市、三亚市等
+海南西部 海南省东方市、昌江黎族自治县、乐东黎族自治县等
+川东南 四川省泸州市、内江市、自贡市、宜宾市等
+黔北 贵州省遵义市、铜仁市等
+滇中 云南省昆明市、楚雄彝族自治州、玉溪市等
+冀西 河北省张家口市、承德市等
+鲁东南 山东省日照市、临沂市、枣庄市等
+皖东南 安徽省宣城市、黄山市、池州市、铜陵市等
+鄂东 湖北省黄冈市、鄂州市、黄石市等
+冀中 河北省石家庄市、保定市、衡水市等
+鲁中南 山东省济宁市、枣庄市、临沂市、日照市等
+鄂南 湖北省咸宁市、荆州市、宜昌市等
+海南东北部 海南省海口市、文昌市、琼海市等
+海南西北部 海南省澄迈县、临高县、儋州市等
+黔西北 贵州省毕节市、六盘水市等
+陇东南 甘肃省庆阳市、平凉市等
+晋南 山西省运城市、临汾市等
+湘西北 湖南省湘西土家族苗族自治州、张家界市等
+鄂中 湖北省随州市、荆门市、天门市等
+京津保核心区 北京市,天津市,河北省保定市
+环京津 河北省廊坊市、承德市、张家口市、秦皇岛市等
+环沪 江苏省苏州市、南通市、无锡市、常州市等
+环广深 广东省东莞市、惠州市、中山市、珠海市等
+环成都 四川省德阳市、眉山市、资阳市、绵阳市等
+环武汉 湖北省黄石市、鄂州市、孝感市、黄冈市等
+环长株潭城市群 湖南省岳阳市、益阳市、常德市、娄底市等
+环鄱阳湖经济圈 江西省九江市、上饶市、抚州市、鹰潭市等
+环杭州湾大湾区 浙江省嘉兴市、湖州市、绍兴市、宁波市等
+环太湖经济圈 江苏省苏州市、无锡市、常州市,浙江省湖州市等
+环青海湖地区 青海省海南藏族自治州、海北藏族自治州等
+环塔里木盆地 新疆维吾尔自治区阿克苏地区、喀什地区、和田地区等
+环渤海大湾区 辽宁省大连市、营口市、盘锦市、锦州市等
+环滇池地区 云南省昆明市、玉溪市、楚雄彝族自治州等
+环鄱阳湖城市群 江西省南昌市、九江市、上饶市、景德镇市等
+环洞庭湖经济圈 湖南省岳阳市、常德市、益阳市、长沙市等
+环珠江口西岸地区 广东省江门市、阳江市、茂名市、湛江市等
+环珠江口东岸地区 广东省惠州市、汕尾市、揭阳市、梅州市等
+环北部湾城市群 广西壮族自治区北海市、钦州市、防城港市等
+环巢湖 安徽省合肥市、芜湖市、马鞍山市、铜陵市等
+环太湖西部地区 江苏省常州市、无锡市、苏州市等
+环太湖东部地区 浙江省湖州市、嘉兴市等
+环巢湖城市群 安徽省合肥市、芜湖市、马鞍山市、六安市等
+环钱塘江城市群 浙江省杭州市、绍兴市、宁波市、嘉兴市等
+环太湖北部地区 江苏省泰州市、南通市等
+环巢湖经济圈 安徽省合肥市、芜湖市、马鞍山市、铜陵市等
+环鄱阳湖生态经济区 江西省南昌市、九江市、上饶市、抚州市等
+环洞庭湖生态经济区 湖南省岳阳市、常德市、益阳市、长沙市等
+环珠江口大湾区 广东省广州市、深圳市、珠海市、东莞市等
+环北部湾经济合作区 广西壮族自治区南宁市、北海市、钦州市、防城港市等
+环巢湖生态经济区 安徽省合肥市、芜湖市、马鞍山市、六安市等
+环太湖生态经济区 江苏省无锡市、苏州市、常州市,浙江省湖州市等
+环巢湖旅游区 安徽省合肥市、芜湖市、巢湖市等
+环钱塘江生态经济区 浙江省杭州市、绍兴市、宁波市、嘉兴市等
+环太湖旅游区 江苏省无锡市、苏州市、常州市,浙江省湖州市等
+环燕山 北京市,天津市,河北省承德市、秦皇岛市等
+环太行山 山西省长治市、晋城市,河北省石家庄市、邢台市等
+环大别山 河南省信阳市、南阳市,湖北省黄冈市、随州市等
+环武陵山 重庆市,湖南省湘西土家族苗族自治州,湖北省恩施土家族苗族自治州等
+环南岭 广东省韶关市、清远市,广西壮族自治区贺州市,湖南省永州市等
+环天山 新疆维吾尔自治区昌吉回族自治州、伊犁哈萨克自治州等
+环阿尔泰山 新疆维吾尔自治区阿勒泰、塔城等
+环帕米尔高原 新疆维吾尔自治区克孜勒苏柯尔克孜自治州等
+环喀喇昆仑山 新疆维吾尔自治区和田、喀什等
+环祁连山 青海省海北藏族自治州、海西蒙古族藏族自治州等
+环横断山 四川省甘孜藏族自治州、阿坝藏族羌族自治州等
+
+
+
+中通服建设有限公司有多个分公司,所属分公司包括中通服建设一分公司(简称一分)、中通服建设二分公司(二分),一直到有七分公司。
+小项目通常指的是项目金额小于20万元的项目
+
+信创 国产化替代、通用软硬件、办公系统软件开发、服务器、操作系统、数据库、中间件、应用软件、芯片、网络安全、云计算、大数据、自主可控、安全可靠、政务内网、政务外网、办公场所
+智慧 智慧城市、智慧交通、智慧医疗、智慧教育、智慧园区、智慧电力、智慧水务、智慧应急、智慧安防、智慧社区、智慧仓储、智慧工地、智慧文旅
+运维 运营维护、系统维护、设备维护、软件维护、网络维护、IT 运维、数据中心运维、云计算运维、应用运维
+系统集成 软件集成、硬件集成、网络集成、数据集成、应用集成、安全集成、系统开发、系统设计、项目管理、技术服务、解决方案
+应急 应急管理、应急救援、应急指挥、应急预案、应急物资、应急通信、应急演练、安全生产、公共安全、防灾减灾
+智慧应急建设内容 应急指挥平台、监测预警系统、风险评估、资源管理、预案管理、模拟演练、应急通信、智能救援装备
+智慧城市建设内容 数字基础设施、运营指挥中心、智慧交通系统、智慧医疗体系、智慧教育平台、智慧园区管理、智慧政务服务、智慧安防监控、智慧能源管理、智慧环保监测、智慧水务系统、智慧社区服务
+智慧交通建设内容 智能网联、智慧停车、智慧交管、智慧交运、车路协同、交通信号控制、电子警察、测速卡口、科技治超、智能管控、交通大数据、出行服务平台
+智慧医疗建设内容 医院建筑智能化、医疗信息化、远程医疗、医疗大数据、电子病历、移动医疗、智能医疗设备、医疗影像系统、医院信息系统、医疗物联网
+智慧教育建设内容 智慧校园、智慧教室、校园物联网、教育云平台、在线教育、教学资源管理、智能教学系统、教育大数据分析、校园安全管理、智慧图书馆
+智慧园区建设内容 园区建筑智能化、基础设施、园区可视化平台、运营服务、企业应用、智能安防、能源管理、环境监测、智能停车、物业管理
+智慧电力建设内容 电力配网工程、通信网络、电力线路管线迁改、建筑智能化、应用系统开发、智慧灯杆、光伏工程、充电桩工程、无人机巡线、电力大数据
+智慧水务建设内容 城市内涝监测、雨情监测、河江湖监测、水质在线监测、智慧水厂、智慧供水、智慧排水、管网监测、水利工程信息化
+智慧安防建设内容 视频监控、门禁系统、入侵报警、人脸识别、智能分析、安防大数据、应急处置、安防小区、安防平台
+智慧社区建设内容 物业管理系统、社区服务平台、智能门禁、车辆管理、环境监测、智能家居、社区安防、养老服务、社区电商
+智慧工地建设内容 人员管理、设备管理、环境监测、施工进度管理、质量安全管理、物料管理、远程监控、智能塔吊、BIM 技术应用
+智慧仓储建设内容 自动化货架、智能搬运设备、仓储管理系统、库存控制、货物识别、数据分析、分拣系统、物流配送、仓库监控
+智慧文旅建设内容 智能票务、景区导览、虚拟旅游、文化遗产数字化、游客大数据分析、文旅营销平台、智慧酒店、智慧民宿、沉浸式体验
+
+
+数据库中有一张 contracts 表,它的字段有<`经办人`,`经办单位`,`经办日期`,`所属分公司`,`合同形式`,`是否主合同`,`合同名称`,`合同编号`,`框架合同编号`,`框架合同名称`,`主合同编号`,`主合同名称`,`项目来源`,`投标项目名称`,`编号生成时间`,`专业`,`地点`,
+`是否关联交易`,`合同类型名称`,`聚焦行业`,`管理分公司`,`建议实施单位`,`项目部`,`最小经营单元`,`省公司统一编号`,`统一编号生成时间`,`客户名称`,`运营商`,`中通服客商类型`,`合同签订金额(人民币)`,`合同签订金额(不含税)`,`是否垫资`,`垫资金额(元)`,
+`垫资说明`,`签订日期`,`签署日期`,`合同有效期(开始)`,`合同有效期(结束)`,`最终客户名称`,`最终中通服客商类型`,`税率`,`是否通服内部合作`,`项目组织模式`,`合同结算金额(含税)`,`列账收入(含税)`,`开票金额(含税)`,`收款金额(含税)`,`是否业务关闭`,
+`业务关闭时间`,`是否财务关闭`,`财务关闭时间`,`甲方订单编号`,`甲方合同编号`,`框架子合同编号`,`确收类型`,`业务拓展方式`,`主实业协同`,`协同类型`,`主业合同金额`,`对方联系人`,`对方联系电话`,`中标时间`,`协同拓展的主业公司`,`主业合同额`,`是否运营商政企`>
+
+`经办人` VARCHAR(30),
+【描述】`经办人`是指项目经理,提问 “谁的项目”、“项目经理” 等类似字眼时,通常涉及对该字段进行筛选。
+【举例】中通服建设七分公司-湖南分公司交付项目部-蔡胜华|中通服建设一分公司-河北集客项目部-常楠|...
+
+`经办单位` VARCHAR(16),
+【描述】`经办单位`是指各个分公司的下属部门,提问 “部门” 类似字眼时,通常涉及对该字段进行筛选。
+【举例】一分集客项目部|业务支撑中心|采购管理中心|网优交付项目部|...
+
+`经办日期` date,
+
+`所属分公司` VARCHAR(16),
+【描述】提问 “一分”、“北分”、“数分”、“智网”、“四分”、“七分公司”、“综合能源分公司” 等类似字眼时,通常涉及对该字段进行筛选。
+【所有可能的值】中通服建设有限公司一分公司|中通服建设有限公司二分公司|中通服建设有限公司三分公司|中通服建设有限公司四分公司|中通服建设有限公司五分公司|中通服建设有限公司六分公司|中通服建设有限公司七分公司|中通服建设有限公司北京分公司|中通服建设有限公司数字基建分公司|中通服建设有限公司上海分公司|中通服建设有限公司智网分公司|中通服建设有限公司河北分公司|中通服建设有限公司综合能源分公司|中通服建设有限公司本部
+
+`合同形式` VARCHAR(10),
+【所有可能的值】单项合同|订单合同|确收单合同|框架子合同|框架合同|结算单
+
+`是否主合同` bool,
+【所有可能的值】0|1|
+
+`合同名称` VARCHAR(137),
+【描述】`合同名称`即项目名称,从中可能提取到项目`地点`、`时间`、`客户名称`、`最终客户名称`、`专业`的相关信息。
+
+`合同编号` VARCHAR(37),
+
+`框架合同编号` VARCHAR(23),
+
+`框架合同名称` VARCHAR(85),
+
+`主合同编号` VARCHAR(25),
+
+`主合同名称` VARCHAR(103),
+
+`项目来源` VARCHAR(10),
+【描述】提问“招投标”“委托”“邀标”相关字眼时,通常涉及对该字段进行筛选。
+【所有可能的值】招投标|委托|邀标
+
+`投标项目名称` VARCHAR(92),
+【描述】`投标项目名称`和`合同名称`描述基本一致。
+
+`编号生成时间` date,
+
+`专业` VARCHAR(40),
+【举例】系统集成-信息系统集成服务-视频监控集成|工程设计-勘察设计-其他勘查设计-其他|工程施工-设备工程-通信设备安装调试-基站|其他-其他-咨询服务|工程施工-管线工程-通信线路施工-线路|系统集成-信息系统集成服务-其他|工程施工-管线工程-通信管道施工-本地网管道|工程施工-建筑智能化-智能化及集成|工程施工-设备工程-通信设备安装调试-数据-网络交换设备|工程施工-管线工程-通信线路施工-电缆|...
+
+`地点` VARCHAR(35),
+【描述】`地点`的值只包含省市区县的内容,不会包含一些常见的地区俗称。
+【注意】涉及地区俗称时,需要分析其所在的省市区县信息进行筛选,不能用地区俗称进行筛选。如:提问“京津冀”的项目时,筛选的`地点`应该是北京、天津或河北,而不是直接筛选 “京津冀”。
+
+`是否关联交易` bool,
+【所有可能的值】0|1|
+
+`合同类型名称` VARCHAR(29),
+【描述】`合同类型名称`通常和`专业`有关。
+【举例】市场经营收入类|系统集成服务类|工程施工类|工程设计类|工程总包收入|工程分包收入|通信网络维护类|设施管理类|国际类|国际贸易服务收入
+
+`聚焦行业` VARCHAR(19),
+
+`管理分公司` VARCHAR(16),
+
+`建议实施单位` VARCHAR(55),
+
+`项目部` VARCHAR(26),
+【描述】该字段和经办单位的意思一致。
+
+`最小经营单元` VARCHAR(17),
+【描述】`最小经营单位`结合了`所属分公司`和`项目部`的内容。
+
+`省公司统一编号` VARCHAR(31),
+
+`统一编号生成时间` date,
+
+`客户名称` VARCHAR(54),
+【举例】广州铁路公安局|广东电网有限责任公司广州供电局|广东电网有限责任公司广州供电局|中国移动通信集团安徽有限公司宣城分公司|中国移动通信集团安徽有限公司宣城分公司|中国移动通信集团安徽有限公司宣城分公司|中国电信股份有限公司合肥分公司|长沙海关技术中心|中共广东省委办公厅|南方电网数字平台科技(广东)有限公司|...
+
+`运营商` VARCHAR(10),
+【所有可能的值】中国电信|中国移动|中国联通|中国广电|中国铁塔|其他
+
+`中通服客商类型` VARCHAR(45),
+【举例】集团客户-建筑与房地产-建筑与房地产|集团客户-党政-党政管理|中国电信-主业上市-广东分公司|集团客户-中小聚类-中小企业|中国广电-中国广电网络集团-股份公司-广东省广播电视网络股份有限公司 (广东广电)|集团客户-互联网与IT传媒-互联网与IT科技|中国联通-各分公司-上海市分公司|中国电信-主业存续-广东省电信公司|中国电信-实业上市-安徽通服|中国电信-主业存续-山西分公司|...
+
+`合同签订金额(人民币)` float,
+【描述】`合同签订金额(人民币)`反映了项目的规模,提问“超大项目”“重大项目”“一般项目”“小项目”“营业额”等类似字眼时,通常涉及对该字段进行筛选。
+【注意】“超大项目”金额大于等于1亿,“重大项目”金额大于1000万而小于1亿,“一般项目”金额大于200万而小于1000万,“小项目”金额小于200万。有时需要计算的是金额的总值,有时需要计算平均值。
+
+`合同签订金额(不含税)` float,
+
+`是否垫资` bool,
+【所有可能的值】0|1|
+
+`垫资金额(元)` float,
+
+`垫资说明` VARCHAR(491),
+
+`签订日期` date,
+【描述】提问到“近几年”“去年”“今年”“上个季度”与项目日期相关内容时,通常涉及对该字段进行筛选。
+【注意】以CURRENT_DATE获取的时间为准作为当前日期。
+
+`签署日期` date,
+
+`合同有效期(开始)` date,
+
+`合同有效期(结束)` date,
+
+`最终客户名称` VARCHAR(57),
+【描述】`最终客户名称`描述和`客户名称描述一致`。
+
+`最终中通服客商类型` VARCHAR(45),
+【描述】`最终中通服客商类型`和`中通服客商类型`描述一致。
+
+`税率` float,
+
+`是否通服内部合作` bool,
+【描述】提问“内部”“内部项目”“内部合作”“通服内部”等类似字眼时,通常涉及对该字段进行筛选。
+【所有可能的值】0|1|
+
+`项目组织模式` VARCHAR(13),
+【所有可能的值】非总包非全咨|总包-过程总包-PC总包|总包-过程总包-EPC总包|总包-过程总包-施工总包|总包-过程总包-DB总包|全过程咨询|非总包非全过程咨询|总包-过程总包-EP总包
+
+`合同结算金额(含税)` float,
+
+`列账收入(含税)` float,
+
+`开票金额(含税)` float,
+
+`收款金额(含税)` float,
+
+`是否业务关闭` bool,
+【所有可能的值】0|1|
+
+`业务关闭时间` date,
+
+`是否财务关闭` bool,
+【所有可能的值】0|1|
+
+`财务关闭时间` date,
+
+`甲方订单编号` VARCHAR(256),
+
+`甲方合同编号` VARCHAR(68),
+
+`框架子合同编号` VARCHAR(106),
+
+`确收类型` VARCHAR(10),
+
+`业务拓展方式` VARCHAR(10),
+【所有可能的值】合作拓展|自主拓展|联合拓展|主业总包,通服分包|LH
+
+`主实业协同` bool,
+【所有可能的值】0|1|
+
+`协同类型` VARCHAR(10),
+
+`主业合同金额` float,
+
+`对方联系人` VARCHAR(17),
+
+`对方联系电话` VARCHAR(18),
+
+`中标时间` date,
+
+`协同拓展的主业公司` VARCHAR(10),
+
+`主业合同额` float,
+
+`是否运营商政企` bool
+【所有可能的值】0|1|
\ No newline at end of file
diff --git a/config/production/apikeys.conf b/config/production/apikeys.conf
new file mode 100644
index 0000000..5271729
--- /dev/null
+++ b/config/production/apikeys.conf
@@ -0,0 +1,10 @@
+[ccscc]
+dashscope_api_key = "sk-6b39b56d21aa4406b0c67061f2e31e81"
+admin = false
+databases = ["students"]
+
+["YUVietLgiGmtqzYUVIIGjrNoLMsGM0FI"]
+dashscope_api_key = "sk-6b39b56d21aa4406b0c67061f2e31e81"
+admin = false
+databases = ["contracts"]
+
diff --git a/config/production/app.conf b/config/production/app.conf
new file mode 100644
index 0000000..25a4502
--- /dev/null
+++ b/config/production/app.conf
@@ -0,0 +1,11 @@
+# 站点配置
+site_url = "https://release.platformtest.email"
+
+# 企业微信接口参数
+[bizwechat]
+token = "8kUGYXi"
+aes_key = "A5RyPqAu5UYBGI4QJTqLbBVyHXvevIUsaMrhct1lpxo"
+corp_id = "wwcbc2d6338dd362d0"
+corp_secret = "J7fAOm4QM2HvI5rpHiKVchbcNLcKABq-T-9v1B29RWo"
+agent_id = 1000002
+qgi_api_key = "YUVietLgiGmtqzYUVIIGjrNoLMsGM0FI"
\ No newline at end of file
diff --git a/config/production/database/contracts.conf b/config/production/database/contracts.conf
new file mode 100644
index 0000000..8b4a1c6
--- /dev/null
+++ b/config/production/database/contracts.conf
@@ -0,0 +1,86 @@
+# 数据库配置
+
+connection_string = "mysql+pymysql://root:H1wNPOz3@mysql.local:3306/contracts?charset=utf8mb4"
+
+type = "MySQL"
+product = "MySQL 5.7"
+
+metadata = """
+CREATE TABLE `contracts` (
+`经办人` VARCHAR(30),
+`经办单位` VARCHAR(16),
+`经办日期` date,
+`所属分公司` VARCHAR(16),
+`合同形式` VARCHAR(10),
+`是否主合同` bool,
+`合同名称` VARCHAR(137),
+`合同编号` VARCHAR(37),
+`框架合同编号` VARCHAR(23),
+`框架合同名称` VARCHAR(85),
+`主合同编号` VARCHAR(25),
+`主合同名称` VARCHAR(103),
+`项目来源` VARCHAR(10),
+`投标项目名称` VARCHAR(92),
+`编号生成时间` date,
+`专业` VARCHAR(40),
+`地点` VARCHAR(35),
+`是否关联交易` bool,
+`合同类型名称` VARCHAR(29),
+`聚焦行业` VARCHAR(19),
+`管理分公司` VARCHAR(16),
+`建议实施单位` VARCHAR(55),
+`项目部` VARCHAR(26),
+`最小经营单元` VARCHAR(17),
+`省公司统一编号` VARCHAR(31),
+`统一编号生成时间` date,
+`客户名称` VARCHAR(54),
+`运营商` VARCHAR(10),
+`中通服客商类型` VARCHAR(45),
+`合同签订金额(人民币)` float,
+`合同签订金额(不含税)` float,
+`是否垫资` bool,
+`垫资金额(元)` float,
+`垫资说明` VARCHAR(491),
+`签订日期` date,
+`签署日期` date,
+`合同有效期(开始)` date,
+`合同有效期(结束)` date,
+`最终客户名称` VARCHAR(57),
+`最终中通服客商类型` VARCHAR(45),
+`税率` float,
+`是否通服内部合作` bool,
+`项目组织模式` VARCHAR(13),
+`合同结算金额(含税)` float,
+`列账收入(含税)` float,
+`开票金额(含税)` float,
+`收款金额(含税)` float,
+`是否业务关闭` bool,
+`业务关闭时间` date,
+`是否财务关闭` bool,
+`财务关闭时间` date,
+`甲方订单编号` VARCHAR(256),
+`甲方合同编号` VARCHAR(68),
+`框架子合同编号` VARCHAR(106),
+`确收类型` VARCHAR(10),
+`业务拓展方式` VARCHAR(10),
+`主实业协同` bool,
+`协同类型` VARCHAR(10),
+`主业合同金额` float,
+`对方联系人` VARCHAR(17),
+`对方联系电话` VARCHAR(18),
+`中标时间` date,
+`协同拓展的主业公司` VARCHAR(10),
+`主业合同额` float,
+`是否运营商政企` bool
+); --合同信息表
+
+- 如果要对 '经办人' 进行筛选,必须使用 LIKE 语句进行模糊匹配。
+- 如果要对 '合同名称' 进行筛选,必须使用 LIKE 语句进行模糊匹配。
+- 如果要对 '客户名称' 进行筛选,必须使用 LIKE 语句进行模糊匹配。
+- 如果要对 '最终客户名称' 进行筛选,必须使用 LIKE 语句进行模糊匹配。
+- 如果要对 '所属分公司' 进行筛选,必须使用 LIKE 语句进行模糊匹配。
+- 如果要对 '经办单位' 进行筛选,必须使用 LIKE 语句进行模糊匹配。
+- 如果要对 '投标项目名称' 进行筛选,必须使用 LIKE 语句进行模糊匹配。
+- 如果要对 '项目名称来源' 进行筛选,必须使用 LIKE 语句进行模糊匹配。
+- 数据库中的金额单位为人民币元,输出时必须除以一万用 ROUND 函数取整,输出结果使用万元为单位,。
+"""
diff --git a/config/production/database/students.conf b/config/production/database/students.conf
new file mode 100644
index 0000000..6d2cac9
--- /dev/null
+++ b/config/production/database/students.conf
@@ -0,0 +1,36 @@
+# 数据库配置
+
+connection_string = "mysql+pymysql://root:H1wNPOz3@mysql.local:3306/students?charset=utf8mb4"
+
+type = "MySQL"
+product = "MySQL 5.7"
+
+metadata = """
+CREATE TABLE students (
+ student_id INTEGER PRIMARY KEY,
+ student_name VARCHAR(100), -- 学生姓名
+ major VARCHAR(100), -- 专业
+ year_of_enrollment INTEGER, -- 入学年份
+ student_age INTEGER -- 学生年龄
+);
+
+CREATE TABLE courses (
+ course_id INTEGER PRIMARY KEY,
+ course_name VARCHAR(100), -- 课程名称
+ credit REAL -- 学分
+);
+
+CREATE TABLE scores (
+ student_id INTEGER,
+ course_id INTEGER,
+ score INTEGER, -- 得分
+ semester VARCHAR(50), -- 学期
+ PRIMARY KEY (student_id, course_id),
+ FOREIGN KEY (student_id) REFERENCES students(student_id),
+ FOREIGN KEY (course_id) REFERENCES courses(course_id)
+);
+
+- 数据库中 'courses' 表中 'course_name' 字段有效值为 '计算机基础','数据结构','高等物理','线性代数','微积分','编程语言','量子力学','概率论','数据库系统','计算机网络'。
+- 数据库中 'scores' 表中 'semester' 字段有效值为 '2020年秋季', '2021年春季', '2021年秋季', '2022年春季', '2020年秋季', '2021年春季', '2021年秋季', '2022年春季', '2022年秋季', '2023年春季'。
+- 数据库中 'students' 表中 'major' 字段有效值为 '计算机科学', '物理学', '数学'。
+"""
diff --git a/config/production/logging.yaml b/config/production/logging.yaml
new file mode 100644
index 0000000..cee30fd
--- /dev/null
+++ b/config/production/logging.yaml
@@ -0,0 +1,58 @@
+version: 1
+disable_existing_loggers: false
+formatters:
+ default:
+ (): uvicorn.logging.DefaultFormatter
+ fmt: '%(asctime)s - %(levelname)s %(message)s'
+ use_colors: null
+ access:
+ (): uvicorn.logging.AccessFormatter
+ fmt: '%(asctime)s - %(levelname)s %(client_addr)s - "%(request_line)s" %(status_code)s'
+handlers:
+ default:
+ formatter: default
+ class: logging.StreamHandler
+ stream: 'ext://sys.stderr'
+ access:
+ formatter: access
+ class: logging.StreamHandler
+ stream: 'ext://sys.stdout'
+ sql_file:
+ class: logging.handlers.RotatingFileHandler
+ level: INFO
+ formatter: default
+ filename: ./log/sql.log
+ maxBytes: 10485760 # 10MB
+ backupCount: 50 #保留50个log文件
+ encoding: utf8
+ err_file:
+ class: logging.handlers.RotatingFileHandler
+ level: ERROR
+ formatter: default
+ filename: ./log/err.log
+ maxBytes: 10485760 # 10MB
+ backupCount: 50 #保留50个log文件
+ encoding: utf8
+loggers:
+ root:
+ handlers:
+ - default
+ - err_file
+ level: INFO
+ uvicorn:
+ level: INFO
+ uvicorn.error:
+ level: INFO
+ uvicorn.access:
+ handlers:
+ - access
+ level: INFO
+ propagate: false
+ sqlcode:
+ level: DEBUG
+ question:
+ handlers:
+ - sql_file
+ level: INFO
+ propagate: false
+
diff --git a/config/production/qwen.conf b/config/production/qwen.conf
new file mode 100644
index 0000000..e51bcfc
--- /dev/null
+++ b/config/production/qwen.conf
@@ -0,0 +1,36 @@
+# qwen 模型配置
+
+# system role prompt
+system = "你擅长编写 SQL 代码,请结合具体问题编写正确规范的 SQL 代码"
+
+# prompt 模板及参数,在模板中可以使用 {question} {database.metadata} 以及 {params.xxx} 引用参数
+prompt = """
+### 数据库结构
+
+{database.metadata}
+
+
+### 问题
+
+根据以上建表语句,生成一个 SQL 来回答如下问题: [QUESTION]{question}[/QUESTION]
+
+
+### 要求
+
+- 输出的字段名必须用中文描述。
+- 输出的 SQL 语句必须能够通过 {database.product} 验证。
+- 输出的 SQL 语句必须包含在 ```sql ``` 标记中。
+- 输出的 SQL 语句不要添加注释。
+
+{requirements}
+
+### 输出格式
+'''sql
+[SQL]
+'''
+[ANSWER]
+"""
+
+params.requirements = """"
+- 除了 SELECT 语句,不要输出任何其他内容 。
+"""
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..b337734
--- /dev/null
+++ b/main.py
@@ -0,0 +1,35 @@
+#!/usr/bin/env python
+# -*- encoding:utf-8 -*-
+
+import fastapi.staticfiles
+import uvicorn.config
+from http import HTTPStatus
+import fastapi
+import logging
+import config
+
+app = fastapi.FastAPI()
+
+app.mount("/output", fastapi.staticfiles.StaticFiles(directory="output"), name="output")
+
+logger = logging.getLogger('sqlcode')
+
+logger.info('ENVIRONMENT: %s', config.ENVIRONMENT)
+logger.debug('LOG DEBUG : ON')
+
+@app.get('/health', include_in_schema=False)
+def health():
+ return fastapi.Response('OK', HTTPStatus.OK)
+
+from query import app as query_app
+app.mount('/query', query_app)
+
+from wechat import app as wechat_app
+app.mount('/wechat', wechat_app)
+
+from tgi_app import app as tgi_app
+app.mount('/tgi', tgi_app)
+
+if __name__ == '__main__':
+ import uvicorn
+ uvicorn.run(app, host='0.0.0.0', port=8000)
diff --git a/pg-data/contracts-mysql.sql b/pg-data/contracts-mysql.sql
new file mode 100644
index 0000000..d089cd5
--- /dev/null
+++ b/pg-data/contracts-mysql.sql
@@ -0,0 +1,67 @@
+CREATE TABLE `contracts` (
+`经办人` varchar(50),
+`经办单位` varchar(20),
+`经办日期` date,
+`所属分公司` varchar(20),
+`合同形式` varchar(10),
+`是否主合同` tinyint(1),
+`合同名称` varchar(255),
+`合同编号` varchar(50),
+`框架合同编号` varchar(50),
+`框架合同名称` varchar(100),
+`主合同编号` varchar(50),
+`主合同名称` varchar(255),
+`项目来源` varchar(10),
+`投标项目名称` varchar(255),
+`编号生成时间` date,
+`专业` varchar(50),
+`地点` varchar(50),
+`是否关联交易` tinyint(1),
+`合同类型名称` varchar(50),
+`聚焦行业` varchar(20),
+`管理分公司` varchar(50),
+`建议实施单位` varchar(100),
+`项目部` varchar(50),
+`最小经营单元` varchar(50),
+`省公司统一编号` varchar(50),
+`统一编号生成时间` date,
+`客户名称` varchar(100),
+`运营商` varchar(10),
+`中通服客商类型` varchar(50),
+`合同签订金额(人民币)` double,
+`合同签订金额(不含税)` double,
+`是否垫资` tinyint(1),
+`垫资金额(元)` double,
+`垫资说明` varchar(999),
+`签订日期` date,
+`签署日期` date,
+`合同有效期(开始)` date,
+`合同有效期(结束)` date,
+`最终客户名称` varchar(100),
+`最终中通服客商类型` varchar(50),
+`税率` double,
+`是否通服内部合作` varchar(20),
+`项目组织模式` varchar(20),
+`合同结算金额(含税)` double,
+`列账收入(含税)` double,
+`开票金额(含税)` double,
+`收款金额(含税)` double,
+`是否业务关闭` tinyint(1),
+`业务关闭时间` date,
+`是否财务关闭` tinyint(1),
+`财务关闭时间` date,
+`甲方订单编号` varchar(999),
+`甲方合同编号` varchar(100),
+`框架子合同编号` varchar(255),
+`确收类型` varchar(10),
+`业务拓展方式` varchar(10),
+`主实业协同` tinyint(1),
+`协同类型` varchar(10),
+`主业合同金额` double,
+`对方联系人` varchar(50),
+`对方联系电话` varchar(50),
+`中标时间` double,
+`协同拓展的主业公司` double,
+`主业合同额` double,
+`是否运营商政企` tinyint(1)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='合同信息表';
diff --git a/pg-data/contracts-pg.sql b/pg-data/contracts-pg.sql
new file mode 100644
index 0000000..c547329
--- /dev/null
+++ b/pg-data/contracts-pg.sql
@@ -0,0 +1,68 @@
+CREATE TABLE contracts (
+ -- id INT PRIMARY KEY, -- 合同ID,自增主键
+ handler VARCHAR(1000), -- 经办人
+ handling_unit VARCHAR(1000), -- 经办单位
+ handling_date TIMESTAMP, -- 经办日期
+ branch_company VARCHAR(1000), -- 所属分公司
+ contract_form VARCHAR(1000), -- 合同形式
+ is_main_contract BOOLEAN, -- 是否主合同
+ contract_name VARCHAR(1000), -- 合同名称
+ contract_number VARCHAR(1000), -- 合同编号
+ framework_contract_number VARCHAR(1000), -- 框架合同编号
+ framework_contract_name VARCHAR(1000), -- 框架合同名称
+ main_contract_number VARCHAR(1000), -- 主合同编号
+ main_contract_name VARCHAR(1000), -- 主合同名称
+ project_source VARCHAR(1000), -- 项目来源
+ bidding_project_name VARCHAR(1000), -- 投标项目名称
+ number_generated_time TIMESTAMP, -- 编号生成时间
+ specialty VARCHAR(1000), -- 专业
+ location VARCHAR(1000), -- 地点
+ is_related_transaction BOOLEAN, -- 是否关联交易
+ contract_type_name VARCHAR(1000), -- 合同类型名称
+ focused_industry VARCHAR(1000), -- 聚焦行业
+ management_branch VARCHAR(1000), -- 管理分公司
+ suggested_implementation_unit VARCHAR(1000), -- 建议实施单位
+ project_department VARCHAR(1000), -- 项目部
+ min_business_unit VARCHAR(1000), -- 最小经营单元
+ unified_number_province VARCHAR(1000), -- 省公司统一编号
+ unified_number_generated_time TIMESTAMP, -- 统一编号生成时间
+ customer_name VARCHAR(1000), -- 客户名称
+ operator VARCHAR(1000), -- 运营商
+ zhongtongfu_customer_type VARCHAR(1000), -- 中通服客商类型
+ contract_amount_rmb DECIMAL(15,2), -- 合同签订金额(人民币)
+ contract_amount_ex_tax DECIMAL(15,2), -- 合同签订金额(不含税)
+ is_advance_fund BOOLEAN, -- 是否垫资
+ advance_fund_amount DECIMAL(15,2), -- 垫资金额(元)
+ advance_fund_description TEXT, -- 垫资说明
+ signing_date DATE, -- 签订日期
+ execution_date DATE, -- 签署日期
+ contract_start_date DATE, -- 合同有效期(开始)
+ contract_end_date DATE, -- 合同有效期(结束)
+ final_customer_name VARCHAR(1000), -- 最终客户名称
+ final_zhongtongfu_customer_type VARCHAR(1000), -- 最终中通服客商类型
+ tax_rate DECIMAL(5,2), -- 税率
+ is_internal_cooperation BOOLEAN, -- 是否通服内部合作
+ project_organization_mode VARCHAR(1000), -- 项目组织模式
+ contract_settlement_amount_inc_tax DECIMAL(15,2), -- 合同结算金额(含税)
+ book_revenue_inc_tax DECIMAL(15,2), -- 列账收入(含税)
+ invoice_amount_inc_tax DECIMAL(15,2), -- 开票金额(含税)
+ collection_amount_inc_tax DECIMAL(15,2), -- 收款金额(含税)
+ is_business_closed BOOLEAN, -- 是否业务关闭
+ business_close_time DATE, -- 业务关闭时间
+ is_financial_closed BOOLEAN, -- 是否财务关闭
+ financial_close_time DATE, -- 财务关闭时间
+ party_a_order_number VARCHAR(1000), -- 甲方订单编号
+ party_a_contract_number VARCHAR(1000), -- 甲方合同编号
+ framework_subcontract_number VARCHAR(1000), -- 框架子合同编号
+ revenue_recognition_type VARCHAR(1000), -- 确收类型
+ business_development_method VARCHAR(1000), -- 业务拓展方式
+ main_industry_collaboration VARCHAR(1000), -- 主实业协同
+ collaboration_type VARCHAR(1000), -- 协同类型
+ main_industry_contract_amount DECIMAL(15,2), -- 主业合同金额
+ counterparty_contact VARCHAR(1000), -- 对方联系人
+ counterparty_phone VARCHAR(20), -- 对方联系电话
+ bidding_win_time DATE, -- 中标时间
+ collaborative_expansion_company VARCHAR(1000), -- 协同拓展的主业公司
+ main_industry_contract_value DECIMAL(15,2), -- 主业合同额
+ is_operator_enterprise BOOLEAN -- 是否运营商政企
+);
diff --git a/pg-data/contracts-struct.txt b/pg-data/contracts-struct.txt
new file mode 100644
index 0000000..d73e6a9
--- /dev/null
+++ b/pg-data/contracts-struct.txt
@@ -0,0 +1,66 @@
+column dtype value
+经办人 text 中通服建设五分公司-惠州项目部-陈思敏
+经办单位 text 惠州项目部
+经办日期 date 2020-06-29 10:38:56
+所属分公司 text 中通服建设有限公司五分公司
+合同形式 text 单项合同
+是否主合同 bool 是
+合同名称 text 远程视频会议及扩音系统项目合同
+合同编号 text SGC-SRHT-CJ-2020-01125
+框架合同编号 text
+框架合同名称 text
+主合同编号 text
+主合同名称 text
+项目来源 text 邀标
+投标项目名称 text 远程视频会议及扩音系统项目
+编号生成时间 date 2020-06-29
+专业 text 系统集成-信息系统集成服务-视频监控集成
+地点 text 中国-广东省-惠州市-市辖区
+是否关联交易 bool 否
+合同类型名称 text 市场经营收入类-系统集成服务类合同(A-8)
+聚焦行业 text 其他
+管理分公司 text 中通服建设有限公司五分公司
+建议实施单位 text 中通服建设有限公司五分公司
+项目部 text 惠州项目部
+最小经营单元 text 五分公司-惠州项目部
+省公司统一编号 text S4412-2020-006338
+统一编号生成时间 date 2020-06-29
+客户名称 text 惠州市建设工程质量检测中心
+运营商 text 其他
+中通服客商类型 text 集团客户-建筑与房地产-建筑与房地产
+合同签订金额(人民币) float 259805
+合同签订金额(不含税) float 238353.21
+是否垫资 bool
+垫资金额(元) float 0
+垫资说明 text
+签订日期 date 2020-06-29
+签署日期 date 2020-06-23
+合同有效期(开始) date 2020-06-29
+合同有效期(结束) date 2022-07-31
+最终客户名称 text 惠州市建设工程质量检测中心
+最终中通服客商类型 text 集团客户-建筑与房地产-建筑与房地产
+税率 float 9
+是否通服内部合作 bool 否
+项目组织模式 text 非总包非全咨
+合同结算金额(含税) float 259805
+列账收入(含税) float 259805
+开票金额(含税) float
+收款金额(含税) float
+是否业务关闭 bool 是
+业务关闭时间 date 2020-09-27
+是否财务关闭 bool 是
+财务关闭时间 date 2020-11-18
+甲方订单编号 text
+甲方合同编号 text 0
+框架子合同编号 text
+确收类型 text 全额
+业务拓展方式 text 合作拓展
+主实业协同 bool N
+协同类型 text
+主业合同金额 float
+对方联系人 text 朱光南
+对方联系电话 text 15220678621
+中标时间 date
+协同拓展的主业公司 text
+主业合同额 float
+是否运营商政企 bool
diff --git a/pg-data/student-manage.sql b/pg-data/student-manage.sql
new file mode 100644
index 0000000..5f29ab1
--- /dev/null
+++ b/pg-data/student-manage.sql
@@ -0,0 +1,66 @@
+CREATE TABLE students (
+ student_id INTEGER PRIMARY KEY,
+ student_name VARCHAR(100), -- 学生姓名
+ major VARCHAR(100), -- 专业
+ year_of_enrollment INTEGER, -- 入学年份
+ student_age INTEGER -- 学生年龄
+);
+
+CREATE TABLE courses (
+ course_id INTEGER PRIMARY KEY,
+ course_name VARCHAR(100), -- 课程名称
+ credit REAL -- 学分
+);
+
+CREATE TABLE scores (
+ student_id INTEGER,
+ course_id INTEGER,
+ score INTEGER, -- 得分
+ semester VARCHAR(50), -- 学期
+ PRIMARY KEY (student_id, course_id),
+ FOREIGN KEY (student_id) REFERENCES students(student_id),
+ FOREIGN KEY (course_id) REFERENCES courses(course_id)
+);
+
+INSERT INTO students (student_id, student_name, major, year_of_enrollment, student_age) VALUES
+(1, '张三', '计算机科学', 2020, 20),
+(2, '李四', '计算机科学', 2021, 19),
+(3, '王五', '物理学', 2020, 21),
+(4, '赵六', '数学', 2021, 19),
+(5, '周七', '计算机科学', 2022, 18),
+(6, '吴八', '物理学', 2020, 21),
+(7, '郑九', '数学', 2021, 19),
+(8, '孙十', '计算机科学', 2022, 18),
+(9, '刘十一', '物理学', 2020, 21),
+(10, '陈十二', '数学', 2021, 19);
+
+INSERT INTO courses (course_id, course_name, credit) VALUES
+(1, '计算机基础', 3),
+(2, '数据结构', 4),
+(3, '高等物理', 3),
+(4, '线性代数', 4),
+(5, '微积分', 5),
+(6, '编程语言', 4),
+(7, '量子力学', 3),
+(8, '概率论', 4),
+(9, '数据库系统', 4),
+(10, '计算机网络', 4);
+
+INSERT INTO scores (student_id, course_id, score, semester) VALUES
+(1, 1, 90, '2020年秋季'),
+(1, 2, 85, '2021年春季'),
+(2, 1, 88, '2021年秋季'),
+(2, 2, 90, '2022年春季'),
+(3, 3, 92, '2020年秋季'),
+(3, 4, 85, '2021年春季'),
+(4, 3, 88, '2021年秋季'),
+(4, 4, 86, '2022年春季'),
+(5, 1, 90, '2022年秋季'),
+(5, 2, 87, '2023年春季');
+
+
+-- students.student_id can be joined with scores.student_id
+-- courses.course_id can be joined with scores.course_id
+-- 专业名称包括 计算机科学,物理学,数学
+-- 课程包括 计算机基础,数据结构,高等物理,线性代数,微积分,编程语言,量子力学,概率论,数据库系统,计算机网络
+-- 学期包括 2020年秋季,2021年春季,2021年秋季,2022年春季,2022年秋季,2023年春季
diff --git a/query.py b/query.py
new file mode 100644
index 0000000..9f2bd60
--- /dev/null
+++ b/query.py
@@ -0,0 +1,139 @@
+#!/usr/bin/env python
+# -*- encoding:utf-8 -*-
+
+import fastapi.staticfiles
+from http import HTTPStatus
+import fastapi
+import config
+import asyncio
+from fastapi.middleware.cors import CORSMiddleware
+
+# 设置允许的源,可以是单个源或多个源
+
+app = fastapi.FastAPI()
+
+app.mount("/output",fastapi.staticfiles.StaticFiles(directory="output"), name="output")
+
+origins = [
+ "*"
+]
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=origins,
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+from sqlcode.qgi import Executor, NewFormatter
+from sqlcode.langchain_model import Generator
+from sqlcode.modelloader import ModelLoader, ModelManager
+from sqlcode.qa_cache import QACache, load_qa_pairs
+from pydantic import BaseModel
+from sqlcode.qgi import ReturnType, QueryResult
+from sqlcode.multi_agent import create_sql_graph
+
+class QueryRequest(BaseModel):
+ '''
+ * question: str 表示要查询的问题。
+ * return_type: str 表示要返回的结果类型。
+ '''
+ question: str
+ return_type: ReturnType = ReturnType.TEXT
+
+qa_dict = load_qa_pairs('log/test.log')
+qa_cache = QACache(similarity_threshold=0.999)
+qa_cache.add(qa_dict)
+
+def try_search_in_cache(question:str) -> QueryResult:
+ cached_answer = qa_cache.find_similar(question)
+ if cached_answer != None:
+ return QueryResult(status=HTTPStatus.OK, result=cached_answer, sql='', thought='从缓存得到结果', error=None)
+ else:
+ return None
+
+
+qwen_cfg = config.qwen_config("qwen_graph.conf")
+re_cfg = config.refineProblem_config()
+
+model_cfg = config.model_config()
+modelLoader = ModelLoader(model_cfg)
+modelManager = ModelManager(modelLoader)
+
+@app.post('/{model_name}/{database}')
+async def query(model_name:str, database:str, apikey:str, req:QueryRequest) -> QueryResult:
+ '''
+ 根据请求执行查询并返回查询结果。
+
+ - :param model: str 表示要使用的大模型的名称。例如 qwen-turbo 等
+ - :param database: str 表示要查询的数据库的名称。
+ - :param req: QueryRequest 包含查询的具体问题和期望的返回类型。
+ '''
+ print('---------进入----------')
+ # 完全匹配查找缓存结果
+ searched_result = try_search_in_cache(req.question)
+ if searched_result is not None:
+ print('从缓存找到结果')
+ print(searched_result.result)
+ return searched_result
+
+ # 调用模型生成sql和答案
+ client = config.api_key(apikey)
+ if not client:
+ return QueryResult(status=HTTPStatus.UNAUTHORIZED, error='invalid apikey')
+
+ if not database in client.databases:
+ return QueryResult(status=HTTPStatus.FORBIDDEN, error='database permission denied')
+
+ metadata = config.metadata(database)
+ modelManager.switch_model(model_name)
+ model = modelManager.get_model()
+
+ agent_executor = create_sql_graph(model=model, qwen_cfg=qwen_cfg, data_cfg=metadata)
+
+ generator = Generator(agentExcutor=agent_executor,
+ messages=[{'role':'system', 'content': qwen_cfg.system}],
+ apikey=client.dashscope_api_key,
+ seed=0,
+ )
+
+ input = {"messages": [("human", req.question)], "iterations": 0}
+
+ if req.return_type == ReturnType.SQL:
+ return generator.generate(input)
+
+ formatter = NewFormatter(format=req.return_type,
+ tranlate=None,
+ output_dir=config.osp.join(config.BASE_DIR, 'output'),
+ site_url=config.app_config().site_url + '/output',
+ )
+
+ executor = Executor(generator=generator, formatter=formatter)
+
+ query_result = executor.query(connection_string=metadata.connection_string,
+ input=input,
+ )
+
+ qa_cache.add({req.question: query_result.result})
+ return query_result
+
+
+
+if __name__ == '__main__':
+ import sys
+ if len(sys.argv) > 1 and sys.argv[1] == 'test':
+ import logging
+ logging.basicConfig(level=logging.DEBUG)
+ ret = asyncio.run(query(apikey='ccscc',
+ model='qwen-max',
+ database='students',
+ req=QueryRequest(
+ question="2022年有哪几门课程",
+ return_type=ReturnType.TEXT
+ ),
+ ))
+ print(ret)
+ else:
+ import uvicorn
+ uvicorn.run(app, host='0.0.0.0', port=9000)
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..8e7a02f
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,15 @@
+dashscope
+fastapi
+sqlalchemy
+tencentcloud-sdk-python
+pandas
+psycopg2-binary
+openpyxl
+tabulate
+wcwidth
+loguru
+markdown
+pycryptodome
+pymysql
+uvicorn
+rtoml
\ No newline at end of file
diff --git a/sqlcode/baidu_qianfan_endpoint.py b/sqlcode/baidu_qianfan_endpoint.py
new file mode 100644
index 0000000..c8c84c3
--- /dev/null
+++ b/sqlcode/baidu_qianfan_endpoint.py
@@ -0,0 +1,237 @@
+from __future__ import annotations
+
+import logging
+from typing import (
+ Any,
+ AsyncIterator,
+ Dict,
+ Iterator,
+ List,
+ Optional,
+)
+
+from langchain_core.callbacks import (
+ AsyncCallbackManagerForLLMRun,
+ CallbackManagerForLLMRun,
+)
+from langchain_core.language_models.llms import LLM
+from langchain_core.outputs import GenerationChunk
+from langchain_core.pydantic_v1 import Field, SecretStr
+from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
+
+logger = logging.getLogger(__name__)
+
+
+class QianfanLLMEndpoint(LLM):
+ """Baidu Qianfan hosted open source or customized models.
+
+ To use, you should have the ``qianfan`` python package installed, and
+ the environment variable ``qianfan_ak`` and ``qianfan_sk`` set with
+ your API key and Secret Key.
+
+ ak, sk are required parameters which you could get from
+ https://cloud.baidu.com/product/wenxinworkshop
+
+ Example:
+ .. code-block:: python
+
+ from langchain_community.llms import QianfanLLMEndpoint
+ qianfan_model = QianfanLLMEndpoint(model="ERNIE-Bot",
+ endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk")
+ """
+
+ init_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """init kwargs for qianfan client init, such as `query_per_second` which is
+ associated with qianfan resource object to limit QPS"""
+
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """extra params for model invoke using with `do`."""
+
+ client: Any
+
+ qianfan_ak: Optional[SecretStr] = None
+ qianfan_sk: Optional[SecretStr] = None
+
+ streaming: Optional[bool] = False
+ """Whether to stream the results or not."""
+
+ model: str = "ERNIE-Bot-turbo"
+ """Model name.
+ you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
+
+ preset models are mapping to an endpoint.
+ `model` will be ignored if `endpoint` is set
+ """
+
+ endpoint: Optional[str] = None
+ """Endpoint of the Qianfan LLM, required if custom model used."""
+
+ request_timeout: Optional[int] = 60
+ """request timeout for chat http requests"""
+
+ top_p: Optional[float] = 0.8
+ temperature: Optional[float] = 0.95
+ penalty_score: Optional[float] = 1
+ """Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo.
+ In the case of other model, passing these params will not affect the result.
+ """
+
+ @pre_init
+ def validate_environment(cls, values: Dict) -> Dict:
+ values["qianfan_ak"] = convert_to_secret_str(
+ get_from_dict_or_env(
+ values,
+ "qianfan_ak",
+ "QIANFAN_AK",
+ default="",
+ )
+ )
+ values["qianfan_sk"] = convert_to_secret_str(
+ get_from_dict_or_env(
+ values,
+ "qianfan_sk",
+ "QIANFAN_SK",
+ default="",
+ )
+ )
+
+ params = {
+ **values.get("init_kwargs", {}),
+ "model": values["model"],
+ }
+ if values["qianfan_ak"].get_secret_value() != "":
+ params["ak"] = values["qianfan_ak"].get_secret_value()
+ if values["qianfan_sk"].get_secret_value() != "":
+ params["sk"] = values["qianfan_sk"].get_secret_value()
+ if values["endpoint"] is not None and values["endpoint"] != "":
+ params["endpoint"] = values["endpoint"]
+ try:
+ import qianfan
+
+ values["client"] = qianfan.Completion(**params)
+ except ImportError:
+ raise ImportError(
+ "qianfan package not found, please install it with "
+ "`pip install qianfan`"
+ )
+ return values
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ return {
+ **{"endpoint": self.endpoint, "model": self.model},
+ **super()._identifying_params,
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "baidu-qianfan-endpoint"
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling Qianfan API."""
+ normal_params = {
+ "model": self.model,
+ "endpoint": self.endpoint,
+ "stream": self.streaming,
+ "request_timeout": self.request_timeout,
+ "top_p": self.top_p,
+ "temperature": self.temperature,
+ "penalty_score": self.penalty_score,
+ }
+
+ return {**normal_params, **self.model_kwargs}
+
+ def _convert_prompt_msg_params(
+ self,
+ prompt: str,
+ **kwargs: Any,
+ ) -> dict:
+ if "streaming" in kwargs:
+ kwargs["stream"] = kwargs.pop("streaming")
+ return {
+ **{"prompt": prompt, "model": self.model},
+ **self._default_params,
+ **kwargs,
+ }
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """Call out to an qianfan models endpoint for each generation with a prompt.
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: Optional list of stop words to use when generating.
+ Returns:
+ The string generated by the model.
+
+ Example:
+ .. code-block:: python
+ response = qianfan_model.invoke("Tell me a joke.")
+ """
+ if self.streaming:
+ completion = ""
+ for chunk in self._stream(prompt, stop, run_manager, **kwargs):
+ completion += chunk.text
+ return completion
+ params = self._convert_prompt_msg_params(prompt, **kwargs)
+ params["stop"] = stop
+ response_payload = self.client.do(**params)
+
+ return response_payload["result"]
+
+ async def _acall(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ if self.streaming:
+ completion = ""
+ async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
+ completion += chunk.text
+ return completion
+
+ params = self._convert_prompt_msg_params(prompt, **kwargs)
+ params["stop"] = stop
+ response_payload = await self.client.ado(**params)
+
+ return response_payload["result"]
+
+ def _stream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[GenerationChunk]:
+ params = self._convert_prompt_msg_params(prompt, **{**kwargs, "stream": True})
+ params["stop"] = stop
+ for res in self.client.do(**params):
+ if res:
+ chunk = GenerationChunk(text=res["result"])
+ if run_manager:
+ run_manager.on_llm_new_token(chunk.text)
+ yield chunk
+
+ async def _astream(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[GenerationChunk]:
+ params = self._convert_prompt_msg_params(prompt, **{**kwargs, "stream": True})
+ params["stop"] = stop
+ async for res in await self.client.ado(**params):
+ if res:
+ chunk = GenerationChunk(text=res["result"])
+ if run_manager:
+ await run_manager.on_llm_new_token(chunk.text)
+ yield chunk
diff --git a/sqlcode/langchain_model.py b/sqlcode/langchain_model.py
new file mode 100644
index 0000000..49c3ede
--- /dev/null
+++ b/sqlcode/langchain_model.py
@@ -0,0 +1,70 @@
+'''
+Author: scutlzc scutlzc@gmail.com
+Date: 2024-07-11 16:36:34
+LastEditors: scutlzc scutlzc@gmail.com
+LastEditTime: 2024-07-12 11:03:24
+FilePath: \bizwechat\sqlcode\langchain_model.py
+Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
+'''
+from dashscope import Generation
+import random
+from http import HTTPStatus
+from enum import Enum
+from .qgi import Generator as BaseGenerator
+from langchain_community.llms import Tongyi
+from langchain_core.language_models import BaseLLM
+from langchain.prompts import PromptTemplate
+from langchain.chains.llm import LLMChain
+from enum import StrEnum
+from .modelloader import ModelManager, ModelLoader
+from langchain.agents import AgentExecutor
+from langchain_core.exceptions import OutputParserException
+from sqlcode.utils import enrich_input
+from typing import Optional
+from langgraph.graph.graph import CompiledGraph
+
+class Generator(BaseGenerator):
+ def __init__(self, agentExcutor:CompiledGraph, messages:list[dict[str,str]]|None, apikey:str, seed:int=0, max_retry: Optional[int] = 2) -> None:
+ self.message = messages
+ self.apikey = apikey
+ self.seed = seed
+ self.agentExcutor = agentExcutor
+ self.max_retry = max_retry
+
+ def _generate(self, input: dict) -> tuple[HTTPStatus]:
+ seed = self.seed
+ if seed == 0:
+ seed = random.randint(1, 10000)
+
+ cnt = 0
+ err = None
+ while cnt < self.max_retry:
+ try:
+ if cnt == 0:
+ response = self.agentExcutor.invoke(input)["messages"][-1][1]
+ res = {"input": input, "output": response, "err": None}
+ else:
+ new_input = enrich_input("之前agent执行流程发生错误,请模型输出严格按照prompt要求", input)
+ response = self.agentExcutor.invoke(new_input)["messages"][-1][1]
+ res = {"input": input, "output": response, "err": err}
+ status_code = HTTPStatus.OK
+ except ConnectionError as connectionError:
+ status_code = HTTPStatus.REQUEST_TIMEOUT
+ err = str(connectionError)
+ except OutputParserException as outputException:
+ cnt+=1
+ err = str(outputException)
+ print("outputException: {}".format(outputException))
+ except KeyError as keyError:
+ cnt+=1
+ err = str(keyError)
+ print("KeyError: {}".format(keyError))
+ else:
+ break
+
+ if cnt == self.max_retry:
+ status_code = HTTPStatus.INTERNAL_SERVER_ERROR
+ res = {"input": input, "output": "模型处理异常,建议您将问题表述更加清晰,再重新尝试", "err": err}
+
+ return status_code, res
+
\ No newline at end of file
diff --git a/sqlcode/modelloader.py b/sqlcode/modelloader.py
new file mode 100644
index 0000000..ae0b76e
--- /dev/null
+++ b/sqlcode/modelloader.py
@@ -0,0 +1,48 @@
+'''
+Author: scutlzc scutlzc@gmail.com
+Date: 2024-07-12 09:23:37
+LastEditors: scutlzc scutlzc@gmail.com
+LastEditTime: 2024-07-12 09:51:30
+FilePath: \bizwechat\sqlcode\modelloader.py
+Description:
+
+Copyright (c) 2024 by ${git_name_email}, All Rights Reserved.
+'''
+import yaml
+# from langchain.llms import OpenAI
+# from langchain.chains import LLamaChain, CustomChain
+from langchain_community.llms import Tongyi
+from langchain_community.chat_models import ChatTongyi
+from sqlcode.baidu_qianfan_endpoint import QianfanLLMEndpoint
+
+class ModelLoader:
+ def __init__(self, config):
+ # with open(config_path, 'r') as file:
+ # self.config = yaml.safe_load(file)
+ self.config = config
+
+ def load_model(self, model_name):
+ model_config = self.config['models'][model_name]
+ model_type = model_config['type']
+ if model_type == 'tongyi':
+ # return Tongyi(dashscope_api_key=model_config['api_key'], model_name=model_config['model_name'])
+ return ChatTongyi(dashscope_api_key=model_config['api_key'], model=model_config['model_name'])
+ elif model_type == 'qianfan':
+ return QianfanLLMEndpoint(qianfan_ak=model_config['ak'], qianfan_sk=model_config['sk'], model=model_config['model_name'])
+ else:
+ raise ValueError(f"Unknown model type: {model_type}")
+
+
+class ModelManager:
+ def __init__(self, model_loader):
+ self.model_loader = model_loader
+ self.current_model = None
+
+ def switch_model(self, model_name):
+ self.current_model = self.model_loader.load_model(model_name)
+ print(f"Switched to model: {model_name}")
+
+ def get_model(self):
+ return self.current_model
+
+
diff --git a/sqlcode/multi_agent.py b/sqlcode/multi_agent.py
new file mode 100644
index 0000000..ff99ef0
--- /dev/null
+++ b/sqlcode/multi_agent.py
@@ -0,0 +1,242 @@
+import os
+import sys
+parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
+sys.path.append(parent_dir)
+
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_core.language_models import BaseLanguageModel
+from langgraph.graph.graph import CompiledGraph
+from typing import TypedDict, List
+from langchain_core.pydantic_v1 import BaseModel, Field
+from langchain_community.utilities import SQLDatabase
+from config import DatabaseConfig, QwenConfig
+from langgraph.graph import END, START, StateGraph, MessagesState
+from langchain_core.output_parsers import PydanticOutputParser
+from langchain.output_parsers import OutputFixingParser
+from langchain_core.messages import RemoveMessage, AnyMessage
+import json
+
+MAX_ITERATIONS = 3
+FLAG = "reflect"
+
+class GraphState(TypedDict):
+ """
+ Represents the state of our graph.
+
+ Attributes:
+ error : Binary flag for control flow to indicate whether test error was tripped
+ messages : With user question, error messages, reasoning
+ generation : Code solution
+ iterations : Number of tries
+ """
+
+ error: str
+ messages: list[AnyMessage]
+ generation: str
+ iterations: int
+
+class code(BaseModel):
+ """SQL输出"""
+
+ prefix: str = Field(description="对问题和使用方法的描述")
+ code: str = Field(description="SQL代码")
+
+
+import re
+
+def extract_sql_code(text):
+ # Regular expression to match ```sql``` blocks
+ pattern = r'```sql(.*?)```'
+
+ # Find all matches and extract the SQL code
+ matches = re.findall(pattern, text, re.DOTALL)
+
+ # Strip leading and trailing whitespace from each match
+ sql_code_blocks = [match.strip() for match in matches]
+
+ return sql_code_blocks
+
+def parse_output(message):
+ content = message.content
+ sql_codes = extract_sql_code(content)
+ if len(sql_codes) > 0:
+ sql_code = sql_codes[0]
+ else:
+ sql_code = ""
+ return code(prefix=content, code=sql_code)
+
+def truncate_strings_from_end(string_list, max_length=6000):
+ current_length = 0
+ truncated_list = []
+
+ # 从列表末尾开始遍历
+ for s in reversed(string_list):
+ if current_length + len(s) > max_length:
+ # 如果加上当前字符串长度后超过了max_length,计算可以保留的部分
+ truncated_length = max_length - current_length
+ truncated_list.append(s[-truncated_length:])
+ break
+ else:
+ truncated_list.append(s)
+ current_length += len(s)
+
+ # 由于是从后往前添加的,最后需要逆序返回
+ return list(reversed(truncated_list))
+
+def create_sql_graph(model:BaseLanguageModel, qwen_cfg: QwenConfig, data_cfg: DatabaseConfig) -> CompiledGraph:
+ # parser = PydanticOutputParser(pydantic_object=code)
+
+ sql_gen_prompt = ChatPromptTemplate.from_messages(
+ [
+ ("system", qwen_cfg.prompt),
+ # ("human", "回答用户问题,并用`json`标签包装输出 {format_instructions}"),
+ ("placeholder", "{messages}")
+ ]
+ )
+ sql_gen_prompt = sql_gen_prompt.partial(metadata=data_cfg.metadata, product=data_cfg.product, example=qwen_cfg.params["example"])
+
+ sql_gen_chain = sql_gen_prompt | model | parse_output
+ # 生成SQL语句
+ def generate(state: GraphState):
+ messages = state["messages"]
+ iterations = state["iterations"]
+ if "error" in state:
+ error = state["error"]
+ else:
+ error = "no"
+ # error = state["error"]
+ if error == 'yes':
+ messages += [("human", "请重新生成SQL,使用 code 工具确保输出内容格式化,包括前缀(prefix)、sql代码(code):")]
+
+ try:
+ sql_solution = sql_gen_chain.invoke({"messages": messages})
+ except ValueError as e:
+ messages = truncate_strings_from_end(messages)
+ sql_solution = sql_gen_chain.invoke({"messages": messages})
+
+ messages += [
+ (
+ "ai",
+ f"{sql_solution.prefix} \n Code: {sql_solution.code}"
+ )
+ ]
+ iterations = iterations + 1
+ return {"generation": sql_solution, "messages": messages, "iterations": iterations}
+ # 检查SQL是否合法
+ def sql_check(state: GraphState):
+ print("---检查SQL---")
+ messages = state["messages"]
+ sql_solution = state["generation"]
+ iterations = state["iterations"]
+
+ code = sql_solution.code
+ sql_executor = SQLDatabase.from_uri(data_cfg.connection_string)
+
+ try:
+ sql_executor.run(code)
+ except Exception as e:
+ print("---SQL检查错误---")
+ error_msg = [("human", f"生成的SQL代码无法通过执行测试: {e}")]
+ messages += error_msg
+ return {
+ "generation": sql_solution,
+ "messages": messages,
+ "iterations": iterations,
+ "error": "yes",
+ }
+
+ print("---SQL检测通过---")
+ return {
+ "generation": sql_solution,
+ "messages": messages,
+ "iterations": iterations,
+ "error": "no",
+ }
+ # SQL若错误,反思
+ def reflect(state: GraphState):
+ messages = state["messages"]
+ iterations = state["iterations"]
+ sql_solution = state["generation"]
+
+ try:
+ reflections = sql_gen_chain.invoke({"messages": messages}).prefix
+ except ValueError as e:
+ messages = truncate_strings_from_end(messages)
+ reflections = sql_gen_chain.invoke({"messages": messages}).prefix
+
+ messages += [("ai", f"这里是对于错误的反思:{reflections}")]
+ return {
+ "generation": sql_solution,
+ "messages": messages,
+ "iterations": iterations,
+ "error": "no",
+ }
+ # 删除消息,保持消息列表的长度不超过3条
+ def delete_messages(state: GraphState):
+ messages = state["messages"]
+ if len(messages) > 3:
+ new_messages = messages[-3:]
+ state["messages"] = new_messages
+ return state
+
+ # 根据错误状态和迭代次数决定是否结束工作流
+ def decide_to_finish(state: GraphState):
+ error = state["error"]
+ iterations = state["iterations"]
+
+ if error == "no" or iterations == MAX_ITERATIONS:
+ print("---DECISION: FINISH---")
+ return "end"
+ else:
+ print("---DECISION: RE-TRY SOLUTION---")
+ if FLAG == "reflect":
+ return "reflect"
+ else:
+ return "generate"
+
+ # 状态图
+ workflow = StateGraph(GraphState)
+
+ workflow.add_node("generate", generate)
+ workflow.add_node("check_code", sql_check) # check code
+ workflow.add_node("reflect", reflect) # reflect
+ workflow.add_node("delete_messages",delete_messages)
+
+ # Build graph
+ workflow.add_edge(START, "generate")
+ workflow.add_edge("generate", "delete_messages")
+ workflow.add_edge("delete_messages", "check_code")
+ workflow.add_conditional_edges(
+ "check_code",
+ decide_to_finish,
+ {
+ "end": END,
+ "reflect": "reflect",
+ "generate": "generate",
+ },
+ )
+ workflow.add_edge("reflect", "generate")
+ app = workflow.compile()
+
+ return app
+
+
+import config
+if __name__ == '__main__':
+ # model_name = "qwen"
+ # model_cfg = config.model_config()
+ # modelLoader = ModelLoader(model_cfg)
+ # modelManager = ModelManager(modelLoader)
+ # modelManager.switch_model(model_name)
+ # model = modelManager.get_model()
+
+ from langchain_community.chat_models import ChatTongyi
+ model = ChatTongyi(model="qwen-turbo")
+ data_cfg = config.metadata("contracts")
+ qwen_cfg = config.qwen_config("qwen_graph.conf")
+ sql_graph = create_sql_graph(model, qwen_cfg, data_cfg)
+ input = {"messages": [("human", "教育、工业互联网、物联网相关行业的项目都有些什么,列举合同金额、项目名称、部门、签订时间、分公司")], "iterations": 0}
+ res = sql_graph.invoke(
+ input
+ )
+ print(input)
\ No newline at end of file
diff --git a/sqlcode/qa_cache.py b/sqlcode/qa_cache.py
new file mode 100644
index 0000000..ae1ff32
--- /dev/null
+++ b/sqlcode/qa_cache.py
@@ -0,0 +1,114 @@
+from cachetools import LRUCache
+from sklearn.feature_extraction.text import TfidfVectorizer
+import numpy as np
+import os
+
+class QACache:
+ def __init__(self, maxsize=100, similarity_threshold=0.5):
+ self.cache = LRUCache(maxsize=maxsize)
+ self.vectorizer = TfidfVectorizer()
+ self.questions = []
+ self.question_vectors = []
+ self.similarity_threshold = similarity_threshold # 设置相似度阈值
+
+ def add(self, question_and_answer:dict):
+ """添加问答对到缓存"""
+ if not isinstance(question_and_answer, dict):
+ raise TypeError("Invalid input. Expected a dictionary.")
+ if question_and_answer == {}:
+ print('传入的问答对为空。')
+ return
+ # 遍历字典并添加每个键值对
+ for question, answer in question_and_answer.items():
+ is_useful_answer = self._check_answer(answer)
+ if not is_useful_answer:
+ continue
+ self.cache[question] = answer
+ self.questions.append(question)
+ # 只有在问题添加后才进行向量化
+ print("question为:", question)
+ if len(self.questions) > 0:
+ self._vectorize_questions()
+ else:
+ print('没有有效的问题添加到缓存。')
+ # self._vectorize_questions()
+
+ def _check_answer(self, answer:str)->bool:
+ if not isinstance(answer, str) or not answer:
+ return False
+
+ """对回答进行过滤"""
+ if answer == "没有符合条件的记录" or answer == "没有符合条件的记录\n":
+ return False
+
+ # 目前 answer 的类型只是 Dataframe 类型的字符串,要考虑后面存储的回答是否有其他格式
+ if answer.find('varchar') != -1:
+ return False
+ str_list = answer.split('\n')
+ for key in ['经办人', '所属分公司', '合同形式', '合同名称', '项目来源', '专业', '地点', '客户名称', '客商类型', '合同签订金额(人民币)', '签订日期', '合同有效期(结束)']:
+ if key in str_list[0]:
+ return True
+ return False
+
+ def _vectorize_questions(self):
+ """向量化所有问题"""
+ self.question_vectors = self.vectorizer.fit_transform(self.questions)
+
+ def find_similar(self, question):
+ """找到最相似的问题和答案"""
+ if len(self.questions) == 0:
+ print('未缓存问答对。')
+ return None
+ question_vector = self.vectorizer.transform([question])
+ cosine_similarities = np.dot(self.question_vectors, question_vector.T).toarray().flatten()
+ most_similar_idx = cosine_similarities.argmax()
+ similarity_score = cosine_similarities[most_similar_idx]
+ most_similar_question = self.questions[most_similar_idx]
+ print('similarity_score:{}'.format(similarity_score))
+ print('most_similar_question:{}'.format(most_similar_question))
+ # 只有当相似度超过阈值时才返回问题和答案
+ if similarity_score > self.similarity_threshold:
+ return self.cache[most_similar_question]
+ else:
+ return None
+
+def load_qa_pairs(qa_filename:str) -> None:
+ if not os.path.exists(qa_filename):
+ file = open(qa_filename, 'w')
+ file.close()
+ qa_dict = {}
+ with open(qa_filename, 'r') as file:
+ lines = file.readlines()
+ positions = []
+ for idx, line in enumerate(lines):
+ if line.find('qa_cache') != -1:
+ positions.append(idx)
+ positions.append(len(lines))
+ chunks = []
+ for i in range(0, len(positions)-1):
+ chunk = lines[positions[i]:positions[i+1]]
+ if '[SUCCESS]' in lines[positions[i]]:
+ chunks.append(chunk)
+ question = ''
+ answer = ''
+ for chunk in chunks:
+ question = chunk[2].strip()
+ answer = ''.join(chunk[4:])
+ # print(question, answer)
+ qa_dict[question] = answer
+ return qa_dict
+
+if __name__ == '__main__':
+
+ qa_dict = load_qa_pairs('log/test1111.log')
+ qa_cache = QACache(similarity_threshold=1)
+ qa_cache.add(qa_dict)
+ # 使用示例
+ # qa_cache = QACache(maxsize=10)
+ # qa_cache.add("What is machine learning?", "Machine learning is a type of artificial intelligence.")
+ # qa_cache.add("How does machine learning work?", "It uses statistical techniques to give computers the ability to learn from data.")
+
+ # 用户提问
+ user_question = "机房的项目有哪些"
+ similar_answer = qa_cache.find_similar(user_question)
+ # print(similar_answer)
\ No newline at end of file
diff --git a/sqlcode/qgi.py b/sqlcode/qgi.py
new file mode 100644
index 0000000..43dbace
--- /dev/null
+++ b/sqlcode/qgi.py
@@ -0,0 +1,378 @@
+from http import HTTPStatus
+import os
+import sqlalchemy
+import pandas as pd
+from enum import StrEnum, IntEnum
+from pydantic import BaseModel, Field
+import markdown
+import logging, logging.config
+import uuid
+from typing import Callable, Any, Optional, Union
+from sqlcode.utils import enrich_input
+import yaml
+
+logger = logging.getLogger(__name__)
+qustion_logger = logging.getLogger('question')
+
+def load_logging_config(config_path) -> None:
+ with open(config_path, 'rt') as f:
+ config = yaml.safe_load(f.read())
+ logging.config.dictConfig(config)
+
+# enum for return type
+class ReturnType(StrEnum):
+ '''
+ - SQL: 返回 SQL 语句
+ - JSON: 返回 JSON 格式的数据
+ - TEXT: 返回文本格式的数据
+ - HTML: 返回 HTML 格式的数据
+ - WX_MD: 返回企业微信 Markdown 格式的数据
+ '''
+ SQL = "sql"
+ '''返回 SQL 语句'''
+
+ JSON = "json"
+ '''返回 JSON 格式的数据'''
+
+ TEXT = "text"
+ '''返回文本格式的数据'''
+
+ HTML = 'html'
+ '''返回 HTML 格式的数据'''
+
+ WX_MD = 'wx_markdown'
+ '''返回企业微信 Markdown 格式的数据'''
+
+
+class QueryResult(BaseModel):
+ status: int = Field(default=200, description='HTTP 状态码, 200 表示成功, 其他表示失败')
+ result: str | dict[str, Any] | None = Field(default=None, description='查询结果')
+ sql: str | None = Field(default=None, description='查询的 SQL 语句')
+ thought: str | None = Field(default=None, description='大模型产生的推理')
+ error: str | None = Field(default=None, description='错误信息')
+
+
+def _retrieve_sql(text) -> Union[str|QueryResult]:
+ sql = text['output']
+ if (n:=sql.find('```sql')) >= 0:
+ n += 6
+ sql = sql[n:sql.index('```', n)]
+ elif (n:=sql.find('```\n')) >= 0:
+ n += 3
+ sql = sql[n:sql.index('```', n)]
+ elif (n:=sql.find('```\r\n')) >= 0:
+ n += 3
+ sql = sql[n:sql.index('```', n)]
+ elif (n:=sql.find('```')) >= 0:
+ sql = sql[sql.index('\n', n): sql.index('```', n+3)]
+ else:
+ return QueryResult(status=HTTPStatus.FAILED_DEPENDENCY, error='从结果中无法找到有效的查询语句。', thought=text['output'])
+
+ sql = sql.strip()
+ return sql
+
+class Generator:
+ max_retry: Optional[int] = 2
+
+ def generate(self, input:dict) -> QueryResult:
+ '''
+ 抽象生成器类
+
+ params:
+ - prompt: 输入文本
+ returns:
+ - code: HTTP 状态码, 200 表示成功, 其他表示失败
+ - sql: 生成的 SQL 语句,如果 code 不是 200,则 sql 中可能包含错误信息
+ - text: 生成的完全文本,内容可能包含 SQL 语句,如果 code 不是 200,则 text 中可能包含错误信息
+ '''
+ cnt=0
+ sql = ""
+ while cnt < self.max_retry:
+ code, text = self._generate(input)
+
+ if code != HTTPStatus.OK:
+ return QueryResult(status=code, error=text['err'], thought=text['output'])
+ res = _retrieve_sql(text)
+ if isinstance(res, QueryResult):
+ cnt+=1
+ print("Error: 从结果中无法找到有效的查询语句, retry again...")
+ continue
+ else:
+ sql = res
+ break
+
+ if sql == "":
+ return QueryResult(status=HTTPStatus.FAILED_DEPENDENCY, error='从结果中无法找到有效的查询语句。', thought=text['output'])
+
+ logger.info('[QUESTION] %s\n%s\n%s', input['messages'][0], sql, text['output'])
+ qustion_logger.info('%s\n--------\n%s\n--------------------\n', sql, text['output'])
+
+ return QueryResult(status=HTTPStatus.OK, sql=sql, thought=text['output'])
+
+
+ def _generate(self, input:dict) -> tuple[HTTPStatus,str]:
+ """
+ 根据给定的提示生成SQL语句。
+
+ 参数:
+ - prompt: str - 提供给模型的提示,用于生成SQL语句。
+
+ 返回:
+ - HTTPStatus - 请求的状态码,200 表示成功。
+ - str - full text, 从大模型的返回的完整回复,如果 code 不是 200,则 text 中可能包含错误信息。
+ 如果 code 是 200,则 text 中的 SQL 语句应包含在围栏式代码块 ```sql ``` 中。
+ """
+ raise NotImplementedError()
+
+
+class FormatSuggest(IntEnum):
+ '''
+ 对结果集的格式化建议
+ '''
+
+ NoResult = 0
+ '''没有查询结果'''
+
+ SingleRow = 1
+ '''结果为一行'''
+
+ SingleColumn = 2
+ '''结果为一列'''
+
+ MultiRow = 3
+ '''结果为多行,但行数不多,建议直接列举结果'''
+
+ ExportToExcel = 4
+ '''结果集较大,建议导出到 Excel 文件'''
+
+
+def _suggest_format(df:pd.DataFrame) -> FormatSuggest:
+ row_count = len(df.index)
+
+ if row_count == 0:
+ return FormatSuggest.NoResult
+ elif row_count == 1:
+ return FormatSuggest.SingleRow
+ elif row_count < 20:
+ if len(df.columns) == 1:
+ return FormatSuggest.SingleColumn
+ else:
+ return FormatSuggest.MultiRow
+ else:
+ return FormatSuggest.ExportToExcel
+
+def _translate_column_names(df:pd.DataFrame, translate:Callable[[str],str]):
+ if translate is None:
+ return
+
+ # translate column names
+ col_names = df.columns.to_list()
+ zh_names = translate(' | '.join(col_names).replace('_', ' ')).split('|')
+ zh_names = list(map(lambda s: s.strip(), zh_names))
+ if len(col_names) == len(zh_names):
+ df.columns = zh_names
+ else:
+ logger.error('[TRANSLATE] count not match (%d->%d) %s %s', len(col_names), len(zh_names), col_names, zh_names)
+
+
+# pandas 数据转为 markdown表格
+def df_to_markdown(df):
+ headers = list(df.columns)
+ # 计算每列内容的总长度
+ total_length = sum(df[col].astype(str).str.len().sum() + len(col) for col in headers)
+ # 计算每列相对比例
+ proportions = [((df[col].astype(str).str.len().sum() + len(col)) / total_length) for col in headers]
+ # 假设总宽度为 100,按比例分配宽度
+ widths = [int(100 * prop) for prop in proportions]
+ # 创建 Markdown 表格的表头
+ markdown_table = "| " + " | ".join([f"{header:<{width}}" for header, width in zip(headers, widths)]) + " |\n"
+ markdown_table += "| " + " | ".join(["---" * width for width in widths]) + " |\n"
+ # 添加数据行
+ for _, row in df.iterrows():
+ markdown_table += "| " + " | ".join([f"{str(value):<{width}}" for value, width in zip(row, widths)]) + " |\n"
+ return markdown_table
+
+class Formatter:
+ '''
+ 抽象格式化器类
+ '''
+ def __init__(self, tranlate) -> None:
+ self.translate = tranlate
+
+ def format(self, df:pd.DataFrame) -> str:
+ '''
+ 格式化指定的 DataFrame。
+ '''
+ raise NotImplementedError()
+
+class _JsonFormatter(Formatter):
+ def __init__(self) -> None:
+ super().__init__(tranlate=None)
+
+ def format(self, df:pd.DataFrame):
+ return df.to_dict(orient='split', index=False)
+
+
+class _MarkdownFormatter(Formatter):
+ def __init__(self, tranlate, output_dir:str, site_url:str) -> None:
+ super().__init__(tranlate=tranlate)
+ self.output_dir = output_dir
+ self.site_url = site_url
+
+ def format(self, df:pd.DataFrame):
+ suggest = _suggest_format(df)
+
+ if suggest == FormatSuggest.NoResult:
+ return '没有符合条件的记录'
+
+ _translate_column_names(df, self.translate)
+
+ if suggest == FormatSuggest.SingleRow:
+ return ', '.join([f'{k}:{v}' for k,v in df.to_dict(orient='records')[0].items()])
+ elif suggest == FormatSuggest.SingleColumn:
+ return df[df.columns[0]].str.cat(sep='\n')
+ elif suggest == FormatSuggest.MultiRow:
+ output_file = uuid.uuid4().hex + '.xlsx'
+ df.to_excel(os.path.join(self.output_dir, output_file), index=False)
+ return df_to_markdown(df.head(10)) + \
+ '\n\n查询结果已导出到Excel文件,请点击链接下载文件:{}/{}'.format(self.site_url, output_file)
+ else:
+ # TODO 测试阶段仅保存前 10 条记录
+ output_file = uuid.uuid4().hex + '.xlsx'
+ df.to_excel(os.path.join(self.output_dir, output_file), index=False)
+ return df_to_markdown(df.head(10)) + \
+ '\n\n结果记录数或信息量太多,查询结果已导出到Excel文件(测试阶段仅导出前10行),请点击链接下载文件:{}/{}'.format(self.site_url, output_file)
+
+
+
+class _HtmlFormatter(_MarkdownFormatter):
+ def __init__(self, tranlate, output_dir: str, site_url: str) -> None:
+ super().__init__(tranlate, output_dir, site_url)
+
+ def format(self, df:pd.DataFrame):
+ answer = super().format(df)
+ return markdown.markdown(answer, extensions=['markdown.extensions.tables'])
+
+
+class _WechatFormatter(_MarkdownFormatter):
+ def __init__(self, tranlate, output_dir: str, site_url: str) -> None:
+ super().__init__(tranlate, output_dir, site_url)
+
+ def format(self, df: pd.DataFrame):
+ suggest = _suggest_format(df)
+
+ if suggest == FormatSuggest.NoResult:
+ return '没有符合条件的记录'
+
+ _translate_column_names(df, self.translate)
+
+ if suggest == FormatSuggest.ExportToExcel:
+ # TODO 测试阶段仅保存前 10 条记录
+ output_file = uuid.uuid4().hex + '.xlsx'
+ # df.head(10).to_excel(os.path.join(self.output_dir, output_file), index=False)
+ # return f'结果记录数或信息量太多,查询结果已导出到Excel文件,请下载查看 (测试阶段仅导出前10行)'+df.head(10).to_string()
+ # return df.head(10).to_html(index=False)
+ return df_to_markdown(df.head(10))
+ elif suggest == FormatSuggest.MultiRow:
+ if len(df.columns) <= 3:
+ return '\n'.join([', '.join([str(v) for v in row.values()]) for row in df.to_dict(orient='records')])
+ else:
+ return '\n'.join([', '.join([f'{k}:{v}' for k,v in row.items()]) for row in df.to_dict(orient='records')])
+ else:
+ return super().format(df)
+
+
+def NewFormatter(format:str, tranlate:Callable[[str],str], output_dir:str, site_url:str) -> Formatter:
+ if format == ReturnType.JSON:
+ return _JsonFormatter()
+ elif format == ReturnType.TEXT:
+ return _MarkdownFormatter(tranlate, output_dir, site_url)
+ elif format == ReturnType.HTML:
+ return _HtmlFormatter(tranlate, output_dir, site_url)
+ elif format == ReturnType.WX_MD:
+ return _WechatFormatter(tranlate, output_dir, site_url)
+ else:
+ raise ValueError('不支持的格式化器类型')
+
+
+class Executor:
+ '''
+ 抽象执行器类
+
+ params:
+ - sql: SQL 语句
+ returns:
+ - code: HTTP 状态码, 200 表示成功, 其他表示失败
+ - text: 生成的完全文本,内容可能包含 SQL 语句,如果 code 不是 200,则 text 中可能包含错误信息
+ '''
+ def __init__(self, generator:Generator, formatter:Formatter, max_retry:Optional[int] = 2) -> None:
+ self.generator = generator
+ self.formatter = formatter
+ self.max_retry = max_retry
+
+ def get_sql(self, cnt:int, input:dict) -> dict:
+ """
+ 根据提供的问题生成SQL
+
+ 参数:
+ - cnt (int): 重试次数。
+ - input (dict): 主要包含用户查询的问题。
+
+ 返回:
+ - dict: 生成sql的返回体。
+ """
+ print('第{}次尝试'.format(cnt+1))
+ if cnt == 0:
+ ret = self.generator.generate(input)
+ else:
+ new_input = enrich_input("之前生成的SQL语句有误或查询结果为空,请重新生成一个不同的SQL语句, question:", input)
+ ret = self.generator.generate(new_input)
+ return ret
+
+ def query(self, connection_string:str, input:dict) -> QueryResult:
+ """
+ 根据提供的模型、数据库和问题生成并执行SQL查询,然后根据返回类型返回结果。
+
+ 参数:
+ - database: str - 数据库的连接字符串。
+ - prompt: str - 提示文本。
+ - question: str - 用户的查询问题。
+ - input: 问题,input = {"messages": [("human", req.question)], "iterations": 0}
+
+ 返回:
+ - status_code - HTTP状态码。
+ - result 查询结果根据return_type的不同可以是字典、字符串或DataFrame的markdown表示。
+ - sql - 生成的SQL语句。
+ - thought - 模型生成的推理
+ """
+ error: Union[Exception|str|None] = None
+ # 使用logger记录问答对
+ load_logging_config('config/development/logging.yaml')
+ qa_logger = logging.getLogger('qa_cache')
+ #数据库连接
+ sql_engine = sqlalchemy.create_engine(connection_string)
+ with sql_engine.connect() as connection:
+ cnt = 0
+ ret = None
+ while cnt < self.max_retry:
+ ret = self.get_sql(cnt, input)
+ if ret.status != HTTPStatus.OK:
+ qa_logger.error('[FAIL] %s\n%s', input['messages'][0][1], ret.error)
+ return ret
+ try:
+ #查询数据
+ result = connection.execute(sqlalchemy.text(ret.sql))
+ df = pd.DataFrame(result, columns=result.keys())
+ print('******************\n', df.head(30),'\n***************\n')
+ ret.result = self.formatter.format(df)
+
+ qa_logger.info('[SUCCESS] \nInput:\n%s\nResult:\n%s', input['messages'][0][1], ret.result)
+ break
+ except Exception as e:
+ error = e
+ qa_logger.error('[QUERY] %s\n%s', input['messages'][0][1], error)
+ ret.status=HTTPStatus.INTERNAL_SERVER_ERROR
+ ret.error=str(error)
+ print("Error: SQL执行错误,retry agin...{}".format(error))
+ cnt+=1
+ return ret
\ No newline at end of file
diff --git a/sqlcode/qwenapi.py b/sqlcode/qwenapi.py
new file mode 100644
index 0000000..47edc74
--- /dev/null
+++ b/sqlcode/qwenapi.py
@@ -0,0 +1,57 @@
+from dashscope import Generation
+import random
+from http import HTTPStatus
+from enum import Enum
+from .qgi import Generator as BaseGenerator
+
+class Model(str, Enum):
+ # 通义千问超大规模语言模型,支持中文、英文等不同语言输入。
+ # 模型支持 8,000 tokens上下文,为了保证正常的使用和输出,API限定用户输入为 6,000 tokens。
+ QWEN_TURBO = "qwen-turbo"
+
+ # 通义千问超大规模语言模型增强版,支持中文、英文等不同语言输入。
+ # 模型支持 32,000 tokens上下文,为了保证正常的使用和输出,API限定用户输入为30,000 tokens。
+ QWEN_PLUS = "qwen-plus"
+
+ # 通义千问千亿级别超大规模语言模型,支持中文、英文等不同语言输入。
+ # 随着模型的升级,qwen-max将滚动更新升级,如果希望使用固定版本,请使用下面的历史快照版本。
+ # 当前qwen-max模型与qwen-max-0428快照版本等价,均为最新版本的qwen-max模型,也是当前通义千问2.5产品版本背后的API模型。
+ # 模型支持 8,000 tokens上下文,为了保证正常的使用和输出,API限定用户输入为 6,000 tokens。
+ QWEN_MAX = "qwen-max"
+
+ # 通义千问千亿级别超大规模语言模型,支持中文、英文等不同语言输入。
+ # 模型支持 30,000 tokens上下文,为了保证正常的使用和输出,API限定用户输入为 28,000 tokens。
+ QWEN_MAX_LONGCONTEXT = "qwen-max-longcontext"
+
+
+class Generator(BaseGenerator):
+ def __init__(self, model:str, messages:list[dict[str,str]]|None, apikey:str, seed:int=0) -> None:
+ '''
+ 参数:
+ - model: str - 使用的大模型名称,例如 qwen-turbo。
+ - messages: list[dict[str,str]] - 提供给大模型的 messages 参数
+ - apikey: str - 用于调用模型的 dashscope API密钥。
+ - seed: int - 随机种子,用于确保生成结果的可复现性,如果忽略或为0,则生成一个随机数。
+ '''
+ self.model = model
+ self.messages = messages
+ self.apikey = apikey
+ self.seed = seed
+
+ def _generate(self, prompt:str|None) -> tuple[HTTPStatus, str]:
+ # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234
+ seed = self.seed
+ if seed == 0:
+ seed = random.randint(1, 10000)
+
+ ret = Generation.call(model=self.model,
+ messages=self.messages,
+ prompt=prompt,
+ api_key=self.apikey,
+ seed=seed,
+ result_format='text')
+
+ if ret.status_code != HTTPStatus.OK:
+ return ret.status_code, ret.message
+ else:
+ return ret.status_code, ret.output.text
diff --git a/sqlcode/sql_agent.py b/sqlcode/sql_agent.py
new file mode 100644
index 0000000..4efaece
--- /dev/null
+++ b/sqlcode/sql_agent.py
@@ -0,0 +1,324 @@
+"""SQL agent."""
+
+from __future__ import annotations
+
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ List,
+ Literal,
+ Optional,
+ Sequence,
+ Union,
+ cast,
+)
+
+from langchain_core.messages import AIMessage, SystemMessage
+from langchain_core.prompts import BasePromptTemplate, PromptTemplate
+from langchain_core.prompts.chat import (
+ ChatPromptTemplate,
+ HumanMessagePromptTemplate,
+ MessagesPlaceholder,
+)
+
+from langchain_community.agent_toolkits.sql.prompt import (
+ SQL_FUNCTIONS_SUFFIX,
+ SQL_PREFIX,
+)
+from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
+from langchain_community.tools.sql_database.tool import (
+ InfoSQLDatabaseTool,
+ ListSQLDatabaseTool,
+)
+
+if TYPE_CHECKING:
+ from langchain.agents.agent import AgentExecutor
+ from langchain.agents.agent_types import AgentType
+ from langchain_core.callbacks import BaseCallbackManager
+ from langchain_core.language_models import BaseLanguageModel
+ from langchain_core.tools import BaseTool
+
+ from langchain_community.utilities.sql_database import SQLDatabase
+
+
+def create_sql_agent(
+ llm: BaseLanguageModel,
+ toolkit: Optional[SQLDatabaseToolkit] = None,
+ agent_type: Optional[
+ Union[AgentType, Literal["openai-tools", "tool-calling"]]
+ ] = None,
+ callback_manager: Optional[BaseCallbackManager] = None,
+ prefix: Optional[str] = None,
+ suffix: Optional[str] = None,
+ format_instructions: Optional[str] = None,
+ input_variables: Optional[List[str]] = None,
+ top_k: int = 10,
+ max_iterations: Optional[int] = 15,
+ max_execution_time: Optional[float] = None,
+ early_stopping_method: str = "force",
+ verbose: bool = False,
+ agent_executor_kwargs: Optional[Dict[str, Any]] = None,
+ extra_tools: Sequence[BaseTool] = (),
+ *,
+ db: Optional[SQLDatabase] = None,
+ prompt: Optional[BasePromptTemplate] = None,
+ **kwargs: Any,
+) -> AgentExecutor:
+ """Construct a SQL agent from an LLM and toolkit or database.
+
+ Args:
+ llm: Language model to use for the agent. If agent_type is "tool-calling" then
+ llm is expected to support tool calling.
+ toolkit: SQLDatabaseToolkit for the agent to use. Must provide exactly one of
+ 'toolkit' or 'db'. Specify 'toolkit' if you want to use a different model
+ for the agent and the toolkit.
+ agent_type: One of "tool-calling", "openai-tools", "openai-functions", or
+ "zero-shot-react-description". Defaults to "zero-shot-react-description".
+ "tool-calling" is recommended over the legacy "openai-tools" and
+ "openai-functions" types.
+ callback_manager: DEPRECATED. Pass "callbacks" key into 'agent_executor_kwargs'
+ instead to pass constructor callbacks to AgentExecutor.
+ prefix: Prompt prefix string. Must contain variables "top_k" and "dialect".
+ suffix: Prompt suffix string. Default depends on agent type.
+ format_instructions: Formatting instructions to pass to
+ ZeroShotAgent.create_prompt() when 'agent_type' is
+ "zero-shot-react-description". Otherwise ignored.
+ input_variables: DEPRECATED.
+ top_k: Number of rows to query for by default.
+ max_iterations: Passed to AgentExecutor init.
+ max_execution_time: Passed to AgentExecutor init.
+ early_stopping_method: Passed to AgentExecutor init.
+ verbose: AgentExecutor verbosity.
+ agent_executor_kwargs: Arbitrary additional AgentExecutor args.
+ extra_tools: Additional tools to give to agent on top of the ones that come with
+ SQLDatabaseToolkit.
+ db: SQLDatabase from which to create a SQLDatabaseToolkit. Toolkit is created
+ using 'db' and 'llm'. Must provide exactly one of 'db' or 'toolkit'.
+ prompt: Complete agent prompt. prompt and {prefix, suffix, format_instructions,
+ input_variables} are mutually exclusive.
+ **kwargs: Arbitrary additional Agent args.
+
+ Returns:
+ An AgentExecutor with the specified agent_type agent.
+
+ Example:
+
+ .. code-block:: python
+
+ from langchain_openai import ChatOpenAI
+ from langchain_community.agent_toolkits import create_sql_agent
+ from langchain_community.utilities import SQLDatabase
+
+ db = SQLDatabase.from_uri("sqlite:///Chinook.db")
+ llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
+ agent_executor = create_sql_agent(llm, db=db, agent_type="tool-calling", verbose=True)
+
+ """ # noqa: E501
+ from langchain.agents import (
+ create_openai_functions_agent,
+ create_openai_tools_agent,
+ create_react_agent,
+ create_tool_calling_agent,
+ )
+ from langchain.agents.agent import (
+ AgentExecutor,
+ RunnableAgent,
+ RunnableMultiActionAgent,
+ )
+ from langchain.agents.agent_types import AgentType
+
+ if toolkit is None and db is None:
+ raise ValueError(
+ "Must provide exactly one of 'toolkit' or 'db'. Received neither."
+ )
+ if toolkit and db:
+ raise ValueError(
+ "Must provide exactly one of 'toolkit' or 'db'. Received both."
+ )
+
+ toolkit = toolkit or SQLDatabaseToolkit(llm=llm, db=db) # type: ignore[arg-type]
+ agent_type = agent_type or AgentType.ZERO_SHOT_REACT_DESCRIPTION
+ tools = toolkit.get_tools() + list(extra_tools)
+ if prompt is None:
+ prefix = prefix or SQL_PREFIX
+ prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
+ else:
+ if "top_k" in prompt.input_variables:
+ prompt = prompt.partial(top_k=str(top_k))
+ if "dialect" in prompt.input_variables:
+ prompt = prompt.partial(dialect=toolkit.dialect)
+ if any(key in prompt.input_variables for key in ["table_info", "table_names"]):
+ db_context = toolkit.get_context()
+ if "table_info" in prompt.input_variables:
+ prompt = prompt.partial(table_info=db_context["table_info"])
+ tools = [
+ tool for tool in tools if not isinstance(tool, InfoSQLDatabaseTool)
+ ]
+ if "table_names" in prompt.input_variables:
+ prompt = prompt.partial(table_names=db_context["table_names"])
+ tools = [
+ tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool)
+ ]
+
+ if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
+ if prompt is None:
+ from langchain.agents.mrkl import prompt as react_prompt
+
+ format_instructions = (
+ format_instructions or react_prompt.FORMAT_INSTRUCTIONS
+ )
+ template = "\n\n".join(
+ [
+ react_prompt.PREFIX,
+ "{tools}",
+ format_instructions,
+ react_prompt.SUFFIX,
+ ]
+ )
+ prompt = PromptTemplate.from_template(template)
+ agent = RunnableAgent(
+ runnable=create_react_agent(llm, tools, prompt, output_parser=ReActSingleInputOutputParserWithOutMarkDown()),
+ input_keys_arg=["input"],
+ return_keys_arg=["output"],
+ **kwargs,
+ )
+
+ elif agent_type == AgentType.OPENAI_FUNCTIONS:
+ if prompt is None:
+ messages: List = [
+ SystemMessage(content=cast(str, prefix)),
+ HumanMessagePromptTemplate.from_template("{input}"),
+ AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
+ ]
+ prompt = ChatPromptTemplate.from_messages(messages)
+ agent = RunnableAgent(
+ runnable=create_openai_functions_agent(llm, tools, prompt), # type: ignore
+ input_keys_arg=["input"],
+ return_keys_arg=["output"],
+ **kwargs,
+ )
+ elif agent_type in ("openai-tools", "tool-calling"):
+ if prompt is None:
+ messages = [
+ SystemMessage(content=cast(str, prefix)),
+ HumanMessagePromptTemplate.from_template("{input}"),
+ AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
+ ]
+ prompt = ChatPromptTemplate.from_messages(messages)
+ if agent_type == "openai-tools":
+ runnable = create_openai_tools_agent(llm, tools, prompt) # type: ignore
+ else:
+ runnable = create_tool_calling_agent(llm, tools, prompt) # type: ignore
+ agent = RunnableMultiActionAgent( # type: ignore[assignment]
+ runnable=runnable,
+ input_keys_arg=["input"],
+ return_keys_arg=["output"],
+ **kwargs,
+ )
+
+ else:
+ raise ValueError(
+ f"Agent type {agent_type} not supported at the moment. Must be one of "
+ "'tool-calling', 'openai-tools', 'openai-functions', or "
+ "'zero-shot-react-description'."
+ )
+
+ return AgentExecutor(
+ name="SQL Agent Executor",
+ agent=agent,
+ tools=tools,
+ callback_manager=callback_manager,
+ verbose=verbose,
+ max_iterations=max_iterations,
+ max_execution_time=max_execution_time,
+ early_stopping_method=early_stopping_method,
+ **(agent_executor_kwargs or {}),
+ )
+
+import re
+
+from langchain_core.agents import AgentAction, AgentFinish
+from langchain_core.exceptions import OutputParserException
+
+from langchain.agents.agent import AgentOutputParser
+from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
+
+FINAL_ANSWER_ACTION = "Final Answer:"
+MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE = (
+ "Invalid Format: Missing 'Action:' after 'Thought:"
+)
+MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE = (
+ "Invalid Format: Missing 'Action Input:' after 'Action:'"
+)
+FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE = (
+ "Parsing LLM output produced both a final answer and a parse-able action:"
+)
+class ReActSingleInputOutputParserWithOutMarkDown(AgentOutputParser):
+ def get_format_instructions(self) -> str:
+ return FORMAT_INSTRUCTIONS
+
+ def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
+ includes_answer = FINAL_ANSWER_ACTION in text
+ regex = (
+ r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
+ )
+ print("text: {}".format(text))
+ action_match = re.search(regex, text, re.DOTALL)
+ if action_match:
+ if includes_answer:
+ raise OutputParserException(
+ f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}"
+ )
+ action = action_match.group(1).strip()
+ action_input = action_match.group(2)
+ tool_input = action_input.strip(" ")
+ tool_input = tool_input.strip('"')
+
+ # Remove markdown code block markers if present
+ tool_input = re.sub(r'(^```\s*sql\s*|^\s*```$)', '', tool_input, flags=re.MULTILINE).strip()
+ tool_input = re.sub(r'(^`\s*|`\s*$)', '', tool_input).strip()
+
+ return AgentAction(action, tool_input, text)
+
+ elif includes_answer:
+ return AgentFinish(
+ {"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
+ )
+
+ if not re.search(r"Action\s*\d*\s*:[\s]*(.*?)", text, re.DOTALL):
+ raise OutputParserException(
+ f"Could not parse LLM output: `{text}`",
+ observation=MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE,
+ llm_output=text,
+ send_to_llm=True,
+ )
+ elif not re.search(
+ r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
+ ):
+ raise OutputParserException(
+ f"Could not parse LLM output: `{text}`",
+ observation=MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
+ llm_output=text,
+ send_to_llm=True,
+ )
+ else:
+ raise OutputParserException(f"Could not parse LLM output: `{text}`")
+
+ @property
+ def _type(self) -> str:
+ return "react-single-input-without-markdown"
+
+def remove_markdown_code_block(text: str):
+ pattern = re.compile(r"```sql\s*(.*?)\s*```", re.DOTALL)
+
+ # 查找匹配的内容
+ match = pattern.search(text)
+ if match:
+ # 提取 SQL 语句
+ sql_query = match.group(1)
+ return sql_query.strip()
+ else:
+ return "No SQL code block found."
\ No newline at end of file
diff --git a/sqlcode/store_vecstore.py b/sqlcode/store_vecstore.py
new file mode 100644
index 0000000..c16c4bd
--- /dev/null
+++ b/sqlcode/store_vecstore.py
@@ -0,0 +1,17 @@
+from langchain_community.document_loaders import TextLoader
+from langchain_text_splitters import RecursiveCharacterTextSplitter
+from langchain_chroma import Chroma
+import os
+from langchain_community.embeddings import DashScopeEmbeddings
+
+ENVIRONMENT = os.environ.get('ENVIRONMENT', 'development')
+DIR_NAME = os.path.dirname(__file__)
+path = os.path.join(DIR_NAME, '..','config',ENVIRONMENT,"rag.txt")
+loader = TextLoader(file_path=path)
+
+docs = loader.load()
+text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=100)
+splits = text_splitter.split_documents(docs)
+out_dir = os.path.join(DIR_NAME, '..', 'chroma_db')
+vectorstore = Chroma.from_documents(documents=splits, embedding=DashScopeEmbeddings(), persist_directory=out_dir)
+print("向量数据库更新完毕")
\ No newline at end of file
diff --git a/sqlcode/tt_tencent.py b/sqlcode/tt_tencent.py
new file mode 100644
index 0000000..be53f01
--- /dev/null
+++ b/sqlcode/tt_tencent.py
@@ -0,0 +1,50 @@
+from tencentcloud.common import credential
+from tencentcloud.common.profile.client_profile import ClientProfile
+from tencentcloud.common.profile.http_profile import HttpProfile
+from tencentcloud.tmt.v20180321 import tmt_client, models
+
+class Translator:
+ def __init__(self, secret_id:str, secret_key:str, *,
+ region: str = 'ap-guangzhou'):
+ # 实例化一个认证对象,入参需要传入腾讯云账户 SecretId 和 SecretKey,此处还需注意密钥对的保密
+ # 代码泄露可能会导致 SecretId 和 SecretKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考,建议采用更安全的方式来使用密钥,请参见:https://cloud.tencent.com/document/product/1278/85305
+ # 密钥可前往官网控制台 https://console.cloud.tencent.com/cam/capi 进行获取
+ # 实例化一个http选项,可选的,没有特殊需求可以跳过
+ httpProfile = HttpProfile()
+ httpProfile.endpoint = "tmt.tencentcloudapi.com"
+
+ # 实例化一个client选项,可选的,没有特殊需求可以跳过
+ clientProfile = ClientProfile()
+ clientProfile.httpProfile = httpProfile
+ cred = credential.Credential(secret_id=secret_id, secret_key=secret_key)
+ self.client = tmt_client.TmtClient(credential=cred, region=region, profile=clientProfile)
+
+
+ def translate(self, text: str, *,
+ source: str|None = 'auto',
+ target: str|None = 'zh',
+ project_id: int|None = None,
+ untranslated_text: str|None = None
+ ) -> str:
+ # 实例化一个请求对象,每个接口都会对应一个request对象
+ req = models.TextTranslateRequest()
+ req.SourceText = text
+ req.Source = source
+ req.Target = target
+ req.ProjectId = project_id or 0
+ req.UntranslatedText = untranslated_text
+
+ # 返回的resp是一个TextTranslateResponse的实例,与请求对象对应
+ resp = self.client.TextTranslate(req)
+ # 输出json格式的字符串回包
+ return resp.TargetText
+
+if __name__ == '__main__':
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--secret_id', type=str, required=True)
+ parser.add_argument('--secret_key', type=str, required=True)
+ args = parser.parse_args()
+ t = Translator(args.secret_id, args.secret_key)
+ r = t.translate('branch_company | number_of_contracts | total_contract_amount'.replace('_', ' '), untranslated_text='|')
+ print(r)
\ No newline at end of file
diff --git a/sqlcode/utils.py b/sqlcode/utils.py
new file mode 100644
index 0000000..30448c2
--- /dev/null
+++ b/sqlcode/utils.py
@@ -0,0 +1,15 @@
+
+def enrich_input(message: str, input: dict) -> dict:
+ pre_messages = input["messages"]
+ problem = input["messages"][0][1]
+ print("human", message + " " + problem)
+ input["messages"] = pre_messages + [("human", message + " " + problem)]
+ input["iterations"] = 0
+ return input
+
+def parse(res: str) -> str:
+ """将优化后的问题提取出来"""
+ return res.split("优化后的问题:",1)[-1]
+
+def format_docs(docs):
+ return "\n\n".join(doc.page_content for doc in docs)
\ No newline at end of file
diff --git a/start-in-docker.sh b/start-in-docker.sh
new file mode 100644
index 0000000..a15d38d
--- /dev/null
+++ b/start-in-docker.sh
@@ -0,0 +1,28 @@
+#!bash
+
+set -xue
+
+venv=/root/miniconda3/envs/bizwechat
+workdir=$(dirname $(realpath $0))
+name=$(basename $workdir)
+
+[ -d "$venv" ] || (echo "venv $venv not exists" && exit 1)
+
+cd $workdir || (echo "cd $workdir failed" && exit 1)
+
+. "$workdir/.env"
+
+[ -e "log" ] || mkdir log
+
+docker run -d --stop-signal INT --replace \
+ --ip $UVICORN_HOST \
+ -v "$venv:$venv" \
+ -v "$workdir:$workdir" \
+ --name "$name" \
+ --env-file "$workdir/.env" \
+ --workdir "$workdir" \
+ --restart unless-stopped \
+ debian:bookworm \
+ "$venv/bin/uvicorn" main:app --log-config "config/$ENVIRONMENT/logging.yaml" $@
+
+docker logs -f "$name"
diff --git a/test.sh b/test.sh
new file mode 100644
index 0000000..49bf86d
--- /dev/null
+++ b/test.sh
@@ -0,0 +1,11 @@
+curl localhost:9000/qwen/contracts?apikey=YUVietLgiGmtqzYUVIIGjrNoLMsGM0FI \
+ -H "Content-Type: application/json" \
+ -d '{"return_type":"text","question":"2024年项目的平均金额为?"}'
+
+# -d '{"return_type":"text","question":"2023公司承接的“合同名称”或“专业”包含工程施工的项目有哪些?列出“合同名称”, “签订日期”, “专业”。使用模糊匹配,参考SQL语句:SELECT `合同名称`, `签订日期`, `专业` FROM contracts WHERE `合同名称` LIKE '%工程施工%' OR `专业` LIKE '%工程施工%'(AND 根据2023设置`签订时间`的年份或月份范围) LIMIT 15"}'
+
+# -d '{"return_type":"text","question":"2022公司承接的“客商类型”或“客户名称”包含中国电信的项目有哪些?列出“合同名称”“签订时间”“客商类型”。使用模糊匹配,参考SQL语句:SELECT `合同名称`, `签订日期`, `客商类型`FROM contracts Where `客户名称` LIKE '%中国电信%' OR `客商类型` LIKE '%中国电信%' (AND 根据2022设置`签订时间`的年份或月份范围"}'
+
+# -d '{"return_type":"text","question":"2020到2024年承接的地点在“四川省”的“合同签订金额”100万到1000万的项目有哪些?今年是2024年。参考sql语句: SELECT `合同名称`, `地点`, `签订日期`, `合同签订金额(人民币)` FROM contracts WHERE `地点` LIKE '%四川省%' AND `合同签订金额(人民币)` >= money AND `签订日期` BETWEEN time1 AND time2"}'
+
+
diff --git a/test_acc.py b/test_acc.py
new file mode 100644
index 0000000..677560d
--- /dev/null
+++ b/test_acc.py
@@ -0,0 +1,319 @@
+import requests
+import json
+import pandas as pd
+import random
+from itertools import product
+
+template_list = [
+ { # 时间、客户、公司
+ "keys": ['time', 'customer', 'company'],
+ "sql": """
+ SELECT `合同名称`, `签订日期`, `客户名称`, `所属分公司`
+ FROM contracts
+ WHERE `所属分公司` LIKE '%{company}%'
+ AND `客户名称` LIKE '%{customer}%' OR `客商类型` LIKE '%{customer}%' OR `合同名称` LIKE '%{customer}%'
+ (AND 根据{time}设置`签订时间`的筛选范围)""",
+ "question": "{time}{company}承接的“客户名称”或“客商类型”或“合同名称”包含{customer}的项目有哪些?参考sql语句:[sql]",
+ },
+ { # 时间、行业(包括专业)
+ "keys": ['time', 'field'],
+ "sql": """
+ SELECT `合同名称`, `签订日期`, `聚焦行业`
+ FROM contracts
+ WHERE `合同名称` LIKE '%{field}%' OR `聚焦行业` LIKE '%{field}%' OR `专业` LIKE '%{field}%'
+ (AND 根据{time}设置`签订时间`的筛选范围)""",
+ "question": "{time}承接的“聚焦行业”或“行业”包含{field}的项目有哪些?参考sql语句:[sql]"
+ },
+ { # 时间、公司、专业
+ "keys": ['time', 'company', 'major'],
+ "sql": """
+ SELECT `合同名称`, `签订日期`, `专业`, `所属分公司`
+ FROM contracts
+ WHERE `合同名称` LIKE '%{major}%' OR `专业` LIKE '%{major}%'
+ AND `所属分公司` LIKE '%{company}%'
+ (AND 根据{time}设置`签订时间`的筛选范围)""",
+ "question": "{time}{company}承接的“专业”包含{major}的项目有哪些?参考sql语句:[sql]"
+ },
+ { # 时间、地点
+ "keys": ['time', 'area'],
+ "sql": """
+ SELECT `合同名称`, `地点`
+ FROM contracts
+ WHERE `地点` LIKE '%{area}%'
+ (AND 根据{time}设置`签订时间`的筛选范围)""",
+ "question": "{time}承接的地点在“{area}”的项目有哪些?参考sql语句:[sql]",
+ },
+ { # 关键字查询(专业、行业、公司、客户、地点)
+ "keys": ['keyword'],
+ "sql": """
+ SELECT `合同名称`, `所属分公司`, `专业`, `聚焦行业`, `客商类型`, `地点`
+ FROM contracts
+ WHERE `合同名称` LIKE '%{keyword}%' OR `聚焦行业` LIKE '%{keyword}%'
+ OR `所属分公司` LIKE '%{keyword}%' OR `客商类型` LIKE '%{keyword}%'
+ OR `地点` LIKE '%{keyword}%' OR `客户名称` LIKE '%{keyword}%'
+ OR `专业` LIKE '%{keyword}%'""",
+ "question": "“地点”“专业”“聚焦行业”“所属分公司”“客商类型”“客户名称”或“合同名称”包含{keyword}的项目有哪些?参考sql语句:[sql]",
+ },
+ { # 公司、金额
+ "keys": ['company', 'money'],
+ "sql": """
+ SELECT `合同名称`, `所属分公司`, `合同签订金额(人民币)`
+ FROM contracts
+ WHERE `公司` LIKE '%{company}%'
+ (AND 根据{money}筛选`合同签订金额(人民币)`的筛选范围)""",
+ "question": "{company}的合同金额{money}的项目有哪些?参考sql语句:[sql]"
+ }]
+
+raw_data = {
+ 'time': list(pd.read_excel("./valueSets.xlsx", sheet_name="时间取值")["时间取值"]),
+ 'major': list(pd.read_excel("./valueSets.xlsx", sheet_name="专业取值")["专业取值"]),
+ 'company': list(pd.read_excel("./valueSets.xlsx", sheet_name="公司取值")["公司取值"]),
+ 'field': list(pd.read_excel("./valueSets.xlsx", sheet_name="行业取值")["行业取值"]),
+ 'customer': list(pd.read_excel("./valueSets.xlsx", sheet_name="客户取值")["客户取值"]),
+ 'area': list(pd.read_excel("./valueSets.xlsx", sheet_name="地区取值")["地区取值"]),
+ 'money': [ "100万以上的", "50万以上的", "2000万以上", "400万以上", "10万以上", "100万到1000万",], # 金额范围取值
+}
+raw_data['keyword'] = [item for key in ['area', 'company', 'customer', 'major', 'field'] for item in raw_data[key]]
+
+def combine_key_val(keys: list, k=4) -> dict:
+ """
+ 计算不同字段的排列组合
+ Args:
+ keys (_type_): 字段名列表
+ k (int, optional): 每个字段的取值数量. Defaults to 4.
+
+ Returns:
+ list: [{key1: value1}]字典列表
+
+ """
+ tmp = {}
+ for key in keys:
+ value = raw_data.get(key, [])
+ # if key == 'company':
+ # tmp[key] = ['一分公司']
+ # else:
+ # k = len(value) if k > len(value) else k
+ # tmp[key] = random.sample(value, k=k)
+ k = len(value) if k > len(value) else k
+ tmp[key] = random.sample(value, k=k)
+ print(tmp)
+ # 计算笛卡尔积
+ keys = list(tmp.keys())
+ values = list(tmp.values())
+ combinations = list(product(*values))
+ # 将每个组合转换为字典
+ return [{key: value for key, value in zip(keys, combination)} for combination in combinations]
+
+
+def form_question(keys, k) -> list:
+ """根据选填字段生成问题
+
+ Args:
+ keys (_type_): _description_
+
+ Returns:
+ list: _description_
+ """
+ combination_dict = combine_key_val(keys, k=k)
+ print(len(combination_dict))
+ template_idx = 0
+ for index, item in enumerate(template_list):
+ if set(item['keys']) == set(keys):
+ template_idx = index
+ break
+ template = template_list[template_idx]
+ question_template = template['question'].replace('[sql]', template['sql'])
+
+ question_list = []
+ tmp = ''
+ for combination in combination_dict:
+ tmp = question_template
+ tmp = tmp.format_map(combination)
+ question_list.append(tmp)
+ return question_list
+
+
+def test(keys, k=3):
+ question_list = form_question(keys, k=k)
+ try:
+ res = [["问题", "sql", "回答", "thought"]]
+ for question in question_list:
+ if question == "":
+ break
+ payload = {"return_type": "text", "question": question}
+ response = requests.post(url, headers=headers, data=json.dumps(payload))
+ response_data = response.json()
+ print("data_res", response_data)
+ res.append(
+ [
+ question,
+ response_data["sql"],
+ response_data["result"],
+ response_data["thought"],
+ ]
+ )
+ data = pd.DataFrame(res)
+ data.to_excel("./output/{}.xlsx".format('-'.join(keys)), index=False)
+ except Exception as error:
+ print("Error:", error)
+ return "Error"
+
+url = "http://localhost:8001/qwen/contracts?apikey=YUVietLgiGmtqzYUVIIGjrNoLMsGM0FI"
+headers = {
+ "Content-Type": "application/json",
+ # Uncomment the following line if needed and provide the appropriate value for data['apiKey']
+ # 'Authorization': f'Bearer {data["apiKey"]}'
+}
+
+
+# # def some_function():
+# # print(pdata["问题"])
+
+# ls = pdata["问题"]
+# qs = []
+# tmp = []
+# res_q = []
+# # for ll in ["数据库字段中的“合同名称”或“聚焦行业”或“专业”包含“key”的项目有哪些?"]:
+# # for ll in ["time公司承接的money的项目有哪些/多少?今年是2024年。"]:
+# # for ll in ["time公司承接的customer的项目有哪些/多少?其中,值为“customer”是数据库中“客商类型”字段的部分内容。今年是2024年。"]:
+# for ll in [
+# "company time承接的field的项目有哪些?其中,值为“company”是数据库中“所属分公司”字段的部分内容,“field”是数据库中“合同名称”或“聚焦行业”字段的部分内容。今年是2024年。"
+# ]:
+# # for ll in ["time公司承接的field的项目有哪些/多少?其中,值为“field”是数据库中“合同名称”或“聚焦行业”字段的部分内容。"]:
+
+# qs = []
+# tmp = []
+# tmp.append(ll)
+# if "time" in ll:
+# for ts in time:
+# t0 = ll.replace("time", ts)
+# qs.append(t0)
+# tmp = qs
+# qs = []
+
+# for t1 in tmp:
+# if "major" in t1:
+# for sp in major:
+# t0 = t1.replace("major", sp)
+# qs.append(t0)
+# else:
+# qs.append(t1)
+
+# tmp = qs
+# qs = []
+
+# for t1 in tmp:
+# if "company" in t1:
+# for sp in company:
+# t0 = t1.replace("company", sp)
+# qs.append(t0)
+# else:
+# qs.append(t1)
+
+# tmp = qs
+# qs = []
+
+# for t1 in tmp:
+# if "field" in t1:
+# for sp in field:
+# t0 = t1.replace("field", sp)
+# qs.append(t0)
+# else:
+# qs.append(t1)
+
+# tmp = qs
+# qs = []
+
+# for t1 in tmp:
+# if "key" in t1:
+# for sp in key:
+# t0 = t1.replace("key", sp)
+# qs.append(t0)
+# else:
+# qs.append(t1)
+
+# tmp = qs
+# qs = []
+
+# # money
+# for t1 in tmp:
+# if "money" in t1:
+# for sp in money:
+# t0 = t1.replace("money", sp)
+# qs.append(t0)
+# else:
+# qs.append(t1)
+
+# tmp = qs
+# qs = []
+
+# for t1 in tmp:
+# if "customer" in t1:
+# for sp in customer:
+# t0 = t1.replace("customer", sp)
+# qs.append(t0)
+# else:
+# qs.append(t1)
+
+# tmp = qs
+# qs = []
+
+# for t1 in tmp:
+# if "area" in t1:
+# for sp in area:
+# t0 = t1.replace("area", sp)
+# qs.append(t0)
+# else:
+# qs.append(t1)
+
+# # tmp = qs
+# # qs = []
+# # print(len(qs))
+# res_q = [*res_q, *qs]
+# print(len(res_q))
+# # print(res_q)
+
+# # return
+# try:
+# res = [["问题", "sql", "回答", "thought"]]
+
+# for ll in res_q:
+# if ll == "":
+# break
+# payload = {"return_type": "text", "question": ll}
+# response = requests.post(url, headers=headers, data=json.dumps(payload))
+# response_data = response.json()
+# print("data_res", response_data)
+# res.append(
+# [
+# ll,
+# response_data["sql"],
+# response_data["result"],
+# response_data["thought"],
+# ]
+# )
+# data = pd.DataFrame(res)
+# data.to_excel("./行业-时间.xlsx", index=False)
+# except Exception as error:
+# print("Error:", error)
+# return "Error"
+
+
+if __name__ == "__main__":
+ keys_combination_list = [
+ ['time', 'customer', 'company'],
+ ['time', 'field'],
+ ['time', 'company', 'major'],
+ ['time', 'area'],
+ ['keyword'],
+ ['company', 'money']
+ ]
+ # print(form_question(keys_combination_list[0]))
+ # form_question(keys_combination_list[2],k=3)
+ # test(keys_combination_list[0], k=3)
+ # test(keys_combination_list[1], k=3)
+ # test(keys_combination_list[2], k=3)
+ # test(keys_combination_list[3], k=3)
+ # test(keys_combination_list[4], k=20)
+ test(keys_combination_list[5], k=3)
+
diff --git a/tgi_app.py b/tgi_app.py
new file mode 100644
index 0000000..d84a538
--- /dev/null
+++ b/tgi_app.py
@@ -0,0 +1,43 @@
+import fastapi
+import config
+from http import HTTPStatus
+from pydantic import BaseModel, Field
+from typing import Any
+from sqlcode.qwenapi import Generator
+
+app = fastapi.FastAPI()
+
+class Request(BaseModel):
+ inputs: str = Field(description='Prompt')
+ parameters: Any | None = Field(default=None, description='Generation parameters')
+ stream: bool = Field(default=False, description='Whether to stream output tokens')
+
+class Response(BaseModel):
+ generated_text: str = Field(description='Generated text')
+ details: Any = Field(default=None, description='Generation details')
+
+@app.post('/{model}')
+async def generate(model:str, apikey:str, req:Request) -> Response:
+ client = config.api_key(apikey)
+ if not client:
+ raise fastapi.HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail='invalid apikey')
+
+ message = [{'role': 'system',
+ 'content': '你擅长编写 MySQL 的 SQL 代码,请结合具体问题编写正确规范的 SQL 代码'
+ }]
+
+ generator = Generator(model=model,
+ messages=message,
+ apikey=client.dashscope_api_key,
+ seed=0,
+ )
+ prompt = req.inputs
+
+ if (n1:=prompt.find('[QUESTION]')) >= 0 and (n2:=prompt.find('[/QUESTION]')) >= n1:
+ question = prompt[n1+10:n2]
+ else:
+ question = 'UNKNOWN'
+
+ ret = generator.generate(question, prompt)
+
+ return Response(generated_text=(ret.sql or ''))
\ No newline at end of file
diff --git a/valueSets.xlsx b/valueSets.xlsx
new file mode 100644
index 0000000..ff314a5
Binary files /dev/null and b/valueSets.xlsx differ
diff --git a/wechat.py b/wechat.py
new file mode 100644
index 0000000..b89b74a
--- /dev/null
+++ b/wechat.py
@@ -0,0 +1,233 @@
+#!/usr/bin/env python
+# -*- encoding:utf-8 -*-
+
+"""
+微信企业后台 接口.
+"""
+
+import xml.etree.cElementTree as ET
+from bizwechat import WXBizMsgCrypt
+from http import HTTPStatus
+import fastapi
+import logging
+import config
+import asyncio
+import aiohttp
+from sqlcode.sql_agent import create_sql_agent, remove_markdown_code_block
+from langchain_community.agent_toolkits import SQLDatabaseToolkit
+from langchain_community.utilities import SQLDatabase
+from langchain.agents.agent_types import AgentType
+from langchain_core.tools import StructuredTool
+from langchain.prompts import PromptTemplate
+from langchain_chroma import Chroma
+from langchain_community.embeddings import DashScopeEmbeddings
+from sqlcode.utils import parse, format_docs
+from langchain_core.runnables import RunnablePassthrough
+from sqlcode.multi_agent import create_sql_graph
+
+app = fastapi.FastAPI()
+logger = logging.getLogger('sqlcode')
+
+def get_wxcpt():
+ '''
+ 创建微信企业后台接口的加密解密对象
+ '''
+ wxbiz_config = config.bizwechat_config()
+ return WXBizMsgCrypt(sToken=wxbiz_config.token,
+ sEncodingAESKey=wxbiz_config.aes_key,
+ sReceiveId=wxbiz_config.corp_id)
+
+
+@app.get('/', include_in_schema=False)
+async def verify_url(request:fastapi.Request):
+ '''
+ 验证微信公众号的请求URL的合法性
+ '''
+ signature = request.query_params.get('msg_signature')
+ timestamp = request.query_params.get('timestamp')
+ nonce = request.query_params.get('nonce')
+ echostr = request.query_params.get('echostr')
+
+ if not signature or not timestamp or not nonce or not echostr:
+ logger.error('verify_url failed, missing parameters')
+ return fastapi.Response('', HTTPStatus.BAD_REQUEST)
+
+ code, echostr = get_wxcpt().VerifyURL(signature, timestamp, nonce, echostr)
+
+ if code == 0:
+ logger.info('verify_url success, echostr: %s', echostr)
+ return fastapi.Response(echostr, HTTPStatus.OK)
+ else:
+ logger.error('verify_url failed, error code: %s', code)
+ return fastapi.Response('', HTTPStatus.BAD_REQUEST)
+
+
+@app.post('/', include_in_schema=False)
+async def receive_message(request:fastapi.Request):
+ '''
+ 接收微信公众号消息,并回复
+ '''
+ signature = request.query_params.get('msg_signature')
+ timestamp = request.query_params.get('timestamp')
+ nonce = request.query_params.get('nonce')
+
+ if not signature or not timestamp or not nonce:
+ logger.error('receive_message failed, missing parameters')
+ return '', HTTPStatus.BAD_REQUEST
+
+ logger.info('receive_message timestamp: %s, nonce: %s', timestamp, nonce)
+
+ post_data = await request.body()
+ code, post_data = get_wxcpt().DecryptMsg(post_data, signature, timestamp, nonce)
+
+ if code != 0:
+ logger.error('receive_message failed, error code: %s', code)
+ return '', HTTPStatus.BAD_REQUEST
+
+ xml = ET.fromstring(post_data)
+ content = xml.find('Content').text
+ from_user = xml.find('FromUserName').text
+
+ asyncio.create_task(async_query_and_reply(
+ apikey=config.bizwechat_config().qgi_api_key,
+ model_name='qwen',
+ database='contracts',
+ question=content,
+ to_user=from_user,
+ ))
+
+ logger.info('receive_message success, timestamp: %s, nonce: %s', timestamp, nonce)
+ return ''
+
+
+from sqlcode.qgi import Executor, NewFormatter, ReturnType
+# from sqlcode.qwenapi import Generator
+from sqlcode.langchain_model import Generator
+from sqlcode.modelloader import ModelLoader, ModelManager
+
+async def async_query_and_reply(apikey:str, model_name:str, database:str, question:str, to_user:str):
+ """
+ 根据请求执行查询并返回结果。
+
+ 该函数接收一个模型名称、数据库名称和一个请求对象,通过这些信息来执行特定的查询操作。
+ 查询的结果会根据请求中指定的返回类型进行处理和包装。
+
+ * :param model: 字符串类型,表示要使用的大模型的名称。例如 qwen-turbo 等
+ * :param database: 字符串类型,表示要查询的数据库的名称。
+ * :param req: Request 类型的对象,包含查询的具体问题和期望的返回类型。
+ """
+ logger.info('开始处理问题:%s', question)
+ metadata = config.metadata(database)
+ qwen_cfg = config.qwen_config("qwen_graph.conf")
+ re_cfg = config.refineProblem_config()
+
+ modelLoader = ModelLoader(config.model_config())
+ modelManager = ModelManager(modelLoader)
+ modelManager.switch_model(model_name)
+ model = modelManager.get_model()
+
+ # db_dir = "chroma_db"
+ # # 需要先检查向量数据库是否为最新,执行sqlcode/store_vecstore.py来更新向量数据库
+ # vectorstore = Chroma(persist_directory=db_dir, embedding_function=DashScopeEmbeddings())
+ # retriever = vectorstore.as_retriever(search_kwargs={'k': 2})
+ # re_prompt = PromptTemplate.from_template(re_cfg.prompt)
+ # context_chain = retriever | format_docs
+ # context = context_chain.invoke(question)
+ # rag_chain = (
+ # {"context": lambda x: context, "question": RunnablePassthrough()}
+ # | re_prompt
+ # | model
+ # | parse
+ # )
+ # refine_question = rag_chain.invoke(question)
+
+ # db = SQLDatabase.from_uri(metadata.connection_string)
+ # toolkit = SQLDatabaseToolkit(db=db, llm=model)
+ # remove_markdown_tool = StructuredTool.from_function(
+ # func=remove_markdown_code_block,
+ # description="当发生因为markdown标记导致的函数输入错误时,可以调用此函数来删除标记",
+ # name="remove_markdown_code_block"
+ # )
+ # prompt = PromptTemplate.from_template(qwen_cfg.prompt)
+ # prompt = prompt.partial(metadata=metadata.metadata, product=metadata.product, example=qwen_cfg.params["example"], context=context)
+ # agent_executor = create_sql_agent(
+ # llm=model,
+ # prompt=prompt,
+ # toolkit=toolkit,
+ # verbose=True,
+ # agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
+ # extra_tools=[remove_markdown_tool],
+ # agent_executor_kwargs={"handle_parsing_errors": True}
+ # )
+
+ agent_executor = create_sql_graph(model=model, qwen_cfg=qwen_cfg, data_cfg=metadata)
+
+ formatter = NewFormatter(format=ReturnType.WX_MD,
+ tranlate=None,
+ output_dir=config.osp.join(config.BASE_DIR, 'output'),
+ site_url=config.app_config().site_url + '/output',
+ )
+
+ generator = Generator(agentExcutor=agent_executor,
+ messages=[{'role': 'system', 'content': qwen_cfg.system}],
+ apikey=config.api_key(apikey).dashscope_api_key,
+ seed=0,
+ )
+
+ executor = Executor(generator=generator, formatter=formatter)
+
+ # input = {"input": question, "refined_question": refine_question}
+ input = {"messages": [("human", question)], "iterations": 0}
+ ret = executor.query(connection_string=metadata.connection_string,
+ input=input,
+ )
+
+ # thought and result 分开发送因为消息大小限制为 2048 字节(utf-8)
+ await send_msg(to_user, question, ret.error)
+ if ret.thought.startswith('```') and ret.thought.endswith('```'):
+ await send_msg(to_user, question, '针对这个问题,采用 SQL 查询:\n' + ret.thought)
+ else:
+ await send_msg(to_user, question, ret.thought)
+ await send_msg(to_user, question, ret.result)
+
+async def send_msg(to_user:str, question:str, content:str):
+ '''
+ 发送问题的回复给指定用户
+ '''
+ if not content:
+ return
+
+ msg = {
+ "touser" : to_user,
+ "msgtype": "markdown",
+ "agentid" : config.bizwechat_config().agent_id,
+ "markdown": {"content": content,},
+ }
+ access_token = config.wxbiz_token()
+ url = f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}"
+
+ async with aiohttp.ClientSession() as session:
+ async with session.post(url, json=msg) as resp:
+ resp_text = await resp.text()
+ logger.debug('send_msg to %s %s: %s', to_user, question, resp_text)
+
+
+if __name__ == '__main__':
+ import sys
+ if len(sys.argv) > 1 and sys.argv[1] == 'test':
+ logging.basicConfig(level=logging.DEBUG)
+ asyncio.run(async_query_and_reply(apikey=config.bizwechat_config().qgi_api_key,
+ model='qwen-max',
+ database='contracts',
+ question="2020年业绩最好的分公司",
+ to_user='SunHaiWen',
+ ))
+ elif len(sys.argv) > 1 and sys.argv[1] == 'send':
+ logging.basicConfig(level=logging.DEBUG)
+ asyncio.run(send_msg(to_user='SunHaiWen',
+ question="2020年业绩最好的分公司",
+ content="2020年业绩最好的分公司是北京分公司",
+ ))
+ else:
+ import uvicorn
+ uvicorn.run(app, host='0.0.0.0', port=9000)