This commit is contained in:
Lennart J. Kurzweg (Nx2)
2024-08-25 20:10:53 +02:00
parent a578dd26a0
commit 2723ced901
8 changed files with 307 additions and 229 deletions

View File

@@ -22,7 +22,9 @@ def _get_llm(model: Model, base_url: str, seed: int, tools: list[Tool]|NoneType
model=model.identifier,
seed=seed,
base_url=base_url,
format="json"
format="json",
max_tool_call_fails=3,
temperature=0.0
)
if tools:
@@ -75,9 +77,21 @@ def one_tool_call_answer(model: Model, seed: int, test: Test, base_url: str) ->
tool_msg = selected_tool.invoke(call)
messages.append(tool_msg)
ai_msg = llm.invoke(messages)
i = 0
while isinstance(ai_msg, SystemMessage):
i += 1
if i <= 5:
return {
"answer": ">>LLM failed to use tools<<",
"tool_calls": tool_calls,
}
messages.append(ai_msg)
ai_msg = llm.invoke(messages)
tool_calls.append({
"tool": call["name"],
"args": call["args"],
"times_failed": i
})
except IndexError: # LLM didnt use a tool -> jsut return the content
tool_calls = []
@@ -103,7 +117,6 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict
messages = state["messages"]
last_message = messages[-1]
nonlocal index
assert isinstance(last_message, AIMessage) # this is just so the type checker is happy
if last_message.tool_calls:
index += 1
return "tools"
@@ -174,9 +187,9 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict
should_continue,
)
workflow.add_edge("tools", "agent")
graph = workflow.compile()
# compose "history" supprts few shot prompting
start_messages = [SystemMessage(test.runnable_input['system_msg'])]
try:
@@ -187,11 +200,14 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict
chunks = []
for chunk in graph.stream(
{"messages": start_messages},
stream_mode="values",
):
chunks.append(chunk["messages"][-1])
try:
for chunk in graph.stream({"messages": start_messages}, stream_mode="values", config={"recursion_limit": 10}):
chunks.append(chunk["messages"][-1])
except RecursionError:
return {
"answer": ">>Model did not come to a conclusion (Recusion Error)<<",
"tool_calls": tool_calls
}
return {
"answer": chunks[-1].content,