43 lines
1.5 KiB
Python
43 lines
1.5 KiB
Python
import fastapi
|
|
import config
|
|
from http import HTTPStatus
|
|
from pydantic import BaseModel, Field
|
|
from typing import Any
|
|
from sqlcode.qwenapi import Generator
|
|
|
|
app = fastapi.FastAPI()
|
|
|
|
class Request(BaseModel):
|
|
inputs: str = Field(description='Prompt')
|
|
parameters: Any | None = Field(default=None, description='Generation parameters')
|
|
stream: bool = Field(default=False, description='Whether to stream output tokens')
|
|
|
|
class Response(BaseModel):
|
|
generated_text: str = Field(description='Generated text')
|
|
details: Any = Field(default=None, description='Generation details')
|
|
|
|
@app.post('/{model}')
|
|
async def generate(model:str, apikey:str, req:Request) -> Response:
|
|
client = config.api_key(apikey)
|
|
if not client:
|
|
raise fastapi.HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail='invalid apikey')
|
|
|
|
message = [{'role': 'system',
|
|
'content': '你擅长编写 MySQL 的 SQL 代码,请结合具体问题编写正确规范的 SQL 代码'
|
|
}]
|
|
|
|
generator = Generator(model=model,
|
|
messages=message,
|
|
apikey=client.dashscope_api_key,
|
|
seed=0,
|
|
)
|
|
prompt = req.inputs
|
|
|
|
if (n1:=prompt.find('[QUESTION]')) >= 0 and (n2:=prompt.find('[/QUESTION]')) >= n1:
|
|
question = prompt[n1+10:n2]
|
|
else:
|
|
question = 'UNKNOWN'
|
|
|
|
ret = generator.generate(question, prompt)
|
|
|
|
return Response(generated_text=(ret.sql or '')) |