bizwechat/sqlcode/qa_cache.py

114 lines
4.6 KiB
Python
Raw Permalink Normal View History

2025-02-17 10:34:35 +08:00
from cachetools import LRUCache
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
import os
class QACache:
def __init__(self, maxsize=100, similarity_threshold=0.5):
self.cache = LRUCache(maxsize=maxsize)
self.vectorizer = TfidfVectorizer()
self.questions = []
self.question_vectors = []
self.similarity_threshold = similarity_threshold # 设置相似度阈值
def add(self, question_and_answer:dict):
"""添加问答对到缓存"""
if not isinstance(question_and_answer, dict):
raise TypeError("Invalid input. Expected a dictionary.")
if question_and_answer == {}:
print('传入的问答对为空。')
return
# 遍历字典并添加每个键值对
for question, answer in question_and_answer.items():
is_useful_answer = self._check_answer(answer)
if not is_useful_answer:
continue
self.cache[question] = answer
self.questions.append(question)
# 只有在问题添加后才进行向量化
print("question为:", question)
if len(self.questions) > 0:
self._vectorize_questions()
else:
print('没有有效的问题添加到缓存。')
# self._vectorize_questions()
def _check_answer(self, answer:str)->bool:
if not isinstance(answer, str) or not answer:
return False
"""对回答进行过滤"""
if answer == "没有符合条件的记录" or answer == "没有符合条件的记录\n":
return False
# 目前 answer 的类型只是 Dataframe 类型的字符串,要考虑后面存储的回答是否有其他格式
if answer.find('varchar') != -1:
return False
str_list = answer.split('\n')
for key in ['经办人', '所属分公司', '合同形式', '合同名称', '项目来源', '专业', '地点', '客户名称', '客商类型', '合同签订金额(人民币)', '签订日期', '合同有效期(结束)']:
if key in str_list[0]:
return True
return False
def _vectorize_questions(self):
"""向量化所有问题"""
self.question_vectors = self.vectorizer.fit_transform(self.questions)
def find_similar(self, question):
"""找到最相似的问题和答案"""
if len(self.questions) == 0:
print('未缓存问答对。')
return None
question_vector = self.vectorizer.transform([question])
cosine_similarities = np.dot(self.question_vectors, question_vector.T).toarray().flatten()
most_similar_idx = cosine_similarities.argmax()
similarity_score = cosine_similarities[most_similar_idx]
most_similar_question = self.questions[most_similar_idx]
print('similarity_score:{}'.format(similarity_score))
print('most_similar_question:{}'.format(most_similar_question))
# 只有当相似度超过阈值时才返回问题和答案
if similarity_score > self.similarity_threshold:
return self.cache[most_similar_question]
else:
return None
def load_qa_pairs(qa_filename:str) -> None:
if not os.path.exists(qa_filename):
file = open(qa_filename, 'w')
file.close()
qa_dict = {}
with open(qa_filename, 'r') as file:
lines = file.readlines()
positions = []
for idx, line in enumerate(lines):
if line.find('qa_cache') != -1:
positions.append(idx)
positions.append(len(lines))
chunks = []
for i in range(0, len(positions)-1):
chunk = lines[positions[i]:positions[i+1]]
if '[SUCCESS]' in lines[positions[i]]:
chunks.append(chunk)
question = ''
answer = ''
for chunk in chunks:
question = chunk[2].strip()
answer = ''.join(chunk[4:])
# print(question, answer)
qa_dict[question] = answer
return qa_dict
if __name__ == '__main__':
qa_dict = load_qa_pairs('log/test1111.log')
qa_cache = QACache(similarity_threshold=1)
qa_cache.add(qa_dict)
# 使用示例
# qa_cache = QACache(maxsize=10)
# qa_cache.add("What is machine learning?", "Machine learning is a type of artificial intelligence.")
# qa_cache.add("How does machine learning work?", "It uses statistical techniques to give computers the ability to learn from data.")
# 用户提问
user_question = "机房的项目有哪些"
similar_answer = qa_cache.find_similar(user_question)
# print(similar_answer)