70 lines
2.9 KiB
Python
70 lines
2.9 KiB
Python
|
|
'''
|
|||
|
|
Author: scutlzc scutlzc@gmail.com
|
|||
|
|
Date: 2024-07-11 16:36:34
|
|||
|
|
LastEditors: scutlzc scutlzc@gmail.com
|
|||
|
|
LastEditTime: 2024-07-12 11:03:24
|
|||
|
|
FilePath: \bizwechat\sqlcode\langchain_model.py
|
|||
|
|
Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
|
|||
|
|
'''
|
|||
|
|
from dashscope import Generation
|
|||
|
|
import random
|
|||
|
|
from http import HTTPStatus
|
|||
|
|
from enum import Enum
|
|||
|
|
from .qgi import Generator as BaseGenerator
|
|||
|
|
from langchain_community.llms import Tongyi
|
|||
|
|
from langchain_core.language_models import BaseLLM
|
|||
|
|
from langchain.prompts import PromptTemplate
|
|||
|
|
from langchain.chains.llm import LLMChain
|
|||
|
|
from enum import StrEnum
|
|||
|
|
from .modelloader import ModelManager, ModelLoader
|
|||
|
|
from langchain.agents import AgentExecutor
|
|||
|
|
from langchain_core.exceptions import OutputParserException
|
|||
|
|
from sqlcode.utils import enrich_input
|
|||
|
|
from typing import Optional
|
|||
|
|
from langgraph.graph.graph import CompiledGraph
|
|||
|
|
|
|||
|
|
class Generator(BaseGenerator):
|
|||
|
|
def __init__(self, agentExcutor:CompiledGraph, messages:list[dict[str,str]]|None, apikey:str, seed:int=0, max_retry: Optional[int] = 2) -> None:
|
|||
|
|
self.message = messages
|
|||
|
|
self.apikey = apikey
|
|||
|
|
self.seed = seed
|
|||
|
|
self.agentExcutor = agentExcutor
|
|||
|
|
self.max_retry = max_retry
|
|||
|
|
|
|||
|
|
def _generate(self, input: dict) -> tuple[HTTPStatus]:
|
|||
|
|
seed = self.seed
|
|||
|
|
if seed == 0:
|
|||
|
|
seed = random.randint(1, 10000)
|
|||
|
|
|
|||
|
|
cnt = 0
|
|||
|
|
err = None
|
|||
|
|
while cnt < self.max_retry:
|
|||
|
|
try:
|
|||
|
|
if cnt == 0:
|
|||
|
|
response = self.agentExcutor.invoke(input)["messages"][-1][1]
|
|||
|
|
res = {"input": input, "output": response, "err": None}
|
|||
|
|
else:
|
|||
|
|
new_input = enrich_input("之前agent执行流程发生错误,请模型输出严格按照prompt要求", input)
|
|||
|
|
response = self.agentExcutor.invoke(new_input)["messages"][-1][1]
|
|||
|
|
res = {"input": input, "output": response, "err": err}
|
|||
|
|
status_code = HTTPStatus.OK
|
|||
|
|
except ConnectionError as connectionError:
|
|||
|
|
status_code = HTTPStatus.REQUEST_TIMEOUT
|
|||
|
|
err = str(connectionError)
|
|||
|
|
except OutputParserException as outputException:
|
|||
|
|
cnt+=1
|
|||
|
|
err = str(outputException)
|
|||
|
|
print("outputException: {}".format(outputException))
|
|||
|
|
except KeyError as keyError:
|
|||
|
|
cnt+=1
|
|||
|
|
err = str(keyError)
|
|||
|
|
print("KeyError: {}".format(keyError))
|
|||
|
|
else:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if cnt == self.max_retry:
|
|||
|
|
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
|
|||
|
|
res = {"input": input, "output": "模型处理异常,建议您将问题表述更加清晰,再重新尝试", "err": err}
|
|||
|
|
|
|||
|
|
return status_code, res
|
|||
|
|
|