53 lines
1.5 KiB
Python
53 lines
1.5 KiB
Python
from langchain_ollama.chat_models import ChatOllama
|
|
from langchain_core.messages import SystemMessage, HumanMessage
|
|
from libs.test_class import Test
|
|
from langchain.tools import Tool
|
|
|
|
def basic(model: str, seed: int, test: Test, base_url: str) -> str:
|
|
|
|
system_msg = test.runnable_input['system_msg']
|
|
human_msg = test.runnable_input['human_msg']
|
|
|
|
if system_msg == None: prompt = [ human_msg ]
|
|
else: prompt = [ system_msg, human_msg ]
|
|
|
|
llm = ChatOllama(
|
|
model=model,
|
|
seed=seed,
|
|
base_url=base_url
|
|
)
|
|
ai_msg = llm.invoke(prompt)
|
|
return ai_msg.content
|
|
|
|
def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> str:
|
|
|
|
system_msg = test.runnable_input['system_msg']
|
|
human_msg = test.runnable_input['human_msg']
|
|
tools_dict = test.runnable_input['tools']
|
|
tools = []
|
|
for key in tools_dict:
|
|
tools.append(tools_dict[key])
|
|
|
|
if system_msg == None: prompt = [ human_msg ]
|
|
else: prompt = [ system_msg, human_msg ]
|
|
|
|
llm = ChatOllama(
|
|
model=model,
|
|
seed=seed,
|
|
base_url=base_url
|
|
).bind_tools(tools)
|
|
|
|
ai_msg = llm.invoke(prompt)
|
|
|
|
prompt.append(ai_msg)
|
|
|
|
try:
|
|
tool_call = ai_msg.tool_calls[0]
|
|
selected_tool = tools_dict[tool_call["name"].lower()]
|
|
tool_msg = selected_tool.invoke(tool_call)
|
|
prompt.append(tool_msg)
|
|
ai_msg = llm.invoke(prompt)
|
|
except IndexError:
|
|
pass
|
|
return ai_msg.content
|