bizwechat/test_acc.py
2025-02-17 10:34:35 +08:00

320 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import requests
import json
import pandas as pd
import random
from itertools import product
template_list = [
{ # 时间、客户、公司
"keys": ['time', 'customer', 'company'],
"sql": """
SELECT `合同名称`, `签订日期`, `客户名称`, `所属分公司`
FROM contracts
WHERE `所属分公司` LIKE '%{company}%'
AND `客户名称` LIKE '%{customer}%' OR `客商类型` LIKE '%{customer}%' OR `合同名称` LIKE '%{customer}%'
AND 根据{time}设置`签订时间`的筛选范围)""",
"question": "{time}{company}承接的“客户名称”或“客商类型”或“合同名称”包含{customer}的项目有哪些参考sql语句[sql]",
},
{ # 时间、行业(包括专业)
"keys": ['time', 'field'],
"sql": """
SELECT `合同名称`, `签订日期`, `聚焦行业`
FROM contracts
WHERE `合同名称` LIKE '%{field}%' OR `聚焦行业` LIKE '%{field}%' OR `专业` LIKE '%{field}%'
AND 根据{time}设置`签订时间`的筛选范围)""",
"question": "{time}承接的“聚焦行业”或“行业”包含{field}的项目有哪些参考sql语句[sql]"
},
{ # 时间、公司、专业
"keys": ['time', 'company', 'major'],
"sql": """
SELECT `合同名称`, `签订日期`, `专业`, `所属分公司`
FROM contracts
WHERE `合同名称` LIKE '%{major}%' OR `专业` LIKE '%{major}%'
AND `所属分公司` LIKE '%{company}%'
AND 根据{time}设置`签订时间`的筛选范围)""",
"question": "{time}{company}承接的“专业”包含{major}的项目有哪些参考sql语句[sql]"
},
{ # 时间、地点
"keys": ['time', 'area'],
"sql": """
SELECT `合同名称`, `地点`
FROM contracts
WHERE `地点` LIKE '%{area}%'
AND 根据{time}设置`签订时间`的筛选范围)""",
"question": "{time}承接的地点在“{area}”的项目有哪些参考sql语句[sql]",
},
{ # 关键字查询(专业、行业、公司、客户、地点)
"keys": ['keyword'],
"sql": """
SELECT `合同名称`, `所属分公司`, `专业`, `聚焦行业`, `客商类型`, `地点`
FROM contracts
WHERE `合同名称` LIKE '%{keyword}%' OR `聚焦行业` LIKE '%{keyword}%'
OR `所属分公司` LIKE '%{keyword}%' OR `客商类型` LIKE '%{keyword}%'
OR `地点` LIKE '%{keyword}%' OR `客户名称` LIKE '%{keyword}%'
OR `专业` LIKE '%{keyword}%'""",
"question": "“地点”“专业”“聚焦行业”“所属分公司”“客商类型”“客户名称”或“合同名称”包含{keyword}的项目有哪些参考sql语句[sql]",
},
{ # 公司、金额
"keys": ['company', 'money'],
"sql": """
SELECT `合同名称`, `所属分公司`, `合同签订金额(人民币)`
FROM contracts
WHERE `公司` LIKE '%{company}%'
AND 根据{money}筛选`合同签订金额(人民币)`的筛选范围)""",
"question": "{company}的合同金额{money}的项目有哪些参考sql语句[sql]"
}]
raw_data = {
'time': list(pd.read_excel("./valueSets.xlsx", sheet_name="时间取值")["时间取值"]),
'major': list(pd.read_excel("./valueSets.xlsx", sheet_name="专业取值")["专业取值"]),
'company': list(pd.read_excel("./valueSets.xlsx", sheet_name="公司取值")["公司取值"]),
'field': list(pd.read_excel("./valueSets.xlsx", sheet_name="行业取值")["行业取值"]),
'customer': list(pd.read_excel("./valueSets.xlsx", sheet_name="客户取值")["客户取值"]),
'area': list(pd.read_excel("./valueSets.xlsx", sheet_name="地区取值")["地区取值"]),
'money': [ "100万以上的", "50万以上的", "2000万以上", "400万以上", "10万以上", "100万到1000万",], # 金额范围取值
}
raw_data['keyword'] = [item for key in ['area', 'company', 'customer', 'major', 'field'] for item in raw_data[key]]
def combine_key_val(keys: list, k=4) -> dict:
"""
计算不同字段的排列组合
Args:
keys (_type_): 字段名列表
k (int, optional): 每个字段的取值数量. Defaults to 4.
Returns:
list: [{key1: value1}]字典列表
"""
tmp = {}
for key in keys:
value = raw_data.get(key, [])
# if key == 'company':
# tmp[key] = ['一分公司']
# else:
# k = len(value) if k > len(value) else k
# tmp[key] = random.sample(value, k=k)
k = len(value) if k > len(value) else k
tmp[key] = random.sample(value, k=k)
print(tmp)
# 计算笛卡尔积
keys = list(tmp.keys())
values = list(tmp.values())
combinations = list(product(*values))
# 将每个组合转换为字典
return [{key: value for key, value in zip(keys, combination)} for combination in combinations]
def form_question(keys, k) -> list:
"""根据选填字段生成问题
Args:
keys (_type_): _description_
Returns:
list: _description_
"""
combination_dict = combine_key_val(keys, k=k)
print(len(combination_dict))
template_idx = 0
for index, item in enumerate(template_list):
if set(item['keys']) == set(keys):
template_idx = index
break
template = template_list[template_idx]
question_template = template['question'].replace('[sql]', template['sql'])
question_list = []
tmp = ''
for combination in combination_dict:
tmp = question_template
tmp = tmp.format_map(combination)
question_list.append(tmp)
return question_list
def test(keys, k=3):
question_list = form_question(keys, k=k)
try:
res = [["问题", "sql", "回答", "thought"]]
for question in question_list:
if question == "":
break
payload = {"return_type": "text", "question": question}
response = requests.post(url, headers=headers, data=json.dumps(payload))
response_data = response.json()
print("data_res", response_data)
res.append(
[
question,
response_data["sql"],
response_data["result"],
response_data["thought"],
]
)
data = pd.DataFrame(res)
data.to_excel("./output/{}.xlsx".format('-'.join(keys)), index=False)
except Exception as error:
print("Error:", error)
return "Error"
url = "http://localhost:8001/qwen/contracts?apikey=YUVietLgiGmtqzYUVIIGjrNoLMsGM0FI"
headers = {
"Content-Type": "application/json",
# Uncomment the following line if needed and provide the appropriate value for data['apiKey']
# 'Authorization': f'Bearer {data["apiKey"]}'
}
# # def some_function():
# # print(pdata["问题"])
# ls = pdata["问题"]
# qs = []
# tmp = []
# res_q = []
# # for ll in ["数据库字段中的“合同名称”或“聚焦行业”或“专业”包含“key”的项目有哪些?"]:
# # for ll in ["time公司承接的money的项目有哪些/多少今年是2024年。"]:
# # for ll in ["time公司承接的customer的项目有哪些/多少其中值为“customer”是数据库中“客商类型”字段的部分内容。今年是2024年。"]:
# for ll in [
# "company time承接的field的项目有哪些其中值为“company”是数据库中“所属分公司”字段的部分内容“field”是数据库中“合同名称”或“聚焦行业”字段的部分内容。今年是2024年。"
# ]:
# # for ll in ["time公司承接的field的项目有哪些/多少其中值为“field”是数据库中“合同名称”或“聚焦行业”字段的部分内容。"]:
# qs = []
# tmp = []
# tmp.append(ll)
# if "time" in ll:
# for ts in time:
# t0 = ll.replace("time", ts)
# qs.append(t0)
# tmp = qs
# qs = []
# for t1 in tmp:
# if "major" in t1:
# for sp in major:
# t0 = t1.replace("major", sp)
# qs.append(t0)
# else:
# qs.append(t1)
# tmp = qs
# qs = []
# for t1 in tmp:
# if "company" in t1:
# for sp in company:
# t0 = t1.replace("company", sp)
# qs.append(t0)
# else:
# qs.append(t1)
# tmp = qs
# qs = []
# for t1 in tmp:
# if "field" in t1:
# for sp in field:
# t0 = t1.replace("field", sp)
# qs.append(t0)
# else:
# qs.append(t1)
# tmp = qs
# qs = []
# for t1 in tmp:
# if "key" in t1:
# for sp in key:
# t0 = t1.replace("key", sp)
# qs.append(t0)
# else:
# qs.append(t1)
# tmp = qs
# qs = []
# # money
# for t1 in tmp:
# if "money" in t1:
# for sp in money:
# t0 = t1.replace("money", sp)
# qs.append(t0)
# else:
# qs.append(t1)
# tmp = qs
# qs = []
# for t1 in tmp:
# if "customer" in t1:
# for sp in customer:
# t0 = t1.replace("customer", sp)
# qs.append(t0)
# else:
# qs.append(t1)
# tmp = qs
# qs = []
# for t1 in tmp:
# if "area" in t1:
# for sp in area:
# t0 = t1.replace("area", sp)
# qs.append(t0)
# else:
# qs.append(t1)
# # tmp = qs
# # qs = []
# # print(len(qs))
# res_q = [*res_q, *qs]
# print(len(res_q))
# # print(res_q)
# # return
# try:
# res = [["问题", "sql", "回答", "thought"]]
# for ll in res_q:
# if ll == "":
# break
# payload = {"return_type": "text", "question": ll}
# response = requests.post(url, headers=headers, data=json.dumps(payload))
# response_data = response.json()
# print("data_res", response_data)
# res.append(
# [
# ll,
# response_data["sql"],
# response_data["result"],
# response_data["thought"],
# ]
# )
# data = pd.DataFrame(res)
# data.to_excel("./行业-时间.xlsx", index=False)
# except Exception as error:
# print("Error:", error)
# return "Error"
if __name__ == "__main__":
keys_combination_list = [
['time', 'customer', 'company'],
['time', 'field'],
['time', 'company', 'major'],
['time', 'area'],
['keyword'],
['company', 'money']
]
# print(form_question(keys_combination_list[0]))
# form_question(keys_combination_list[2],k=3)
# test(keys_combination_list[0], k=3)
# test(keys_combination_list[1], k=3)
# test(keys_combination_list[2], k=3)
# test(keys_combination_list[3], k=3)
# test(keys_combination_list[4], k=20)
test(keys_combination_list[5], k=3)