* TabbyAPI Client Addition and presets refactoring (#126)

* feat: frequency_penalty (will make tabbyAPI custom wrapper)

* feat: add FREQUENCY_PENALTY_BASE and adj. conversation template

* feat: use `client_type` of `openai_compat` to send FIXED preset

* change from client name

* feat: pass client_type into presets.configure(...)

* wip: base TabbyAPI client

* feat: add import to register TabbyAPI client

* feat: adjust `presence_penalty` so it has a range of 0.1-0.5 (higher values will likely degrade performance)

* feat: add additional samplers/settings for TabbyAPI

* feat: keep presence_penalty in a range of 0.1-0.5

* feat: keep min_p in a range of 0.05 to 0.15

* update tabbyapi.py

* feat: add MIN_P_BASE and TEMP_LAST and change to tabbyapi client only for now

* fix: add /v1 as default API route to TabbyAPI

* feat: implement CustomAPIClient to allow all TabbyAPI parameters

* fix: change to "temperature_last" instead of "temp_last"

* feat: convert presets to dictionary mappings to make cleaner/more flexible

* fix: account for original substring/in statements and remove TabbyAPI client call

* fix: move down returning token values as it realistically should never be none, so substrings wouldn't be checked

* chore: remove automatic 'token' import due to IDE

---------

Co-authored-by: vegu-ai-tools <152010387+vegu-ai-tools@users.noreply.github.com>

* tabbyapi client auto-set model name
tabbyapi client use urljoin to prevent errors when user adds trailing slash

* expose presets to config and ux for editing

* some more help text

* tweak min, max and step size for some of the inference parameter sliders

* min_p step size to 0.01

* preset editor - allow reset to defaults

* fix preset reset

* dont perist inference_defaults to config file

* only persist presets to config if they have been changed

* ensure defaults are loaded

* rename config to parameters for more clarity

* update default inference params
textgenwebui support for min_p, frequence_penalty and presence_penalty

* overridable function to clean promp params

* add `supported_parameters` class property to clients and revisit all of the clients to add any missing supported parameters

* ux tweaks

* support_parameters moved to propert function

* top p decrease step size

* only show audio stop button if there is actually audio playing

* relock

* allow setting presence and frequency penalty to 0

* lower default frequency penalty

* frequency and presence penalty step size to 0.01

* set default model to gpt-4o

---------

Co-authored-by: official-elinas <57051565+official-elinas@users.noreply.github.com>
This commit is contained in:
veguAI
2024-05-31 13:07:57 +03:00
committed by GitHub
parent 9a2bbd78a4
commit cdcc804ffa
26 changed files with 1321 additions and 816 deletions

890
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -21,6 +21,7 @@ from talemate.emit import emit
from talemate.events import GameLoopEvent
from talemate.prompts import Prompt
from talemate.scene_message import CharacterMessage, DirectorMessage
from talemate.exceptions import LLMAccuracyError
from .base import (
Agent,
@@ -661,6 +662,11 @@ class ConversationAgent(Agent):
empty_result_count += 1
if empty_result_count >= 2:
break
# if result is empty, raise an error
if not total_result:
raise LLMAccuracyError("Received empty response from AI")
result = result.replace(" :", ":")

View File

@@ -134,6 +134,10 @@ class EditorAgent(Agent):
if not self.actions["fix_exposition"].enabled:
return content
# if not content was generated, return it as is
if not content:
return content
if not character.is_player:
if '"' not in content and "*" not in content:

View File

@@ -10,5 +10,6 @@ from talemate.client.lmstudio import LMStudioClient
from talemate.client.mistral import MistralAIClient
from talemate.client.openai import OpenAIClient
from talemate.client.openai_compat import OpenAICompatibleClient
from talemate.client.tabbyapi import TabbyAPIClient
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
from talemate.client.textgenwebui import TextGeneratorWebuiClient

View File

@@ -58,6 +58,15 @@ class AnthropicClient(ClientBase):
def anthropic_api_key(self):
return self.config.get("anthropic", {}).get("api_key")
@property
def supported_parameters(self):
return [
"temperature",
"top_p",
"top_k",
"max_tokens",
]
def emit_status(self, processing: bool = None):
error_action = None
if processing is not None:
@@ -160,14 +169,6 @@ class AnthropicClient(ClientBase):
return prompt
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "top_p", "max_tokens"]
for key in keys:
if key not in valid_keys:
del parameters[key]
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.

View File

@@ -41,6 +41,7 @@ class PromptData(pydantic.BaseModel):
time: Union[float, int]
agent_stack: list[str] = pydantic.Field(default_factory=list)
generation_parameters: dict = pydantic.Field(default_factory=dict)
inference_preset: str = None
class ErrorAction(pydantic.BaseModel):
@@ -63,6 +64,20 @@ class ExtraField(pydantic.BaseModel):
required: bool
description: str
class ParameterReroute(pydantic.BaseModel):
talemate_parameter: str
client_parameter: str
def reroute(self, parameters: dict):
if self.talemate_parameter in parameters:
parameters[self.client_parameter] = parameters[self.talemate_parameter]
del parameters[self.talemate_parameter]
def __str__(self):
return self.client_parameter
def __eq__(self, other):
return str(self) == str(other)
class ClientBase:
api_url: str
@@ -81,6 +96,7 @@ class ClientBase:
finalizers: list[str] = []
double_coercion: Union[str, None] = None
client_type = "base"
class Meta(pydantic.BaseModel):
experimental: Union[None, str] = None
@@ -126,6 +142,15 @@ class ClientBase:
def max_tokens_param_name(self):
return "max_tokens"
@property
def supported_parameters(self):
# each client should override this with the parameters it supports
return [
"temperature",
"max_tokens",
]
def set_client(self, **kwargs):
self.client = AsyncOpenAI(base_url=self.api_url, api_key="sk-1111")
@@ -236,8 +261,6 @@ class ClientBase:
if "narrate" in kind:
return system_prompts.NARRATOR
if "story" in kind:
return system_prompts.NARRATOR
if "director" in kind:
return system_prompts.DIRECTOR
if "create" in kind:
@@ -269,8 +292,6 @@ class ClientBase:
if "narrate" in kind:
return system_prompts.NARRATOR_NO_DECENSOR
if "story" in kind:
return system_prompts.NARRATOR_NO_DECENSOR
if "director" in kind:
return system_prompts.DIRECTOR_NO_DECENSOR
if "create" in kind:
@@ -428,7 +449,7 @@ class ClientBase:
def generate_prompt_parameters(self, kind: str):
parameters = {}
self.tune_prompt_parameters(
presets.configure(parameters, kind, self.max_token_length), kind
presets.configure(parameters, kind, self.max_token_length, self), kind
)
return parameters
@@ -469,6 +490,21 @@ class ClientBase:
else:
parameters["extra_stopping_strings"] = dialog_stopping_strings
def clean_prompt_parameters(self, parameters: dict):
"""
Does some final adjustments to the prompt parameters before sending
"""
# apply any parameter reroutes
for param in self.supported_parameters:
if isinstance(param, ParameterReroute):
param.reroute(parameters)
# drop any parameters that are not supported by the client
for key in list(parameters.keys()):
if key not in self.supported_parameters:
del parameters[key]
def finalize(self, parameters: dict, prompt: str):
prompt = util.replace_special_tokens(prompt)
@@ -478,6 +514,7 @@ class ClientBase:
prompt, applied = fn(parameters, prompt)
if applied:
return prompt
return prompt
async def generate(self, prompt: str, parameters: dict, kind: str):
@@ -537,6 +574,8 @@ class ClientBase:
time_start = time.time()
extra_stopping_strings = prompt_param.pop("extra_stopping_strings", [])
self.clean_prompt_parameters(prompt_param)
self.log.debug(
"send_prompt",
@@ -577,6 +616,7 @@ class ClientBase:
client_type=self.client_type,
time=time_end - time_start,
generation_parameters=prompt_param,
inference_preset=client_context_attribute("inference_preset"),
).model_dump(),
)

