bizwechat/config/__init__.py
2025-02-17 10:34:35 +08:00

135 lines
3.5 KiB
Python

from os import path as osp
import rtoml
from typing import Any, Optional
import requests
import time
from threading import Lock
from os import environ
from dataclasses import dataclass
ENVIRONMENT = environ.get('ENVIRONMENT', 'development')
CONFIG_DIR = osp.dirname(__file__)
BASE_DIR = osp.realpath(osp.join(CONFIG_DIR, '..'))
def config_file(*paths):
return osp.join(CONFIG_DIR, ENVIRONMENT, *paths)
def load_config(*paths) -> dict[str, Any]:
with open(config_file(*paths), 'r') as f:
return rtoml.load(f)
@dataclass
class DatabaseConfig:
connection_string: str
metadata: str
type: str
product: str
def metadata(database:str):
d = load_config('database', database+'.conf')
return DatabaseConfig(**d)
@dataclass
class QwenConfig:
system: str
prompt: str
params: dict[str, Any]
def gen_prompt(self, database:DatabaseConfig):
return self.prompt.format(database=database, **self.params)
def qwen_config(path: Optional[str] = None):
if path is not None:
d = load_config(path)
else:
d = load_config('myqwen.conf')
return QwenConfig(**d)
@dataclass
class ApiKeyConfig:
dashscope_api_key: str
admin: bool
databases: list[str]
def api_key(key:str):
d = load_config('apikeys.conf')
if key in d:
return ApiKeyConfig(**d[key])
else:
return None
@dataclass
class BizWechatConfig:
corp_id:str
corp_secret:str
agent_id:int
token:str
aes_key:str
qgi_api_key:str
@dataclass
class AppConfig:
site_url:str
def app_config():
d = load_config('app.conf')
d.pop('bizwechat')
return AppConfig(**d)
def bizwechat_config():
d = load_config('app.conf')
return BizWechatConfig(**d['bizwechat'])
@dataclass
class RefineProblemConfig:
prompt:str
def refineProblem_config(path: Optional[str] = None):
if path is not None:
d = load_config(path)
else:
d = load_config("problem_refine.conf")
return RefineProblemConfig(**d)
import yaml
def model_config():
with open(config_file("model.yaml"), 'r') as file:
return yaml.safe_load(file)
class WxBiz:
lock = Lock()
access_token:str = None
expire_time:float = 0
def wxbiz_token():
with WxBiz.lock:
cfg_file = osp.join(CONFIG_DIR, 'wxbiz-access-token')
if WxBiz.expire_time == 0 and osp.exists(cfg_file):
with open(cfg_file, 'r') as f:
token = f.read()
items = token.split(':')
WxBiz.expire_time = float(items[0])
WxBiz.access_token = items[1]
if WxBiz.expire_time < time.time():
cfg = bizwechat_config()
data = requests.get(f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={cfg.corp_id}&corpsecret={cfg.corp_secret}').json()
# 假设返回的token数据格式如下
# {
# "errcode": 0,
# "errmsg": "ok",
# "access_token": "ACCESS_TOKEN",
# "expires_in": 7200
# }
if 'errcode' in data and data['errcode'] == 0:
WxBiz.access_token = data['access_token']
WxBiz.expire_time = time.time() + data['expires_in'] - 60
with open(cfg_file, 'w') as f:
f.write(f"{WxBiz.expire_time}:{WxBiz.access_token}")
else:
raise RuntimeError("Failed to refresh token:", data)
return WxBiz.access_token