from cachetools import LRUCache from sklearn.feature_extraction.text import TfidfVectorizer import numpy as np import os class QACache: def __init__(self, maxsize=100, similarity_threshold=0.5): self.cache = LRUCache(maxsize=maxsize) self.vectorizer = TfidfVectorizer() self.questions = [] self.question_vectors = [] self.similarity_threshold = similarity_threshold # 设置相似度阈值 def add(self, question_and_answer:dict): """添加问答对到缓存""" if not isinstance(question_and_answer, dict): raise TypeError("Invalid input. Expected a dictionary.") if question_and_answer == {}: print('传入的问答对为空。') return # 遍历字典并添加每个键值对 for question, answer in question_and_answer.items(): is_useful_answer = self._check_answer(answer) if not is_useful_answer: continue self.cache[question] = answer self.questions.append(question) # 只有在问题添加后才进行向量化 print("question为:", question) if len(self.questions) > 0: self._vectorize_questions() else: print('没有有效的问题添加到缓存。') # self._vectorize_questions() def _check_answer(self, answer:str)->bool: if not isinstance(answer, str) or not answer: return False """对回答进行过滤""" if answer == "没有符合条件的记录" or answer == "没有符合条件的记录\n": return False # 目前 answer 的类型只是 Dataframe 类型的字符串,要考虑后面存储的回答是否有其他格式 if answer.find('varchar') != -1: return False str_list = answer.split('\n') for key in ['经办人', '所属分公司', '合同形式', '合同名称', '项目来源', '专业', '地点', '客户名称', '客商类型', '合同签订金额(人民币)', '签订日期', '合同有效期(结束)']: if key in str_list[0]: return True return False def _vectorize_questions(self): """向量化所有问题""" self.question_vectors = self.vectorizer.fit_transform(self.questions) def find_similar(self, question): """找到最相似的问题和答案""" if len(self.questions) == 0: print('未缓存问答对。') return None question_vector = self.vectorizer.transform([question]) cosine_similarities = np.dot(self.question_vectors, question_vector.T).toarray().flatten() most_similar_idx = cosine_similarities.argmax() similarity_score = cosine_similarities[most_similar_idx] most_similar_question = self.questions[most_similar_idx] print('similarity_score:{}'.format(similarity_score)) print('most_similar_question:{}'.format(most_similar_question)) # 只有当相似度超过阈值时才返回问题和答案 if similarity_score > self.similarity_threshold: return self.cache[most_similar_question] else: return None def load_qa_pairs(qa_filename:str) -> None: if not os.path.exists(qa_filename): file = open(qa_filename, 'w') file.close() qa_dict = {} with open(qa_filename, 'r') as file: lines = file.readlines() positions = [] for idx, line in enumerate(lines): if line.find('qa_cache') != -1: positions.append(idx) positions.append(len(lines)) chunks = [] for i in range(0, len(positions)-1): chunk = lines[positions[i]:positions[i+1]] if '[SUCCESS]' in lines[positions[i]]: chunks.append(chunk) question = '' answer = '' for chunk in chunks: question = chunk[2].strip() answer = ''.join(chunk[4:]) # print(question, answer) qa_dict[question] = answer return qa_dict if __name__ == '__main__': qa_dict = load_qa_pairs('log/test1111.log') qa_cache = QACache(similarity_threshold=1) qa_cache.add(qa_dict) # 使用示例 # qa_cache = QACache(maxsize=10) # qa_cache.add("What is machine learning?", "Machine learning is a type of artificial intelligence.") # qa_cache.add("How does machine learning work?", "It uses statistical techniques to give computers the ability to learn from data.") # 用户提问 user_question = "机房的项目有哪些" similar_answer = qa_cache.find_similar(user_question) # print(similar_answer)