bizwechat/sqlcode/modelloader.py

49 lines
1.7 KiB
Python
Raw Normal View History

2025-02-17 10:34:35 +08:00
'''
Author: scutlzc scutlzc@gmail.com
Date: 2024-07-12 09:23:37
LastEditors: scutlzc scutlzc@gmail.com
LastEditTime: 2024-07-12 09:51:30
FilePath: \bizwechat\sqlcode\modelloader.py
Description:
Copyright (c) 2024 by ${git_name_email}, All Rights Reserved.
'''
import yaml
# from langchain.llms import OpenAI
# from langchain.chains import LLamaChain, CustomChain
from langchain_community.llms import Tongyi
from langchain_community.chat_models import ChatTongyi
from sqlcode.baidu_qianfan_endpoint import QianfanLLMEndpoint
class ModelLoader:
def __init__(self, config):
# with open(config_path, 'r') as file:
# self.config = yaml.safe_load(file)
self.config = config
def load_model(self, model_name):
model_config = self.config['models'][model_name]
model_type = model_config['type']
if model_type == 'tongyi':
# return Tongyi(dashscope_api_key=model_config['api_key'], model_name=model_config['model_name'])
return ChatTongyi(dashscope_api_key=model_config['api_key'], model=model_config['model_name'])
elif model_type == 'qianfan':
return QianfanLLMEndpoint(qianfan_ak=model_config['ak'], qianfan_sk=model_config['sk'], model=model_config['model_name'])
else:
raise ValueError(f"Unknown model type: {model_type}")
class ModelManager:
def __init__(self, model_loader):
self.model_loader = model_loader
self.current_model = None
def switch_model(self, model_name):
self.current_model = self.model_loader.load_model(model_name)
print(f"Switched to model: {model_name}")
def get_model(self):
return self.current_model