mirror of
https://github.com/vegu-ai/talemate.git
synced 2025-12-16 11:47:48 +01:00
67 lines
1.9 KiB
Python
67 lines
1.9 KiB
Python
import pydantic
|
|
from openai import AsyncOpenAI
|
|
|
|
from talemate.client.base import ClientBase
|
|
from talemate.client.registry import register
|
|
|
|
|
|
class Defaults(pydantic.BaseModel):
|
|
api_url: str = "http://localhost:1234"
|
|
max_token_length: int = 4096
|
|
|
|
|
|
@register()
|
|
class TestClient(ClientBase):
|
|
client_type = "test"
|
|
|
|
class Meta(ClientBase.Meta):
|
|
name_prefix: str = "test"
|
|
title: str = "Test"
|
|
defaults: Defaults = Defaults()
|
|
|
|
def set_client(self, **kwargs):
|
|
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
|
|
|
|
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
|
"""
|
|
Talemate adds a bunch of parameters to the prompt, but not all of them are valid for all clients.
|
|
|
|
This method is called before the prompt is sent to the client, and it allows the client to remove
|
|
any parameters that it doesn't support.
|
|
"""
|
|
|
|
super().tune_prompt_parameters(parameters, kind)
|
|
|
|
keys = list(parameters.keys())
|
|
|
|
valid_keys = ["temperature", "top_p"]
|
|
|
|
for key in keys:
|
|
if key not in valid_keys:
|
|
del parameters[key]
|
|
|
|
async def get_model_name(self):
|
|
"""
|
|
This should return the name of the model that is being used.
|
|
"""
|
|
|
|
return "Mock test model"
|
|
|
|
async def generate(self, prompt: str, parameters: dict, kind: str):
|
|
"""
|
|
Generates text from the given prompt and parameters.
|
|
"""
|
|
human_message = {"role": "user", "content": prompt.strip()}
|
|
|
|
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
|
|
|
|
try:
|
|
response = await self.client.chat.completions.create(
|
|
model=self.model_name, messages=[human_message], **parameters
|
|
)
|
|
|
|
return response.choices[0].message.content
|
|
except Exception as e:
|
|
self.log.error("generate error", e=e)
|
|
return ""
|