49 lines
1.7 KiB
Python
49 lines
1.7 KiB
Python
'''
|
|
Author: scutlzc scutlzc@gmail.com
|
|
Date: 2024-07-12 09:23:37
|
|
LastEditors: scutlzc scutlzc@gmail.com
|
|
LastEditTime: 2024-07-12 09:51:30
|
|
FilePath: \bizwechat\sqlcode\modelloader.py
|
|
Description:
|
|
|
|
Copyright (c) 2024 by ${git_name_email}, All Rights Reserved.
|
|
'''
|
|
import yaml
|
|
# from langchain.llms import OpenAI
|
|
# from langchain.chains import LLamaChain, CustomChain
|
|
from langchain_community.llms import Tongyi
|
|
from langchain_community.chat_models import ChatTongyi
|
|
from sqlcode.baidu_qianfan_endpoint import QianfanLLMEndpoint
|
|
|
|
class ModelLoader:
|
|
def __init__(self, config):
|
|
# with open(config_path, 'r') as file:
|
|
# self.config = yaml.safe_load(file)
|
|
self.config = config
|
|
|
|
def load_model(self, model_name):
|
|
model_config = self.config['models'][model_name]
|
|
model_type = model_config['type']
|
|
if model_type == 'tongyi':
|
|
# return Tongyi(dashscope_api_key=model_config['api_key'], model_name=model_config['model_name'])
|
|
return ChatTongyi(dashscope_api_key=model_config['api_key'], model=model_config['model_name'])
|
|
elif model_type == 'qianfan':
|
|
return QianfanLLMEndpoint(qianfan_ak=model_config['ak'], qianfan_sk=model_config['sk'], model=model_config['model_name'])
|
|
else:
|
|
raise ValueError(f"Unknown model type: {model_type}")
|
|
|
|
|
|
class ModelManager:
|
|
def __init__(self, model_loader):
|
|
self.model_loader = model_loader
|
|
self.current_model = None
|
|
|
|
def switch_model(self, model_name):
|
|
self.current_model = self.model_loader.load_model(model_name)
|
|
print(f"Switched to model: {model_name}")
|
|
|
|
def get_model(self):
|
|
return self.current_model
|
|
|
|
|