bizwechat/sqlcode/langchain_model.py
2025-02-17 10:34:35 +08:00

70 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

'''
Author: scutlzc scutlzc@gmail.com
Date: 2024-07-11 16:36:34
LastEditors: scutlzc scutlzc@gmail.com
LastEditTime: 2024-07-12 11:03:24
FilePath: \bizwechat\sqlcode\langchain_model.py
Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
'''
from dashscope import Generation
import random
from http import HTTPStatus
from enum import Enum
from .qgi import Generator as BaseGenerator
from langchain_community.llms import Tongyi
from langchain_core.language_models import BaseLLM
from langchain.prompts import PromptTemplate
from langchain.chains.llm import LLMChain
from enum import StrEnum
from .modelloader import ModelManager, ModelLoader
from langchain.agents import AgentExecutor
from langchain_core.exceptions import OutputParserException
from sqlcode.utils import enrich_input
from typing import Optional
from langgraph.graph.graph import CompiledGraph
class Generator(BaseGenerator):
def __init__(self, agentExcutor:CompiledGraph, messages:list[dict[str,str]]|None, apikey:str, seed:int=0, max_retry: Optional[int] = 2) -> None:
self.message = messages
self.apikey = apikey
self.seed = seed
self.agentExcutor = agentExcutor
self.max_retry = max_retry
def _generate(self, input: dict) -> tuple[HTTPStatus]:
seed = self.seed
if seed == 0:
seed = random.randint(1, 10000)
cnt = 0
err = None
while cnt < self.max_retry:
try:
if cnt == 0:
response = self.agentExcutor.invoke(input)["messages"][-1][1]
res = {"input": input, "output": response, "err": None}
else:
new_input = enrich_input("之前agent执行流程发生错误请模型输出严格按照prompt要求", input)
response = self.agentExcutor.invoke(new_input)["messages"][-1][1]
res = {"input": input, "output": response, "err": err}
status_code = HTTPStatus.OK
except ConnectionError as connectionError:
status_code = HTTPStatus.REQUEST_TIMEOUT
err = str(connectionError)
except OutputParserException as outputException:
cnt+=1
err = str(outputException)
print("outputException: {}".format(outputException))
except KeyError as keyError:
cnt+=1
err = str(keyError)
print("KeyError: {}".format(keyError))
else:
break
if cnt == self.max_retry:
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
res = {"input": input, "output": "模型处理异常,建议您将问题表述更加清晰,再重新尝试", "err": err}
return status_code, res