114 lines
4.6 KiB
Python
114 lines
4.6 KiB
Python
from cachetools import LRUCache
|
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
import numpy as np
|
|
import os
|
|
|
|
class QACache:
|
|
def __init__(self, maxsize=100, similarity_threshold=0.5):
|
|
self.cache = LRUCache(maxsize=maxsize)
|
|
self.vectorizer = TfidfVectorizer()
|
|
self.questions = []
|
|
self.question_vectors = []
|
|
self.similarity_threshold = similarity_threshold # 设置相似度阈值
|
|
|
|
def add(self, question_and_answer:dict):
|
|
"""添加问答对到缓存"""
|
|
if not isinstance(question_and_answer, dict):
|
|
raise TypeError("Invalid input. Expected a dictionary.")
|
|
if question_and_answer == {}:
|
|
print('传入的问答对为空。')
|
|
return
|
|
# 遍历字典并添加每个键值对
|
|
for question, answer in question_and_answer.items():
|
|
is_useful_answer = self._check_answer(answer)
|
|
if not is_useful_answer:
|
|
continue
|
|
self.cache[question] = answer
|
|
self.questions.append(question)
|
|
# 只有在问题添加后才进行向量化
|
|
print("question为:", question)
|
|
if len(self.questions) > 0:
|
|
self._vectorize_questions()
|
|
else:
|
|
print('没有有效的问题添加到缓存。')
|
|
# self._vectorize_questions()
|
|
|
|
def _check_answer(self, answer:str)->bool:
|
|
if not isinstance(answer, str) or not answer:
|
|
return False
|
|
|
|
"""对回答进行过滤"""
|
|
if answer == "没有符合条件的记录" or answer == "没有符合条件的记录\n":
|
|
return False
|
|
|
|
# 目前 answer 的类型只是 Dataframe 类型的字符串,要考虑后面存储的回答是否有其他格式
|
|
if answer.find('varchar') != -1:
|
|
return False
|
|
str_list = answer.split('\n')
|
|
for key in ['经办人', '所属分公司', '合同形式', '合同名称', '项目来源', '专业', '地点', '客户名称', '客商类型', '合同签订金额(人民币)', '签订日期', '合同有效期(结束)']:
|
|
if key in str_list[0]:
|
|
return True
|
|
return False
|
|
|
|
def _vectorize_questions(self):
|
|
"""向量化所有问题"""
|
|
self.question_vectors = self.vectorizer.fit_transform(self.questions)
|
|
|
|
def find_similar(self, question):
|
|
"""找到最相似的问题和答案"""
|
|
if len(self.questions) == 0:
|
|
print('未缓存问答对。')
|
|
return None
|
|
question_vector = self.vectorizer.transform([question])
|
|
cosine_similarities = np.dot(self.question_vectors, question_vector.T).toarray().flatten()
|
|
most_similar_idx = cosine_similarities.argmax()
|
|
similarity_score = cosine_similarities[most_similar_idx]
|
|
most_similar_question = self.questions[most_similar_idx]
|
|
print('similarity_score:{}'.format(similarity_score))
|
|
print('most_similar_question:{}'.format(most_similar_question))
|
|
# 只有当相似度超过阈值时才返回问题和答案
|
|
if similarity_score > self.similarity_threshold:
|
|
return self.cache[most_similar_question]
|
|
else:
|
|
return None
|
|
|
|
def load_qa_pairs(qa_filename:str) -> None:
|
|
if not os.path.exists(qa_filename):
|
|
file = open(qa_filename, 'w')
|
|
file.close()
|
|
qa_dict = {}
|
|
with open(qa_filename, 'r') as file:
|
|
lines = file.readlines()
|
|
positions = []
|
|
for idx, line in enumerate(lines):
|
|
if line.find('qa_cache') != -1:
|
|
positions.append(idx)
|
|
positions.append(len(lines))
|
|
chunks = []
|
|
for i in range(0, len(positions)-1):
|
|
chunk = lines[positions[i]:positions[i+1]]
|
|
if '[SUCCESS]' in lines[positions[i]]:
|
|
chunks.append(chunk)
|
|
question = ''
|
|
answer = ''
|
|
for chunk in chunks:
|
|
question = chunk[2].strip()
|
|
answer = ''.join(chunk[4:])
|
|
# print(question, answer)
|
|
qa_dict[question] = answer
|
|
return qa_dict
|
|
|
|
if __name__ == '__main__':
|
|
|
|
qa_dict = load_qa_pairs('log/test1111.log')
|
|
qa_cache = QACache(similarity_threshold=1)
|
|
qa_cache.add(qa_dict)
|
|
# 使用示例
|
|
# qa_cache = QACache(maxsize=10)
|
|
# qa_cache.add("What is machine learning?", "Machine learning is a type of artificial intelligence.")
|
|
# qa_cache.add("How does machine learning work?", "It uses statistical techniques to give computers the ability to learn from data.")
|
|
|
|
# 用户提问
|
|
user_question = "机房的项目有哪些"
|
|
similar_answer = qa_cache.find_similar(user_question)
|
|
# print(similar_answer) |