This commit is contained in:
Lennart J. Kurzweg (Nx2)
2024-08-26 21:20:47 +02:00
parent 2723ced901
commit 5d7ce3cf71
12 changed files with 2055 additions and 2350 deletions

View File

@@ -1,8 +1,8 @@
from types import NoneType
from langchain_ollama.chat_models import ChatOllama
from libs.ollama_functions import OllamaFunctions
from libs.ollama_functions import OllamaFunctionsLSM, OllamaFunctionsT2S
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, ToolMessage
from libs.classes import Test, Model
from libs.classes import Technique, Test, Model
from langchain.tools import Tool
from typing import Literal
@@ -10,22 +10,31 @@ from langgraph.graph import StateGraph, MessagesState
import json
from pydantic import ValidationError
def _get_llm(model: Model, base_url: str, seed: int, tools: list[Tool]|NoneType = None):
if model.supports_tools:
from suite_settings.techniques import techniques
def _get_llm(model: Model, base_url: str, seed: int, technique: Technique, tools: list[Tool]|NoneType = None):
if technique == techniques[1]: # Native
llm = ChatOllama(
model=model.identifier,
seed=seed,
base_url=base_url
)
else:
llm = OllamaFunctions(
elif technique == techniques[903]: # Long System Message
llm = OllamaFunctionsLSM(
model=model.identifier,
seed=seed,
base_url=base_url,
format="json",
max_tool_call_fails=3,
temperature=0.0
)
elif technique == techniques[572]: # ToolMessages to SystemMessages
llm = OllamaFunctionsT2S(
model=model.identifier,
seed=seed,
base_url=base_url,
format="json",
)
else:
raise ValueError("Unkown Technique in _get_llm()")
if tools:
llm = llm.bind_tools(tools=tools)
@@ -33,7 +42,7 @@ def _get_llm(model: Model, base_url: str, seed: int, tools: list[Tool]|NoneType
return llm
def basic_prompt(model: Model, seed: int, test: Test, base_url: str) -> str:
def basic_prompt(model: Model, seed: int, test: Test, technique: Technique, base_url: str) -> str:
messages = [SystemMessage(test.runnable_input['system_msg'])]
try:
@@ -42,20 +51,20 @@ def basic_prompt(model: Model, seed: int, test: Test, base_url: str) -> str:
pass
messages += [ HumanMessage(test.runnable_input['human_msg']) ]
llm = _get_llm(model=model, base_url=base_url, seed=seed)
llm = _get_llm(model=model, base_url=base_url, technique=technique, seed=seed)
ai_msg = llm.invoke(messages)
assert isinstance(ai_msg.content, str)
return ai_msg.content
def one_tool_call_answer(model: Model, seed: int, test: Test, base_url: str) -> dict:
def one_tool_call_answer(model: Model, seed: int, test: Test, technique: Technique, base_url: str) -> dict:
tools_dict = test.runnable_input['tools']
tools = []
for key in tools_dict:
tools.append(tools_dict[key])
llm = _get_llm(model=model, base_url=base_url, seed=seed, tools=tools)
llm = _get_llm(model=model, base_url=base_url, seed=seed, technique=technique, tools=tools)
messages = [SystemMessage(test.runnable_input['system_msg'])]
try:
@@ -108,7 +117,7 @@ def one_tool_call_answer(model: Model, seed: int, test: Test, base_url: str) ->
"tool_calls": tool_calls,
}
def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict[str, str|list]:
def agent_with_tools(model: Model, seed: int, test: Test, technique: Technique, base_url: str) -> dict[str, str|list]:
tool_calls = []
index = -1
@@ -173,7 +182,7 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict
for key in tools_dict:
tools.append(tools_dict[key])
tool_node = NxToolNode(tools)
llm = _get_llm(model=model, base_url=base_url, seed=seed, tools=tools)
llm = _get_llm(model=model, base_url=base_url, seed=seed, technique=technique, tools=tools)
workflow = StateGraph(MessagesState)