mf1
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
from types import NoneType
|
||||
from langchain_ollama.chat_models import ChatOllama
|
||||
from libs.ollama_functions import OllamaFunctions
|
||||
from libs.ollama_functions import OllamaFunctionsLSM, OllamaFunctionsT2S
|
||||
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, ToolMessage
|
||||
from libs.classes import Test, Model
|
||||
from libs.classes import Technique, Test, Model
|
||||
from langchain.tools import Tool
|
||||
from typing import Literal
|
||||
|
||||
@@ -10,22 +10,31 @@ from langgraph.graph import StateGraph, MessagesState
|
||||
import json
|
||||
from pydantic import ValidationError
|
||||
|
||||
def _get_llm(model: Model, base_url: str, seed: int, tools: list[Tool]|NoneType = None):
|
||||
if model.supports_tools:
|
||||
from suite_settings.techniques import techniques
|
||||
|
||||
def _get_llm(model: Model, base_url: str, seed: int, technique: Technique, tools: list[Tool]|NoneType = None):
|
||||
if technique == techniques[1]: # Native
|
||||
llm = ChatOllama(
|
||||
model=model.identifier,
|
||||
seed=seed,
|
||||
base_url=base_url
|
||||
)
|
||||
else:
|
||||
llm = OllamaFunctions(
|
||||
elif technique == techniques[903]: # Long System Message
|
||||
llm = OllamaFunctionsLSM(
|
||||
model=model.identifier,
|
||||
seed=seed,
|
||||
base_url=base_url,
|
||||
format="json",
|
||||
max_tool_call_fails=3,
|
||||
temperature=0.0
|
||||
)
|
||||
elif technique == techniques[572]: # ToolMessages to SystemMessages
|
||||
llm = OllamaFunctionsT2S(
|
||||
model=model.identifier,
|
||||
seed=seed,
|
||||
base_url=base_url,
|
||||
format="json",
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unkown Technique in _get_llm()")
|
||||
|
||||
if tools:
|
||||
llm = llm.bind_tools(tools=tools)
|
||||
@@ -33,7 +42,7 @@ def _get_llm(model: Model, base_url: str, seed: int, tools: list[Tool]|NoneType
|
||||
return llm
|
||||
|
||||
|
||||
def basic_prompt(model: Model, seed: int, test: Test, base_url: str) -> str:
|
||||
def basic_prompt(model: Model, seed: int, test: Test, technique: Technique, base_url: str) -> str:
|
||||
|
||||
messages = [SystemMessage(test.runnable_input['system_msg'])]
|
||||
try:
|
||||
@@ -42,20 +51,20 @@ def basic_prompt(model: Model, seed: int, test: Test, base_url: str) -> str:
|
||||
pass
|
||||
messages += [ HumanMessage(test.runnable_input['human_msg']) ]
|
||||
|
||||
llm = _get_llm(model=model, base_url=base_url, seed=seed)
|
||||
llm = _get_llm(model=model, base_url=base_url, technique=technique, seed=seed)
|
||||
ai_msg = llm.invoke(messages)
|
||||
assert isinstance(ai_msg.content, str)
|
||||
return ai_msg.content
|
||||
|
||||
|
||||
|
||||
def one_tool_call_answer(model: Model, seed: int, test: Test, base_url: str) -> dict:
|
||||
def one_tool_call_answer(model: Model, seed: int, test: Test, technique: Technique, base_url: str) -> dict:
|
||||
|
||||
tools_dict = test.runnable_input['tools']
|
||||
tools = []
|
||||
for key in tools_dict:
|
||||
tools.append(tools_dict[key])
|
||||
llm = _get_llm(model=model, base_url=base_url, seed=seed, tools=tools)
|
||||
llm = _get_llm(model=model, base_url=base_url, seed=seed, technique=technique, tools=tools)
|
||||
|
||||
messages = [SystemMessage(test.runnable_input['system_msg'])]
|
||||
try:
|
||||
@@ -108,7 +117,7 @@ def one_tool_call_answer(model: Model, seed: int, test: Test, base_url: str) ->
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
|
||||
def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict[str, str|list]:
|
||||
def agent_with_tools(model: Model, seed: int, test: Test, technique: Technique, base_url: str) -> dict[str, str|list]:
|
||||
|
||||
tool_calls = []
|
||||
index = -1
|
||||
@@ -173,7 +182,7 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict
|
||||
for key in tools_dict:
|
||||
tools.append(tools_dict[key])
|
||||
tool_node = NxToolNode(tools)
|
||||
llm = _get_llm(model=model, base_url=base_url, seed=seed, tools=tools)
|
||||
llm = _get_llm(model=model, base_url=base_url, seed=seed, technique=technique, tools=tools)
|
||||
|
||||
workflow = StateGraph(MessagesState)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user