bizwechat/config/__init__.py

135 lines
3.5 KiB
Python
Raw Normal View History

2025-02-17 10:34:35 +08:00
from os import path as osp
import rtoml
from typing import Any, Optional
import requests
import time
from threading import Lock
from os import environ
from dataclasses import dataclass
ENVIRONMENT = environ.get('ENVIRONMENT', 'development')
CONFIG_DIR = osp.dirname(__file__)
BASE_DIR = osp.realpath(osp.join(CONFIG_DIR, '..'))
def config_file(*paths):
return osp.join(CONFIG_DIR, ENVIRONMENT, *paths)
def load_config(*paths) -> dict[str, Any]:
with open(config_file(*paths), 'r') as f:
return rtoml.load(f)
@dataclass
class DatabaseConfig:
connection_string: str
metadata: str
type: str
product: str
def metadata(database:str):
d = load_config('database', database+'.conf')
return DatabaseConfig(**d)
@dataclass
class QwenConfig:
system: str
prompt: str
params: dict[str, Any]
def gen_prompt(self, database:DatabaseConfig):
return self.prompt.format(database=database, **self.params)
def qwen_config(path: Optional[str] = None):
if path is not None:
d = load_config(path)
else:
d = load_config('myqwen.conf')
return QwenConfig(**d)
@dataclass
class ApiKeyConfig:
dashscope_api_key: str
admin: bool
databases: list[str]
def api_key(key:str):
d = load_config('apikeys.conf')
if key in d:
return ApiKeyConfig(**d[key])
else:
return None
@dataclass
class BizWechatConfig:
corp_id:str
corp_secret:str
agent_id:int
token:str
aes_key:str
qgi_api_key:str
@dataclass
class AppConfig:
site_url:str
def app_config():
d = load_config('app.conf')
d.pop('bizwechat')
return AppConfig(**d)
def bizwechat_config():
d = load_config('app.conf')
return BizWechatConfig(**d['bizwechat'])
@dataclass
class RefineProblemConfig:
prompt:str
def refineProblem_config(path: Optional[str] = None):
if path is not None:
d = load_config(path)
else:
d = load_config("problem_refine.conf")
return RefineProblemConfig(**d)
import yaml
def model_config():
with open(config_file("model.yaml"), 'r') as file:
return yaml.safe_load(file)
class WxBiz:
lock = Lock()
access_token:str = None
expire_time:float = 0
def wxbiz_token():
with WxBiz.lock:
cfg_file = osp.join(CONFIG_DIR, 'wxbiz-access-token')
if WxBiz.expire_time == 0 and osp.exists(cfg_file):
with open(cfg_file, 'r') as f:
token = f.read()
items = token.split(':')
WxBiz.expire_time = float(items[0])
WxBiz.access_token = items[1]
if WxBiz.expire_time < time.time():
cfg = bizwechat_config()
data = requests.get(f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={cfg.corp_id}&corpsecret={cfg.corp_secret}').json()
# 假设返回的token数据格式如下
# {
# "errcode": 0,
# "errmsg": "ok",
# "access_token": "ACCESS_TOKEN",
# "expires_in": 7200
# }
if 'errcode' in data and data['errcode'] == 0:
WxBiz.access_token = data['access_token']
WxBiz.expire_time = time.time() + data['expires_in'] - 60
with open(cfg_file, 'w') as f:
f.write(f"{WxBiz.expire_time}:{WxBiz.access_token}")
else:
raise RuntimeError("Failed to refresh token:", data)
return WxBiz.access_token