View File

@@ -2,7 +2,7 @@ import pydantic
import structlog
from cohere import AsyncClient
from talemate.client.base import ClientBase, ErrorAction
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute
from talemate.client.registry import register
from talemate.config import load_config
from talemate.emit import emit
@@ -58,6 +58,18 @@ class CohereClient(ClientBase):
def cohere_api_key(self):
return self.config.get("cohere", {}).get("api_key")
@property
def supported_parameters(self):
return [
"temperature",
ParameterReroute(talemate_parameter="top_p", client_parameter="p"),
ParameterReroute(talemate_parameter="top_k", client_parameter="k"),
ParameterReroute(talemate_parameter="stopping_strings", client_parameter="stop_sequences"),
"frequency_penalty",
"presence_penalty",
"max_tokens",
]
def emit_status(self, processing: bool = None):
error_action = None
if processing is not None:
@@ -160,18 +172,22 @@ class CohereClient(ClientBase):
return prompt
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "max_tokens"]
for key in keys:
if key not in valid_keys:
del parameters[key]
def clean_prompt_parameters(self, parameters: dict):
super().clean_prompt_parameters(parameters)
# if temperature is set, it needs to be clamped between 0 and 1.0
if "temperature" in parameters:
parameters["temperature"] = max(0.0, min(1.0, parameters["temperature"]))
# if stop_sequences is set, max 5 items
if "stop_sequences" in parameters:
parameters["stop_sequences"] = parameters["stop_sequences"][:5]
# if both frequency_penalty and presence_penalty are set, drop frequency_penalty
if "presence_penalty" in parameters and "frequency_penalty" in parameters:
del parameters["frequency_penalty"]
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.

View File

@@ -38,6 +38,7 @@ class ContextModel(BaseModel):
nuke_repetition: float = Field(0.0, ge=0.0, le=3.0)
conversation: ConversationContext = Field(default_factory=ConversationContext)
length: int = 96
inference_preset: str = None
# Define the context variable as an empty dictionary

View File

@@ -10,9 +10,10 @@ from vertexai.generative_models import (
GenerativeModel,
ResponseValidationError,
SafetySetting,
GenerationConfig,
)
from talemate.client.base import ClientBase, ErrorAction, ExtraField
from talemate.client.base import ClientBase, ErrorAction, ExtraField, ParameterReroute
from talemate.client.registry import register
from talemate.client.remote import RemoteServiceMixin
from talemate.config import Client as BaseClientConfig
@@ -54,7 +55,7 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
auto_break_repetition_enabled = False
decensor_enabled = True
config_cls = ClientConfig
class Meta(ClientBase.Meta):
name_prefix: str = "Google"
title: str = "Google"
@@ -140,6 +141,16 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
return safety_settings
@property
def supported_parameters(self):
return [
"temperature",
"top_p",
"top_k",
ParameterReroute(talemate_parameter="max_tokens", client_parameter="max_output_tokens"),
ParameterReroute(talemate_parameter="stopping_strings", client_parameter="stop_sequences"),
]
def emit_status(self, processing: bool = None):
error_action = None
if processing is not None:
@@ -237,11 +248,19 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
if "disable_safety_settings" in kwargs:
self.disable_safety_settings = kwargs["disable_safety_settings"]
def clean_prompt_parameters(self, parameters: dict):
super().clean_prompt_parameters(parameters)
log.warning("clean_prompt_parameters", parameters=parameters)
# if top_k is 0, remove it
if "top_k" in parameters and parameters["top_k"] == 0:
del parameters["top_k"]
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
if not self.ready:
raise Exception("Google cloud setup incomplete")
@@ -268,10 +287,11 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
try:
chat = self.model_instance.start_chat()
response = await chat.send_message_async(
human_message,
safety_settings=self.safety_settings,
generation_config=parameters,
)
self._returned_prompt_tokens = self.prompt_tokens(prompt)

View File

@@ -2,7 +2,7 @@ import pydantic
import structlog
from groq import AsyncGroq, PermissionDeniedError
from talemate.client.base import ClientBase, ErrorAction
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute
from talemate.client.registry import register
from talemate.config import load_config
from talemate.emit import emit
@@ -60,6 +60,17 @@ class GroqClient(ClientBase):
def groq_api_key(self):
return self.config.get("groq", {}).get("api_key")
@property
def supported_parameters(self):
return [
"temperature",
"top_p",
"presence_penalty",
"frequency_penalty",
ParameterReroute(talemate_parameter="stopping_strings", client_parameter="stop"),
"max_tokens",
]
def emit_status(self, processing: bool = None):
error_action = None
if processing is not None:
@@ -162,14 +173,6 @@ class GroqClient(ClientBase):
return prompt
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "top_p", "max_tokens"]
for key in keys:
if key not in valid_keys:
del parameters[key]
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.

View File

