bizwechat/test_acc.py

320 lines
12 KiB
Python
Raw Normal View History

2025-02-17 10:34:35 +08:00
import requests
import json
import pandas as pd
import random
from itertools import product
template_list = [
{ # 时间、客户、公司
"keys": ['time', 'customer', 'company'],
"sql": """
SELECT `合同名称`, `签订日期`, `客户名称`, `所属分公司`
FROM contracts
WHERE `所属分公司` LIKE '%{company}%'
AND `客户名称` LIKE '%{customer}%' OR `客商类型` LIKE '%{customer}%' OR `合同名称` LIKE '%{customer}%'
AND 根据{time}设置`签订时间`的筛选范围""",
"question": "{time}{company}承接的“客户名称”或“客商类型”或“合同名称”包含{customer}的项目有哪些参考sql语句[sql]",
},
{ # 时间、行业(包括专业)
"keys": ['time', 'field'],
"sql": """
SELECT `合同名称`, `签订日期`, `聚焦行业`
FROM contracts
WHERE `合同名称` LIKE '%{field}%' OR `聚焦行业` LIKE '%{field}%' OR `专业` LIKE '%{field}%'
AND 根据{time}设置`签订时间`的筛选范围""",
"question": "{time}承接的“聚焦行业”或“行业”包含{field}的项目有哪些参考sql语句[sql]"
},
{ # 时间、公司、专业
"keys": ['time', 'company', 'major'],
"sql": """
SELECT `合同名称`, `签订日期`, `专业`, `所属分公司`
FROM contracts
WHERE `合同名称` LIKE '%{major}%' OR `专业` LIKE '%{major}%'
AND `所属分公司` LIKE '%{company}%'
AND 根据{time}设置`签订时间`的筛选范围""",
"question": "{time}{company}承接的“专业”包含{major}的项目有哪些参考sql语句[sql]"
},
{ # 时间、地点
"keys": ['time', 'area'],
"sql": """
SELECT `合同名称`, `地点`
FROM contracts
WHERE `地点` LIKE '%{area}%'
AND 根据{time}设置`签订时间`的筛选范围""",
"question": "{time}承接的地点在“{area}”的项目有哪些参考sql语句[sql]",
},
{ # 关键字查询(专业、行业、公司、客户、地点)
"keys": ['keyword'],
"sql": """
SELECT `合同名称`, `所属分公司`, `专业`, `聚焦行业`, `客商类型`, `地点`
FROM contracts
WHERE `合同名称` LIKE '%{keyword}%' OR `聚焦行业` LIKE '%{keyword}%'
OR `所属分公司` LIKE '%{keyword}%' OR `客商类型` LIKE '%{keyword}%'
OR `地点` LIKE '%{keyword}%' OR `客户名称` LIKE '%{keyword}%'
OR `专业` LIKE '%{keyword}%'""",
"question": "“地点”“专业”“聚焦行业”“所属分公司”“客商类型”“客户名称”或“合同名称”包含{keyword}的项目有哪些参考sql语句[sql]",
},
{ # 公司、金额
"keys": ['company', 'money'],
"sql": """
SELECT `合同名称`, `所属分公司`, `合同签订金额人民币`
FROM contracts
WHERE `公司` LIKE '%{company}%'
AND 根据{money}筛选`合同签订金额人民币`的筛选范围""",
"question": "{company}的合同金额{money}的项目有哪些参考sql语句[sql]"
}]
raw_data = {
'time': list(pd.read_excel("./valueSets.xlsx", sheet_name="时间取值")["时间取值"]),
'major': list(pd.read_excel("./valueSets.xlsx", sheet_name="专业取值")["专业取值"]),
'company': list(pd.read_excel("./valueSets.xlsx", sheet_name="公司取值")["公司取值"]),
'field': list(pd.read_excel("./valueSets.xlsx", sheet_name="行业取值")["行业取值"]),
'customer': list(pd.read_excel("./valueSets.xlsx", sheet_name="客户取值")["客户取值"]),
'area': list(pd.read_excel("./valueSets.xlsx", sheet_name="地区取值")["地区取值"]),
'money': [ "100万以上的", "50万以上的", "2000万以上", "400万以上", "10万以上", "100万到1000万",], # 金额范围取值
}
raw_data['keyword'] = [item for key in ['area', 'company', 'customer', 'major', 'field'] for item in raw_data[key]]
def combine_key_val(keys: list, k=4) -> dict:
"""
计算不同字段的排列组合
Args:
keys (_type_): 字段名列表
k (int, optional): 每个字段的取值数量. Defaults to 4.
Returns:
list: [{key1: value1}]字典列表
"""
tmp = {}
for key in keys:
value = raw_data.get(key, [])
# if key == 'company':
# tmp[key] = ['一分公司']
# else:
# k = len(value) if k > len(value) else k
# tmp[key] = random.sample(value, k=k)
k = len(value) if k > len(value) else k
tmp[key] = random.sample(value, k=k)
print(tmp)
# 计算笛卡尔积
keys = list(tmp.keys())
values = list(tmp.values())
combinations = list(product(*values))
# 将每个组合转换为字典
return [{key: value for key, value in zip(keys, combination)} for combination in combinations]
def form_question(keys, k) -> list:
"""根据选填字段生成问题
Args:
keys (_type_): _description_
Returns:
list: _description_
"""
combination_dict = combine_key_val(keys, k=k)
print(len(combination_dict))
template_idx = 0
for index, item in enumerate(template_list):
if set(item['keys']) == set(keys):
template_idx = index
break
template = template_list[template_idx]
question_template = template['question'].replace('[sql]', template['sql'])
question_list = []
tmp = ''
for combination in combination_dict:
tmp = question_template
tmp = tmp.format_map(combination)
question_list.append(tmp)
return question_list
def test(keys, k=3):
question_list = form_question(keys, k=k)
try:
res = [["问题", "sql", "回答", "thought"]]
for question in question_list:
if question == "":
break
payload = {"return_type": "text", "question": question}
response = requests.post(url, headers=headers, data=json.dumps(payload))
response_data = response.json()
print("data_res", response_data)
res.append(
[
question,
response_data["sql"],
response_data["result"],
response_data["thought"],
]
)
data = pd.DataFrame(res)
data.to_excel("./output/{}.xlsx".format('-'.join(keys)), index=False)
except Exception as error:
print("Error:", error)
return "Error"
url = "http://localhost:8001/qwen/contracts?apikey=YUVietLgiGmtqzYUVIIGjrNoLMsGM0FI"
headers = {
"Content-Type": "application/json",
# Uncomment the following line if needed and provide the appropriate value for data['apiKey']
# 'Authorization': f'Bearer {data["apiKey"]}'
}
# # def some_function():
# # print(pdata["问题"])
# ls = pdata["问题"]
# qs = []
# tmp = []
# res_q = []
# # for ll in ["数据库字段中的“合同名称”或“聚焦行业”或“专业”包含“key”的项目有哪些?"]:
# # for ll in ["time公司承接的money的项目有哪些/多少今年是2024年。"]:
# # for ll in ["time公司承接的customer的项目有哪些/多少其中值为“customer”是数据库中“客商类型”字段的部分内容。今年是2024年。"]:
# for ll in [
# "company time承接的field的项目有哪些其中值为“company”是数据库中“所属分公司”字段的部分内容“field”是数据库中“合同名称”或“聚焦行业”字段的部分内容。今年是2024年。"
# ]:
# # for ll in ["time公司承接的field的项目有哪些/多少其中值为“field”是数据库中“合同名称”或“聚焦行业”字段的部分内容。"]:
# qs = []
# tmp = []
# tmp.append(ll)
# if "time" in ll:
# for ts in time:
# t0 = ll.replace("time", ts)
# qs.append(t0)
# tmp = qs
# qs = []
# for t1 in tmp:
# if "major" in t1:
# for sp in major:
# t0 = t1.replace("major", sp)
# qs.append(t0)
# else:
# qs.append(t1)
# tmp = qs
# qs = []
# for t1 in tmp:
# if "company" in t1:
# for sp in company:
# t0 = t1.replace("company", sp)
# qs.append(t0)
# else:
# qs.append(t1)
# tmp = qs
# qs = []
# for t1 in tmp:
# if "field" in t1:
# for sp in field:
# t0 = t1.replace("field", sp)
# qs.append(t0)
# else:
# qs.append(t1)
# tmp = qs
# qs = []
# for t1 in tmp:
# if "key" in t1:
# for sp in key:
# t0 = t1.replace("key", sp)
# qs.append(t0)
# else:
# qs.append(t1)
# tmp = qs
# qs = []
# # money
# for t1 in tmp:
# if "money" in t1:
# for sp in money:
# t0 = t1.replace("money", sp)
# qs.append(t0)
# else:
# qs.append(t1)
# tmp = qs
# qs = []
# for t1 in tmp:
# if "customer" in t1:
# for sp in customer:
# t0 = t1.replace("customer", sp)
# qs.append(t0)
# else:
# qs.append(t1)
# tmp = qs
# qs = []
# for t1 in tmp:
# if "area" in t1:
# for sp in area:
# t0 = t1.replace("area", sp)
# qs.append(t0)
# else:
# qs.append(t1)
# # tmp = qs
# # qs = []
# # print(len(qs))
# res_q = [*res_q, *qs]
# print(len(res_q))
# # print(res_q)
# # return
# try:
# res = [["问题", "sql", "回答", "thought"]]
# for ll in res_q:
# if ll == "":
# break
# payload = {"return_type": "text", "question": ll}
# response = requests.post(url, headers=headers, data=json.dumps(payload))
# response_data = response.json()
# print("data_res", response_data)
# res.append(
# [
# ll,
# response_data["sql"],
# response_data["result"],
# response_data["thought"],
# ]
# )
# data = pd.DataFrame(res)
# data.to_excel("./行业-时间.xlsx", index=False)
# except Exception as error:
# print("Error:", error)
# return "Error"
if __name__ == "__main__":
keys_combination_list = [
['time', 'customer', 'company'],
['time', 'field'],
['time', 'company', 'major'],
['time', 'area'],
['keyword'],
['company', 'money']
]
# print(form_question(keys_combination_list[0]))
# form_question(keys_combination_list[2],k=3)
# test(keys_combination_list[0], k=3)
# test(keys_combination_list[1], k=3)
# test(keys_combination_list[2], k=3)
# test(keys_combination_list[3], k=3)
# test(keys_combination_list[4], k=20)
test(keys_combination_list[5], k=3)