''' Author: scutlzc scutlzc@gmail.com Date: 2024-07-12 09:23:37 LastEditors: scutlzc scutlzc@gmail.com LastEditTime: 2024-07-12 09:51:30 FilePath: \bizwechat\sqlcode\modelloader.py Description: Copyright (c) 2024 by ${git_name_email}, All Rights Reserved. ''' import yaml # from langchain.llms import OpenAI # from langchain.chains import LLamaChain, CustomChain from langchain_community.llms import Tongyi from langchain_community.chat_models import ChatTongyi from sqlcode.baidu_qianfan_endpoint import QianfanLLMEndpoint class ModelLoader: def __init__(self, config): # with open(config_path, 'r') as file: # self.config = yaml.safe_load(file) self.config = config def load_model(self, model_name): model_config = self.config['models'][model_name] model_type = model_config['type'] if model_type == 'tongyi': # return Tongyi(dashscope_api_key=model_config['api_key'], model_name=model_config['model_name']) return ChatTongyi(dashscope_api_key=model_config['api_key'], model=model_config['model_name']) elif model_type == 'qianfan': return QianfanLLMEndpoint(qianfan_ak=model_config['ak'], qianfan_sk=model_config['sk'], model=model_config['model_name']) else: raise ValueError(f"Unknown model type: {model_type}") class ModelManager: def __init__(self, model_loader): self.model_loader = model_loader self.current_model = None def switch_model(self, model_name): self.current_model = self.model_loader.load_model(model_name) print(f"Switched to model: {model_name}") def get_model(self): return self.current_model