bizwechat/sqlcode/langchain_model.py

70 lines
2.9 KiB
Python
Raw Permalink Normal View History

2025-02-17 10:34:35 +08:00
'''
Author: scutlzc scutlzc@gmail.com
Date: 2024-07-11 16:36:34
LastEditors: scutlzc scutlzc@gmail.com
LastEditTime: 2024-07-12 11:03:24
FilePath: \bizwechat\sqlcode\langchain_model.py
Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
'''
from dashscope import Generation
import random
from http import HTTPStatus
from enum import Enum
from .qgi import Generator as BaseGenerator
from langchain_community.llms import Tongyi
from langchain_core.language_models import BaseLLM
from langchain.prompts import PromptTemplate
from langchain.chains.llm import LLMChain
from enum import StrEnum
from .modelloader import ModelManager, ModelLoader
from langchain.agents import AgentExecutor
from langchain_core.exceptions import OutputParserException
from sqlcode.utils import enrich_input
from typing import Optional
from langgraph.graph.graph import CompiledGraph
class Generator(BaseGenerator):
def __init__(self, agentExcutor:CompiledGraph, messages:list[dict[str,str]]|None, apikey:str, seed:int=0, max_retry: Optional[int] = 2) -> None:
self.message = messages
self.apikey = apikey
self.seed = seed
self.agentExcutor = agentExcutor
self.max_retry = max_retry
def _generate(self, input: dict) -> tuple[HTTPStatus]:
seed = self.seed
if seed == 0:
seed = random.randint(1, 10000)
cnt = 0
err = None
while cnt < self.max_retry:
try:
if cnt == 0:
response = self.agentExcutor.invoke(input)["messages"][-1][1]
res = {"input": input, "output": response, "err": None}
else:
new_input = enrich_input("之前agent执行流程发生错误请模型输出严格按照prompt要求", input)
response = self.agentExcutor.invoke(new_input)["messages"][-1][1]
res = {"input": input, "output": response, "err": err}
status_code = HTTPStatus.OK
except ConnectionError as connectionError:
status_code = HTTPStatus.REQUEST_TIMEOUT
err = str(connectionError)
except OutputParserException as outputException:
cnt+=1
err = str(outputException)
print("outputException: {}".format(outputException))
except KeyError as keyError:
cnt+=1
err = str(keyError)
print("KeyError: {}".format(keyError))
else:
break
if cnt == self.max_retry:
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
res = {"input": input, "output": "模型处理异常,建议您将问题表述更加清晰,再重新尝试", "err": err}
return status_code, res