@@ -7,7 +7,7 @@ from urllib.parse import urljoin, urlparse
import httpx
import structlog
from talemate.client.base import STOPPING_STRINGS, ClientBase, Defaults, ExtraField
from talemate.client.base import STOPPING_STRINGS, ClientBase, Defaults, ParameterReroute
from talemate.client.registry import register
import talemate.util as util
@@ -81,6 +81,33 @@ class KoboldCppClient(ClientBase):
else:
return "max_length"
@property
def supported_parameters(self):
if not self.is_openai:
# koboldcpp united api
return [
ParameterReroute(talemate_parameter="max_tokens", client_parameter="max_length"),
"max_context_length",
ParameterReroute(talemate_parameter="repetition_penalty", client_parameter="rep_pen"),
ParameterReroute(talemate_parameter="repetition_penalty_range", client_parameter="rep_pen_range"),
"top_p",
"top_k",
ParameterReroute(talemate_parameter="stopping_strings", client_parameter="stop_sequence"),
"temperature",
]
else:
# openai api
return [
"max_tokens",
"presence_penalty",
"top_p",
"temperature",
]
def api_endpoint_specified(self, url: str) -> bool:
return "/v1" in self.api_url
@@ -97,51 +124,11 @@ class KoboldCppClient(ClientBase):
super().__init__(**kwargs)
self.ensure_api_endpoint_specified()
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
if not self.is_openai:
# adjustments for united api
parameters["max_length"] = parameters.pop("max_tokens")
parameters["max_context_length"] = self.max_token_length
if "repetition_penalty_range" in parameters:
parameters["rep_pen_range"] = parameters.pop("repetition_penalty_range")
if "repetition_penalty" in parameters:
parameters["rep_pen"] = parameters.pop("repetition_penalty")
if parameters.get("stop_sequence"):
parameters["stop_sequence"] = parameters.pop("stopping_strings")
if parameters.get("extra_stopping_strings"):
if "stop_sequence" in parameters:
parameters["stop_sequence"] += parameters.pop("extra_stopping_strings")
else:
parameters["stop_sequence"] = parameters.pop("extra_stopping_strings")
allowed_params = [
"max_length",
"max_context_length",
"rep_pen",
"rep_pen_range",
"top_p",
"top_k",
"temperature",
"stop_sequence",
]
else:
allowed_params = ["max_tokens", "presence_penalty", "top_p", "temperature"]
# drop unsupported params
for param in list(parameters.keys()):
if param not in allowed_params:
del parameters[param]
def set_client(self, **kwargs):
self.api_key = kwargs.get("api_key", self.api_key)
self.ensure_api_endpoint_specified()
async def get_model_name(self):
self.ensure_api_endpoint_specified()
async with httpx.AsyncClient() as client:

View File

@@ -1,7 +1,7 @@
import pydantic
from openai import AsyncOpenAI
from talemate.client.base import ClientBase
from talemate.client.base import ClientBase, ParameterReroute
from talemate.client.registry import register
@@ -14,26 +14,25 @@ class Defaults(pydantic.BaseModel):
class LMStudioClient(ClientBase):
auto_determine_prompt_template: bool = True
client_type = "lmstudio"
class Meta(ClientBase.Meta):
name_prefix: str = "LMStudio"
title: str = "LMStudio"
defaults: Defaults = Defaults()
@property
def supported_parameters(self):
return [
"temperature",
"top_p",
"frequency_penalty",
"presence_penalty",
ParameterReroute(talemate_parameter="stopping_strings", client_parameter="stop"),
]
def set_client(self, **kwargs):
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "top_p"]
for key in keys:
if key not in valid_keys:
del parameters[key]
async def get_model_name(self):
model_name = await super().get_model_name()

View File

@@ -4,7 +4,7 @@ from mistralai.async_client import MistralAsyncClient
from mistralai.exceptions import MistralAPIStatusException
from mistralai.models.chat_completion import ChatMessage
from talemate.client.base import ClientBase, ErrorAction
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute
from talemate.client.registry import register
from talemate.config import load_config
from talemate.emit import emit
@@ -69,6 +69,14 @@ class MistralAIClient(ClientBase):
def mistralai_api_key(self):
return self.config.get("mistralai", {}).get("api_key")
@property
def supported_parameters(self):
return [
"temperature",
"top_p",
"max_tokens",
]
def emit_status(self, processing: bool = None):
error_action = None
if processing is not None:
@@ -171,17 +179,10 @@ class MistralAIClient(ClientBase):
return prompt
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "top_p", "max_tokens"]
for key in keys:
if key not in valid_keys:
del parameters[key]
def clean_prompt_parameters(self, parameters: dict):
super().clean_prompt_parameters(parameters)
# clamp temperature to 0.1 and 1.0
# Unhandled Error: Status: 422. Message: {"object":"error","message":{"detail":[{"type":"less_than_equal","loc":["body","temperature"],"msg":"Input should be less than or equal to 1","input":1.31,"ctx":{"le":1.0},"url":"https://errors.pydantic.dev/2.6/v/less_than_equal"}]},"type":"invalid_request_error","param":null,"code":null}
if "temperature" in parameters:
parameters["temperature"] = min(1.0, max(0.1, parameters["temperature"]))

View File

@@ -94,7 +94,7 @@ def num_tokens_from_messages(messages: list[dict], model: str = "gpt-3.5-turbo-0
class Defaults(pydantic.BaseModel):
max_token_length: int = 16384
model: str = "gpt-4-turbo"
model: str = "gpt-4o"
@register()
@@ -117,7 +117,7 @@ class OpenAIClient(ClientBase):
requires_prompt_template: bool = False
defaults: Defaults = Defaults()
def __init__(self, model="gpt-4-turbo", **kwargs):
def __init__(self, model="gpt-4o", **kwargs):
self.model_name = model
self.api_key_status = None
self.config = load_config()
@@ -129,6 +129,15 @@ class OpenAIClient(ClientBase):
def openai_api_key(self):
return self.config.get("openai", {}).get("api_key")
@property
def supported_parameters(self):
return [
"temperature",
"top_p",
"presence_penalty",
"max_tokens",
]
def emit_status(self, processing: bool = None):
error_action = None
if processing is not None:
@@ -241,26 +250,6 @@ class OpenAIClient(ClientBase):
return prompt
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "top_p"]
# GPT-3.5 models tend to run away with the generated
# response size so we allow talemate to set the max_tokens
#
# GPT-4 on the other hand seems to benefit from letting it
# decide the generation length naturally and it will generally
# produce reasonably sized responses
if self.model_name.startswith("gpt-3.5-"):
valid_keys.append("max_tokens")
for key in keys:
if key not in valid_keys:
del parameters[key]
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.

View File

