bizwechat/tgi_app.py

43 lines
1.5 KiB
Python
Raw Permalink Normal View History

2025-02-17 10:34:35 +08:00
import fastapi
import config
from http import HTTPStatus
from pydantic import BaseModel, Field
from typing import Any
from sqlcode.qwenapi import Generator
app = fastapi.FastAPI()
class Request(BaseModel):
inputs: str = Field(description='Prompt')
parameters: Any | None = Field(default=None, description='Generation parameters')
stream: bool = Field(default=False, description='Whether to stream output tokens')
class Response(BaseModel):
generated_text: str = Field(description='Generated text')
details: Any = Field(default=None, description='Generation details')
@app.post('/{model}')
async def generate(model:str, apikey:str, req:Request) -> Response:
client = config.api_key(apikey)
if not client:
raise fastapi.HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail='invalid apikey')
message = [{'role': 'system',
'content': '你擅长编写 MySQL 的 SQL 代码,请结合具体问题编写正确规范的 SQL 代码'
}]
generator = Generator(model=model,
messages=message,
apikey=client.dashscope_api_key,
seed=0,
)
prompt = req.inputs
if (n1:=prompt.find('[QUESTION]')) >= 0 and (n2:=prompt.find('[/QUESTION]')) >= n1:
question = prompt[n1+10:n2]
else:
question = 'UNKNOWN'
ret = generator.generate(question, prompt)
return Response(generated_text=(ret.sql or ''))