@@ -32,7 +32,7 @@ class OpenAICompatibleClient(ClientBase):
client_type = "openai_compat"
conversation_retries = 0
config_cls = ClientConfig
class Meta(ClientBase.Meta):
title: str = "OpenAI Compatible API"
name_prefix: str = "OpenAI Compatible API"
@@ -70,6 +70,15 @@ class OpenAICompatibleClient(ClientBase):
"""
return not self.api_handles_prompt_template
@property
def supported_parameters(self):
return [
"temperature",
"top_p",
"presence_penalty",
"max_tokens",
]
def set_client(self, **kwargs):
self.api_key = kwargs.get("api_key", self.api_key)
self.api_handles_prompt_template = kwargs.get(
@@ -81,16 +90,6 @@ class OpenAICompatibleClient(ClientBase):
kwargs.get("model") or kwargs.get("model_name") or self.model_name
)
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
allowed_params = ["max_tokens", "presence_penalty", "top_p", "temperature"]
# drop unsupported params
for param in list(parameters.keys()):
if param not in allowed_params:
del parameters[param]
def prompt_template(self, system_message: str, prompt: str):
log.debug(

View File

@@ -1,3 +1,12 @@
from typing import TYPE_CHECKING
from talemate.config import load_config, InferencePresets
from talemate.emit.signals import handlers
from talemate.client.context import set_client_context_attribute
import structlog
if TYPE_CHECKING:
from talemate.client.base import ClientBase
__all__ = [
"configure",
"set_max_tokens",
@@ -11,228 +20,152 @@ __all__ = [
"PRESET_SIMPLE_1",
]
# TODO: refactor abstraction and make configurable
log = structlog.get_logger("talemate.client.presets")
PRESENCE_PENALTY_BASE = 0.2
config = load_config(as_model=True)
PRESET_TALEMATE_CONVERSATION = {
"temperature": 0.65,
"top_p": 0.47,
"top_k": 42,
"presence_penalty": PRESENCE_PENALTY_BASE,
"repetition_penalty": 1.18,
"repetition_penalty_range": 2048,
}
PRESET_TALEMATE_CREATOR = {
"temperature": 0.7,
"top_p": 0.9,
"top_k": 20,
"presence_penalty": PRESENCE_PENALTY_BASE,
"repetition_penalty": 1.15,
"repetition_penalty_range": 512,
}
PRESET_LLAMA_PRECISE = {
"temperature": 0.7,
"top_p": 0.1,
"top_k": 40,
"presence_penalty": PRESENCE_PENALTY_BASE,
"repetition_penalty": 1.18,
}
PRESET_DETERMINISTIC = {
"temperature": 0.1,
"top_p": 1,
"top_k": 0,
"repetition_penalty": 1.0,
}
PRESET_DIVINE_INTELLECT = {
"temperature": 1.31,
"top_p": 0.14,
"top_k": 49,
"presence_penalty": PRESENCE_PENALTY_BASE,
"repetition_penalty_range": 1024,
"repetition_penalty": 1.17,
}
PRESET_SIMPLE_1 = {
"temperature": 0.7,
"top_p": 0.9,
"top_k": 20,
"presence_penalty": PRESENCE_PENALTY_BASE,
"repetition_penalty": 1.15,
}
PRESET_ANALYTICAL = {
"temperature": 0.1,
"top_p": 0.9,
"top_k": 20,
# Load the config
CONFIG = {
"inference": config.presets.inference,
}
def configure(config: dict, kind: str, total_budget: int):
# Sync the config when it is saved
def sync_config(event):
CONFIG["inference"] = InferencePresets(
**event.data.get("presets", {}).get("inference", {})
)
handlers["config_saved"].connect(sync_config)
def get_inference_parameters(preset_name: str) -> dict:
"""
Returns the inference parameters for the given preset name.
"""
presets = CONFIG["inference"].model_dump()
if preset_name in presets:
return presets[preset_name]
raise ValueError(f"Preset name {preset_name} not found in presets.inference")
def configure(parameters: dict, kind: str, total_budget: int, client: "ClientBase"):
"""
Sets the config based on the kind of text to generate.
"""
set_preset(config, kind)
set_max_tokens(config, kind, total_budget)
return config
set_preset(parameters, kind, client)
set_max_tokens(parameters, kind, total_budget)
return parameters
def set_max_tokens(config: dict, kind: str, total_budget: int):
def set_max_tokens(parameters: dict, kind: str, total_budget: int):
"""
Sets the max_tokens in the config based on the kind of text to generate.
"""
config["max_tokens"] = max_tokens_for_kind(kind, total_budget)
return config
parameters["max_tokens"] = max_tokens_for_kind(kind, total_budget)
return parameters
def set_preset(config: dict, kind: str):
def set_preset(parameters: dict, kind: str, client: "ClientBase"):
"""
Sets the preset in the config based on the kind of text to generate.
"""
config.update(preset_for_kind(kind))
parameters.update(preset_for_kind(kind, client))
def preset_for_kind(kind: str):
# TODO: can this just be checking all keys in inference.presets.inference?
PRESET_SUBSTRING_MAPPINGS = {
"deterministic": "deterministic",
"creative": "creative",
"analytical": "analytical",
"analyze": "analytical",
}
# tag based
if "deterministic" in kind:
return PRESET_DETERMINISTIC
elif "creative" in kind:
return PRESET_DIVINE_INTELLECT
elif "simple" in kind:
return PRESET_SIMPLE_1
elif "analytical" in kind:
return PRESET_ANALYTICAL
elif kind == "conversation":
return PRESET_TALEMATE_CONVERSATION
elif kind == "conversation_old":
return PRESET_TALEMATE_CONVERSATION # Assuming old conversation uses the same preset
elif kind == "conversation_long":
return PRESET_TALEMATE_CONVERSATION # Assuming long conversation uses the same preset
elif kind == "conversation_select_talking_actor":
return PRESET_TALEMATE_CONVERSATION # Assuming select talking actor uses the same preset
elif kind == "summarize":
return PRESET_LLAMA_PRECISE
elif kind == "analyze":
return PRESET_SIMPLE_1
elif kind == "analyze_creative":
return PRESET_DIVINE_INTELLECT
elif kind == "analyze_long":
return PRESET_SIMPLE_1 # Assuming long analysis uses the same preset as simple
elif kind == "analyze_freeform":
return PRESET_LLAMA_PRECISE
elif kind == "analyze_freeform_short":
return PRESET_LLAMA_PRECISE # Assuming short freeform analysis uses the same preset as precise
elif kind == "narrate":
return PRESET_LLAMA_PRECISE
elif kind == "story":
return PRESET_DIVINE_INTELLECT
elif kind == "create":
return PRESET_TALEMATE_CREATOR
elif kind == "create_concise":
return PRESET_TALEMATE_CREATOR # Assuming concise creation uses the same preset as creator
elif kind == "create_precise":
return PRESET_LLAMA_PRECISE
elif kind == "director":
return PRESET_SIMPLE_1
elif kind == "director_short":
return (
PRESET_SIMPLE_1 # Assuming short direction uses the same preset as simple
)
elif kind == "director_yesno":
return (
PRESET_SIMPLE_1 # Assuming yes/no direction uses the same preset as simple
)
elif kind == "edit_dialogue":
return PRESET_DIVINE_INTELLECT
elif kind == "edit_add_detail":
return PRESET_DIVINE_INTELLECT # Assuming adding detail uses the same preset as divine intellect
elif kind == "edit_fix_exposition":
return PRESET_DETERMINISTIC # Assuming fixing exposition uses the same preset as divine intellect
elif kind == "edit_fix_continuity":
return PRESET_DETERMINISTIC
elif kind == "visualize":
return PRESET_SIMPLE_1
else:
return PRESET_SIMPLE_1 # Default preset if none of the kinds match
PRESET_MAPPING = {
"conversation": "conversation",
"conversation_select_talking_actor": "analytical",
"summarize": "summarization",
"analyze": "analytical",
"analyze_long": "analytical",
"analyze_freeform": "analytical",
"analyze_freeform_short": "analytical",
"narrate": "creative",
"create": "creative_instruction",
"create_short": "creative_instruction",
"create_concise": "creative_instruction",
"director": "scene_direction",
"edit_add_detail": "creative",
"edit_fix_exposition": "deterministic",
"edit_fix_continuity": "deterministic",
"visualize": "creative_instruction",
}
def max_tokens_for_kind(kind: str, total_budget: int):
if kind == "conversation":
return 75
elif kind == "conversation_old":
return 75
elif kind == "conversation_long":
return 300
elif kind == "conversation_select_talking_actor":
return 30
elif kind == "summarize":
return 500
elif kind == "analyze":
return 500
elif kind == "analyze_creative":
return 1024
elif kind == "analyze_long":
return 2048
elif kind == "analyze_freeform":
return 500
elif kind == "analyze_freeform_medium":
return 192
elif kind == "analyze_freeform_medium_short":
return 128
elif kind == "analyze_freeform_short":
return 10
elif kind == "narrate":
return 500
elif kind == "story":
return 300
elif kind == "create":
return min(1024, int(total_budget * 0.35))
elif kind == "create_concise":
return min(400, int(total_budget * 0.25))
elif kind == "create_precise":
return min(400, int(total_budget * 0.25))
elif kind == "create_short":
return 25
elif kind == "director":
return min(192, int(total_budget * 0.25))
elif kind == "director_short":
return 25
elif kind == "director_yesno":
return 2
elif kind == "edit_dialogue":
return 100
elif kind == "edit_add_detail":
return 200
elif kind == "edit_fix_exposition":
return 1024
elif kind == "edit_fix_continuity":
return 512
elif kind == "visualize":
return 150
# tag based
elif "extensive" in kind:
return 2048
elif "long" in kind:
return 1024
elif "medium2" in kind:
return 512
elif "medium" in kind:
return 192
elif "short2" in kind:
return 128
elif "short" in kind:
return 75
elif "tiny2" in kind:
return 25
elif "tiny" in kind:
return 10
elif "yesno" in kind:
return 2
else:
return 150 # Default value if none of the kinds match
def preset_for_kind(kind: str, client: "ClientBase") -> dict:
# Check the substrings first(based on order of the original elifs)
preset_name = None
preset_name = PRESET_MAPPING.get(kind)
if not preset_name:
for substring, value in PRESET_SUBSTRING_MAPPINGS.items():
if substring in kind:
preset_name = value
if not preset_name:
log.warning(f"No preset found for kind {kind}, defaulting to 'scene_direction'", presets=CONFIG["inference"])
preset_name = "scene_direction"
set_client_context_attribute("inference_preset", preset_name)
return get_inference_parameters(preset_name)
TOKEN_MAPPING = {
"conversation": 75,
"conversation_select_talking_actor": 30,
"summarize": 500,
"analyze": 500,
"analyze_long": 2048,
"analyze_freeform": 500,
"analyze_freeform_medium": 192,
"analyze_freeform_medium_short": 128,
"analyze_freeform_short": 10,
"narrate": 500,
"story": 300,
"create": lambda total_budget: min(1024, int(total_budget * 0.35)),
"create_concise": lambda total_budget: min(400, int(total_budget * 0.25)),
"create_short": 25,
"director": lambda total_budget: min(192, int(total_budget * 0.25)),
"edit_add_detail": 200,
"edit_fix_exposition": 1024,
"edit_fix_continuity": 512,
"visualize": 150,
}
TOKEN_SUBSTRING_MAPPINGS = {
"extensive": 2048,
"long": 1024,
"medium2": 512,
"medium": 192,
"short2": 128,
"short": 75,
"tiny2": 25,
"tiny": 10,
"yesno": 2,
}
def max_tokens_for_kind(kind: str, total_budget: int) -> int:
token_value = TOKEN_MAPPING.get(kind)
if callable(token_value):
return token_value(total_budget)
# If no exact match, check for substrings (order of original elifs)
for substring, value in TOKEN_SUBSTRING_MAPPINGS.items():
if substring in kind:
return value
if token_value is not None:
return token_value
return 150 # Default value if none of the kinds match

View File

@@ -23,13 +23,5 @@ class RemoteServiceMixin:
self.config = config
self.set_client(max_token_length=self.max_token_length)
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "max_tokens"]
for key in keys:
if key not in valid_keys:
del parameters[key]
async def status(self):
self.emit_status()

View File

@@ -0,0 +1,236 @@
import urllib
import aiohttp
import random
import pydantic
import structlog
from openai import AsyncOpenAI, NotFoundError, PermissionDeniedError
from talemate.client.base import ClientBase, ExtraField
from talemate.client.registry import register
from talemate.config import Client as BaseClientConfig
from talemate.emit import emit
from talemate.client.utils import urljoin
log = structlog.get_logger("talemate.client.tabbyapi")
EXPERIMENTAL_DESCRIPTION = """Use this client to use all of TabbyAPI's features"""
class CustomAPIClient:
def __init__(self, base_url, api_key):
self.base_url = base_url
self.api_key = api_key
async def get_model_name(self):
url = urljoin(self.base_url, "model")
headers = {
"x-api-key": self.api_key,
}
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
if response.status != 200:
raise Exception(f"Request failed: {response.status}")
response_data = await response.json()
model_name = response_data.get("id")
# split by "/" and take last
if model_name:
model_name = model_name.split("/")[-1]
return model_name
async def create_chat_completion(self, model, messages, **parameters):
url = urljoin(self.base_url, "chat/completions")
headers = {
"x-api-key": self.api_key,
"Content-Type": "application/json",
}
data = {
"model": model,
"messages": messages,
**parameters,
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
if response.status != 200:
raise Exception(f"Request failed: {response.status}")
return await response.json()
async def create_completion(self, model, **parameters):
url = urljoin(self.base_url, "completions")
headers = {
"x-api-key": self.api_key,
"Content-Type": "application/json",
}
data = {
"model": model,
**parameters,
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
if response.status != 200:
raise Exception(f"Request failed: {response.status}")
return await response.json()
class Defaults(pydantic.BaseModel):
api_url: str = "http://localhost:5000/v1"
api_key: str = ""
max_token_length: int = 8192
model: str = ""
api_handles_prompt_template: bool = False
double_coercion: str = None
class ClientConfig(BaseClientConfig):
api_handles_prompt_template: bool = False
@register()
class TabbyAPIClient(ClientBase):
client_type = "tabbyapi"
conversation_retries = 0
config_cls = ClientConfig
class Meta(ClientBase.Meta):
title: str = "TabbyAPI"
name_prefix: str = "TabbyAPI"
experimental: str = EXPERIMENTAL_DESCRIPTION
enable_api_auth: bool = True
manual_model: bool = False
defaults: Defaults = Defaults()
extra_fields: dict[str, ExtraField] = {
"api_handles_prompt_template": ExtraField(
name="api_handles_prompt_template",
type="bool",
label="API handles prompt template (chat/completions)",
required=False,
description="The API handles the prompt template, meaning your choice in the UI for the prompt template below will be ignored. This is not recommended and should only be used if the API does not support the `completions` endpoint or you don't know which prompt template to use.",
)
}
def __init__(self, model=None, api_key=None, api_handles_prompt_template=False, **kwargs):
self.model_name = model
self.api_key = api_key
self.api_handles_prompt_template = api_handles_prompt_template
super().__init__(**kwargs)
@property
def experimental(self):
return EXPERIMENTAL_DESCRIPTION
@property
def can_be_coerced(self):
"""
Determines whether or not this client can pass LLM coercion. (e.g., is able to predefine partial LLM output in the prompt)
"""
return not self.api_handles_prompt_template
@property
def supported_parameters(self):
return [
"max_tokens",
"presence_penalty",
"frequency_penalty",
"repetition_penalty_range",
"min_p",
"top_p",
"temperature_last",
"temperature"
]
def set_client(self, **kwargs):
self.api_key = kwargs.get("api_key", self.api_key)
self.api_handles_prompt_template = kwargs.get("api_handles_prompt_template", self.api_handles_prompt_template)
self.client = CustomAPIClient(base_url=self.api_url, api_key=self.api_key)
self.model_name = kwargs.get("model") or kwargs.get("model_name") or self.model_name
def prompt_template(self, system_message: str, prompt: str):
log.debug(
"IS API HANDLING PROMPT TEMPLATE",
api_handles_prompt_template=self.api_handles_prompt_template,
)
if not self.api_handles_prompt_template:
return super().prompt_template(system_message, prompt)
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if right:
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
else:
prompt = prompt.replace("<|BOT|>", "")
return prompt
async def get_model_name(self):
return await self.client.get_model_name()
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
try:
if self.api_handles_prompt_template:
# Custom API handles prompt template
# Use the chat completions endpoint
self.log.debug("generate (chat/completions)", prompt=prompt[:128] + " ...", parameters=parameters)
human_message = {"role": "user", "content": prompt.strip()}
response = await self.client.create_chat_completion(self.model_name, [human_message], **parameters)
response = response['choices'][0]['message']['content']
return self.process_response_for_indirect_coercion(prompt, response)
else:
# Talemate handles prompt template
# Use the completions endpoint
self.log.debug("generate (completions)", prompt=prompt[:128] + " ...", parameters=parameters)
parameters["prompt"] = prompt
response = await self.client.create_completion(self.model_name, **parameters)
return response['choices'][0]['text']
except PermissionDeniedError as e:
self.log.error("generate error", e=e)
emit("status", message="Client API: Permission Denied", status="error")
return ""
except Exception as e:
self.log.error("generate error", e=e)
emit("status", message="Error during generation (check logs)", status="error")
return ""
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
if "api_url" in kwargs:
self.api_url = kwargs["api_url"]
if "max_token_length" in kwargs:
self.max_token_length = int(kwargs["max_token_length"]) if kwargs["max_token_length"] else 8192
if "api_key" in kwargs:
self.api_key = kwargs["api_key"]
if "api_handles_prompt_template" in kwargs:
self.api_handles_prompt_template = kwargs["api_handles_prompt_template"]
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
if "double_coercion" in kwargs:
self.double_coercion = kwargs["double_coercion"]
log.warning("reconfigure", kwargs=kwargs)
self.set_client(**kwargs)
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
"""
adjusts temperature and presence penalty by random values using the base value as a center
"""
temp = prompt_config["temperature"]
min_offset = offset * 0.3
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
# keep min_p in a tight range to avoid unwanted tokens
prompt_config["min_p"] = random.uniform(0.05, 0.15)
try:
presence_penalty = prompt_config["presence_penalty"]
adjusted_presence_penalty = round(random.uniform(
presence_penalty + 0.1, presence_penalty + offset
),1)
# Ensure presence_penalty does not exceed 0.5 and does not fall below 0.1
prompt_config["presence_penalty"] = min(0.5, max(0.1, adjusted_presence_penalty))
except KeyError:
pass

View File

@@ -39,6 +39,36 @@ class TextGeneratorWebuiClient(ClientBase):
headers["Authorization"] = f"Bearer {self.api_key}"
return headers
@property
def supported_parameters(self):
# textgenwebui does not error on unsupported parameters
# but we should still drop them so they don't get passed to the API
# and show up in our prompt debugging tool.
# note that this is not the full list of their supported parameters
# but only those we send.
return [
"temperature",
"top_p",
"top_k",
"min_p",
"frequency_penalty",
"presence_penalty",
"repetition_penalty",
"repetition_penalty_range",
"stopping_strings",
"skip_special_tokens",
"max_tokens",
"stream",
# arethese needed?
"max_new_tokens",
"stop",
# talemate internal
# These will be removed before sending to the API
# but we keep them here since they are used during the prompt finalization
"extra_stopping_strings",
]
def __init__(self, **kwargs):
self.api_key = kwargs.pop("api_key", "")
super().__init__(**kwargs)
@@ -51,39 +81,6 @@ class TextGeneratorWebuiClient(ClientBase):
# is this needed?
parameters["max_new_tokens"] = parameters["max_tokens"]
parameters["stop"] = parameters["stopping_strings"]
# textgenwebui does not error on unsupported parameters
# but we should still drop them so they don't get passed to the API
# and show up in our prompt debugging tool.
# note that this is not the full list of their supported parameters
# but only those we send.
allowed_params = [
"temperature",
"top_p",
"top_k",
"max_tokens",
"repetition_penalty",
"repetition_penalty_range",
"max_tokens",
"stopping_strings",
"skip_special_tokens",
"stream",
# is this needed?
"max_new_tokens",
"stop",
# talemate internal
# These will be removed before sending to the API
# but we keep them here since they are used during the prompt finalization
"extra_stopping_strings",
]
# drop unsupported params
for param in list(parameters.keys()):
if param not in allowed_params:
del parameters[param]
def set_client(self, **kwargs):
self.api_key = kwargs.get("api_key", self.api_key)

View File

@@ -1,33 +1,9 @@
import copy
import random
from urllib.parse import urljoin as _urljoin
__all__ = ["urljoin"]
def jiggle_randomness(prompt_config: dict, offset: float = 0.3) -> dict:
"""
adjusts temperature and repetition_penalty
by random values using the base value as a center
"""
temp = prompt_config["temperature"]
rep_pen = prompt_config["repetition_penalty"]
copied_config = copy.deepcopy(prompt_config)
min_offset = offset * 0.3
copied_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
copied_config["repetition_penalty"] = random.uniform(
rep_pen + min_offset * 0.3, rep_pen + offset * 0.3
)
return copied_config
def jiggle_enabled_for(kind: str):
if kind in ["conversation", "story"]:
return True
if kind.startswith("narrate"):
return True
return False
def urljoin(base, *args):
base = f"{base.rstrip('/')}/"
return _urljoin(base, *args)

View File

@@ -192,6 +192,58 @@ class RecentScene(BaseModel):
date: str
cover_image: Union[Asset, None] = None
class InferenceParameters(BaseModel):
temperature: float = 1.0
temperature_last: bool = True
top_p: float | None = 1.0
top_k: int | None = 0
min_p: float | None = 0.1
presence_penalty: float | None = 0.2
frequency_penalty: float | None = 0.05
repetition_penalty: float | None= 1.0
repetition_penalty_range: int | None = 1024
# this determines whether or not it should be persisted
# to the config file
changed: bool = False
class InferencePresets(BaseModel):
analytical: InferenceParameters = InferenceParameters(
temperature=0.7,
presence_penalty=0,
frequency_penalty=0,
repetition_penalty=1.0,
min_p=0,
)
conversation: InferenceParameters = InferenceParameters(
temperature=0.85,
repetition_penalty_range=2048
)
creative: InferenceParameters = InferenceParameters()
creative_instruction: InferenceParameters = InferenceParameters(
temperature=0.85,
repetition_penalty_range=512
)
deterministic: InferenceParameters = InferenceParameters(
temperature=0.1,
top_p=1,
top_k=0,
repetition_penalty=1.0,
min_p=0,
)
scene_direction: InferenceParameters = InferenceParameters(
temperature=0.85,
)
summarization: InferenceParameters = InferenceParameters(
temperature=0.7,
)
class Presets(BaseModel):
inference_defaults: InferencePresets = InferencePresets()
inference: InferencePresets = InferencePresets()
def gnerate_intro_scenes():
"""
@@ -353,6 +405,8 @@ class Config(BaseModel):
tts: TTSConfig = TTSConfig()
recent_scenes: RecentScenes = RecentScenes()
presets: Presets = Presets()
class Config:
extra = "ignore"
@@ -414,6 +468,22 @@ def save_config(config, file_path: str = "./config.yaml"):
log.error("config validation", error=e)
return None
# we dont want to persist the following, so we drop them:
# - presets.inference_defaults
try:
config["presets"].pop("inference_defaults")
except KeyError:
pass
# for normal presets we only want to persist if they have changed
for preset_name, preset in list(config["presets"]["inference"].items()):
if not preset.get("changed"):
config["presets"]["inference"].pop(preset_name)
# if presets is empty, remove it
if not config["presets"]["inference"]:
config.pop("presets")
with open(file_path, "w") as file:
yaml.dump(config, file)

View File

@@ -11,6 +11,10 @@
<v-icon start>mdi-application</v-icon>
Application
</v-tab>
<v-tab value="presets">
<v-icon start>mdi-tune</v-icon>
Presets
</v-tab>
<v-tab value="creator">
<v-icon start>mdi-palette-outline</v-icon>
Creator
@@ -263,6 +267,12 @@
</v-card>
</v-window-item>
<!-- PRESETS -->
<v-window-item value="presets">
<AppConfigPresets :immutable-config="app_config" ref="presets"></AppConfigPresets>
</v-window-item>
<!-- CREATOR -->
<v-window-item value="creator">
@@ -325,8 +335,13 @@
</template>
<script>
import AppConfigPresets from './AppConfigPresets.vue';
export default {
name: 'AppConfig',
components: {
AppConfigPresets,
},
data() {
return {
tab: 'game',
@@ -431,6 +446,13 @@ export default {
},
saveConfig() {
// check if presets component is present
if(this.$refs.presets) {
// update app_config.presets from $refs.presets.config
this.app_config.presets = this.$refs.presets.config;
}
this.sendRequest({
action: 'save',
config: this.app_config,

View File

@@ -0,0 +1,135 @@
<template>
<v-alert density="compact" type="warning" variant="text">
<p>
This interface is a work in progress and right now serves as a very basic way to edit inference parameter presets.
</p>
<p class="text-caption text-grey">
Not all clients support all parameters, and generally it is assumed that the client implementation
handles the parameters in a sane way, especially if values are passed for all of them. <span class="text-primary">All presets are used</span> and will be selected depending on the action the agent is performing. If you don't know what these mean, it is recommended to leave them as they are.
</p>
</v-alert>
<v-row>
<v-col cols="4">
<!-- list with all presets by key, read from `config` -->
<v-list slim selectable v-model:selected="selected" color="primary">
<v-list-item v-for="(preset, preset_key) in config.inference" :key="preset_key" :value="preset_key" prepend-icon="mdi-tune">
<v-list-item-title>{{ toLabel(preset_key) }}</v-list-item-title>
</v-list-item>
</v-list>
</v-col>
<v-col cols="8">
<!--
class InferenceParameters(BaseModel):
temperature: float = 1.0
temperature_last: bool = True
top_p: float | None = 1.0
top_k: int | None = 0
min_p: float | None = 0.1
presence_penalty: float | None = 0.2
frequency_penalty: float | None = 0.2
repetition_penalty: float | None= 1.1
repetition_penalty_range: int | None = 1024
Display editable form for the selected preset
Will use sliders for float and int values, and checkboxes for bool values
-->
<div v-if="selected.length === 1">
<v-form>
<v-card>
<v-card-title>
<v-row no-gutters>
<v-col cols="8">
{{ toLabel(selected[0]) }}
</v-col>
<v-col cols="4" class="text-right">
<v-btn variant="text" size="small" color="warning" prepend-icon="mdi-refresh" @click="config.inference[selected[0]] = {...immutableConfig.presets.inference_defaults[selected[0]]}">Reset</v-btn>
</v-col>
</v-row>
</v-card-title>
<v-card-text>
<v-slider thumb-label="always" density="compact" v-model="config.inference[selected[0]].temperature" min="0.1" max="2.0" step="0.05" label="Temperature" @update:model-value="setPresetChanged(selected[0])"></v-slider>
<v-slider thumb-label="always" density="compact" v-model="config.inference[selected[0]].top_p" min="0.1" max="1.0" step="0.05" label="Top P" @update:model-value="setPresetChanged(selected[0])"></v-slider>
<v-slider thumb-label="always" density="compact" v-model="config.inference[selected[0]].top_k" min="0" max="1024" step="1" label="Top K" @update:model-value="setPresetChanged(selected[0])"></v-slider>
<v-slider thumb-label="always" density="compact" v-model="config.inference[selected[0]].min_p" min="0" max="1.0" step="0.01" label="Min P" @update:model-value="setPresetChanged(selected[0])"></v-slider>
<v-slider thumb-label="always" density="compact" v-model="config.inference[selected[0]].presence_penalty" min="0" max="1.0" step="0.01" label="Presence Penalty" @update:model-value="setPresetChanged(selected[0])"></v-slider>
<v-slider thumb-label="always" density="compact" v-model="config.inference[selected[0]].frequency_penalty" min="0" max="1.0" step="0.01" label="Frequency Penalty" @update:model-value="setPresetChanged(selected[0])"></v-slider>
<v-slider thumb-label="always" density="compact" v-model="config.inference[selected[0]].repetition_penalty" min="1.0" max="1.20" step="0.01" label="Repetition Penalty" @update:model-value="setPresetChanged(selected[0])"></v-slider>
<v-slider thumb-label="always" density="compact" v-model="config.inference[selected[0]].repetition_penalty_range" min="0" max="4096" step="256" label="Repetition Penalty Range" @update:model-value="setPresetChanged(selected[0])"></v-slider>
<v-checkbox density="compact" v-model="config.inference[selected[0]].temperature_last" label="Sample temperature last" @update:model-value="setPresetChanged(selected[0])"></v-checkbox>
</v-card-text>
</v-card>
</v-form>
</div>
<div v-else>
<v-alert color="grey" variant="text">Select a preset to edit</v-alert>
</div>
</v-col>
</v-row>
</template>
<script>
export default {
name: 'AppConfigPresets',
components: {
},
props: {
immutableConfig: Object,
},
watch: {
immutableConfig: {
handler: function(newVal) {
console.log("immutableConfig changed", newVal)
if(!newVal) {
this.config = {};
return;
}
this.config = {...newVal.presets};
},
immediate: true,
deep: true,
},
},
emits: [
'update',
],
data() {
return {
selected: [],
config: {
inference: {},
},
}
},
methods: {
setPresetChanged(presetName) {
// this ensures that the change gets saved
console.log("setPresetChanged", presetName)
this.config.inference[presetName].changed = true;
},
toLabel(key) {
return key.replace(/_/g, ' ').replace(/\b\w/g, l => l.toUpperCase());
},
},
}
</script>

View File

@@ -3,7 +3,7 @@
<span>{{ queue.length }} sound(s) queued</span>
<v-icon :color="isPlaying ? 'green' : ''" v-if="!isMuted" @click="toggleMute">mdi-volume-high</v-icon>
<v-icon :color="isPlaying ? 'red' : ''" v-else @click="toggleMute">mdi-volume-off</v-icon>
<v-icon class="ml-1" @click="stopAndClear">mdi-stop-circle-outline</v-icon>
<v-icon v-if="isPlaying" class="ml-1" @click="stopAndClear">mdi-stop-circle-outline</v-icon>
</div>
</template>

View File

@@ -99,6 +99,7 @@ export default {
time: parseInt(data.data.time),
num: this.total++,
generation_parameters: data.data.generation_parameters,
inference_preset: data.data.inference_preset,
// immutable copy of original generation parameters
original_generation_parameters: JSON.parse(JSON.stringify(data.data.generation_parameters)),
original_prompt: data.data.prompt,

View File

@@ -5,9 +5,9 @@
<v-card-title>
#{{ prompt.num }}
<v-chip color="grey-lightne-1" variant="text">{{ prompt.agent_name }}</v-chip>
<v-chip color="grey" variant="text">{{ prompt.agent_action }}</v-chip>
<v-divider vertical></v-divider>
<v-chip color="grey" variant="text">{{ prompt.kind }}</v-chip>
<v-chip size="small" label class="mr-1" color="primary" variant="tonal"><strong class="mr-1">action</strong>{{ prompt.agent_action }}</v-chip>
<v-chip class="mr-1" size="small" color="grey" label variant="tonal"><strong class="mr-1">task</strong> {{ prompt.kind }}</v-chip>
<v-chip size="small" color="grey" label variant="tonal"><strong class="mr-1">preset</strong> {{ prompt.inference_preset }}</v-chip>
<v-chip size="small" color="primary" variant="text" label>{{ prompt.prompt_tokens }}<v-icon size="14"
class="ml-1">mdi-arrow-down-bold</v-icon></v-chip>
<v-chip size="small" color="secondary" variant="text" label>{{ prompt.response_tokens }}<v-icon size="14"