mirror of
https://github.com/vegu-ai/talemate.git
synced 2025-12-15 19:27:47 +01:00
0.25.5 (#121)
* oepnai compat client to /completions instead of chat/completions openai compat client pass frequency penalty * 0.25.5 * fix version * remove debug message * fix openai compat client not saving coercion settings * openai compatible client: API handles prompt template switches over to chat/completions api * wording * mistral std template * fix error when setting llm prompt template if model name contained / * lock sentence transformers to 2.2.2 since >=2.3.0 breaks instructor model loading * support png tEXt * openai compat client: fix repetition_penality KeyError issue * presence_penalty is not equal to repetition_penalty and needs its own dedicated definition * round presence penalty randomization to one decimal place * fix filename * same fixes for presence_penalty ported to koboldcpp client * kcpp client: remove a1111 setup spam kcpp client: fixes to presence_penalty jiggle * mistral.ai: default model 8x22b mistral.ai: 7b and 8x7b taken out of JSON_OBJECT_RESPONSE_MODELS
This commit is contained in:
1098
poetry.lock
generated
1098
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@ build-backend = "poetry.masonry.api"
|
||||
|
||||
[tool.poetry]
|
||||
name = "talemate"
|
||||
version = "0.25.4"
|
||||
version = "0.25.5"
|
||||
description = "AI-backed roleplay and narrative tools"
|
||||
authors = ["FinalWombat"]
|
||||
license = "GNU Affero General Public License v3.0"
|
||||
@@ -51,7 +51,8 @@ chromadb = ">=0.4.17,<1"
|
||||
InstructorEmbedding = "^1.0.1"
|
||||
torch = ">=2.1.0"
|
||||
torchaudio = ">=2.3.0"
|
||||
sentence-transformers="^2.2.2"
|
||||
# locked for instructor embeddings
|
||||
sentence-transformers="==2.2.2"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "^6.2"
|
||||
|
||||
@@ -2,4 +2,4 @@ from .agents import Agent
|
||||
from .client import TextGeneratorWebuiClient
|
||||
from .tale_mate import *
|
||||
|
||||
VERSION = "0.25.4"
|
||||
VERSION = "0.25.5"
|
||||
|
||||
@@ -5,7 +5,7 @@ from talemate.client.anthropic import AnthropicClient
|
||||
from talemate.client.cohere import CohereClient
|
||||
from talemate.client.google import GoogleClient
|
||||
from talemate.client.groq import GroqClient
|
||||
from talemate.client.koboldccp import KoboldCppClient
|
||||
from talemate.client.koboldcpp import KoboldCppClient
|
||||
from talemate.client.lmstudio import LMStudioClient
|
||||
from talemate.client.mistral import MistralAIClient
|
||||
from talemate.client.openai import OpenAIClient
|
||||
|
||||
@@ -755,3 +755,29 @@ class ClientBase:
|
||||
new_lines.append(line)
|
||||
|
||||
return "\n".join(new_lines)
|
||||
|
||||
|
||||
def process_response_for_indirect_coercion(self, prompt:str, response:str) -> str:
|
||||
|
||||
"""
|
||||
A lot of remote APIs don't let us control the prompt template and we cannot directly
|
||||
append the beginning of the desired response to the prompt.
|
||||
|
||||
With indirect coercion we tell the LLM what the beginning of the response should be
|
||||
and then hopefully it will adhere to it and we can strip it off the actual response.
|
||||
"""
|
||||
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
if (
|
||||
expected_response
|
||||
and expected_response.startswith("{")
|
||||
):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
|
||||
@@ -128,12 +128,6 @@ class KoboldCppClient(ClientBase):
|
||||
"stop_sequence",
|
||||
]
|
||||
else:
|
||||
# adjustments for openai api
|
||||
if "repetition_penalty" in parameters:
|
||||
parameters["presence_penalty"] = parameters.pop(
|
||||
"repetition_penalty"
|
||||
)
|
||||
|
||||
allowed_params = ["max_tokens", "presence_penalty", "top_p", "temperature"]
|
||||
|
||||
# drop unsupported params
|
||||
@@ -243,19 +237,27 @@ class KoboldCppClient(ClientBase):
|
||||
|
||||
if "rep_pen" in prompt_config:
|
||||
rep_pen_key = "rep_pen"
|
||||
elif "frequency_penalty" in prompt_config:
|
||||
rep_pen_key = "frequency_penalty"
|
||||
elif "presence_penalty" in prompt_config:
|
||||
rep_pen_key = "presence_penalty"
|
||||
else:
|
||||
rep_pen_key = "repetition_penalty"
|
||||
|
||||
rep_pen = prompt_config[rep_pen_key]
|
||||
|
||||
min_offset = offset * 0.3
|
||||
|
||||
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
|
||||
prompt_config[rep_pen_key] = random.uniform(
|
||||
rep_pen + min_offset * 0.3, rep_pen + offset * 0.3
|
||||
)
|
||||
try:
|
||||
if rep_pen_key == "presence_penalty":
|
||||
presence_penalty = prompt_config["presence_penalty"]
|
||||
prompt_config["presence_penalty"] = round(random.uniform(
|
||||
presence_penalty + 0.1, presence_penalty + offset
|
||||
),1)
|
||||
else:
|
||||
rep_pen = prompt_config[rep_pen_key]
|
||||
prompt_config[rep_pen_key] = random.uniform(
|
||||
rep_pen + min_offset * 0.3, rep_pen + offset * 0.3
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if "api_key" in kwargs:
|
||||
@@ -293,10 +295,11 @@ class KoboldCppClient(ClientBase):
|
||||
|
||||
sd_model = response_data[0].get("model_name") if response_data else None
|
||||
|
||||
log.info("automatic1111_setup", sd_model=sd_model)
|
||||
if not sd_model:
|
||||
return False
|
||||
|
||||
log.info("automatic1111_setup", sd_model=sd_model)
|
||||
|
||||
visual_agent.actions["automatic1111"].config["api_url"].value = self.url
|
||||
visual_agent.is_enabled = True
|
||||
return True
|
||||
@@ -25,12 +25,16 @@ SUPPORTED_MODELS = [
|
||||
"mistral-large-latest",
|
||||
]
|
||||
|
||||
JSON_OBJECT_RESPONSE_MODELS = SUPPORTED_MODELS
|
||||
|
||||
JSON_OBJECT_RESPONSE_MODELS = [
|
||||
"open-mixtral-8x22b",
|
||||
"mistral-small-latest",
|
||||
"mistral-medium-latest",
|
||||
"mistral-large-latest",
|
||||
]
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "open-mixtral-8x7b"
|
||||
model: str = "open-mixtral-8x22b"
|
||||
|
||||
|
||||
@register()
|
||||
@@ -53,7 +57,7 @@ class MistralAIClient(ClientBase):
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model="open-mixtral-8x7b", **kwargs):
|
||||
def __init__(self, model="open-mixtral-8x22b", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self.config = load_config()
|
||||
@@ -115,7 +119,7 @@ class MistralAIClient(ClientBase):
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "open-mixtral-8x7b"
|
||||
self.model_name = "open-mixtral-8x22b"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
@@ -136,13 +136,15 @@ class ModelPrompt:
|
||||
"""
|
||||
|
||||
matches = []
|
||||
|
||||
cleaned_model_name = model_name.replace("/", "__")
|
||||
|
||||
# Iterate over all templates in the loader's directory
|
||||
for template_name in self.env.list_templates():
|
||||
# strip extension
|
||||
template_name_match = os.path.splitext(template_name)[0]
|
||||
# Check if the model name is in the template filename
|
||||
if template_name_match.lower() in model_name.lower():
|
||||
if template_name_match.lower() in cleaned_model_name.lower():
|
||||
matches.append(template_name)
|
||||
|
||||
# If there are no matches, return None
|
||||
@@ -163,16 +165,17 @@ class ModelPrompt:
|
||||
"""
|
||||
|
||||
template_name = template_name.split(".jinja2")[0]
|
||||
|
||||
cleaned_model_name = model_name.replace("/", "__")
|
||||
|
||||
shutil.copyfile(
|
||||
os.path.join(STD_TEMPLATE_PATH, template_name + ".jinja2"),
|
||||
os.path.join(USER_TEMPLATE_PATH, model_name + ".jinja2"),
|
||||
os.path.join(USER_TEMPLATE_PATH, cleaned_model_name + ".jinja2"),
|
||||
)
|
||||
|
||||
return os.path.join(USER_TEMPLATE_PATH, model_name + ".jinja2")
|
||||
return os.path.join(USER_TEMPLATE_PATH, cleaned_model_name + ".jinja2")
|
||||
|
||||
def query_hf_for_prompt_template_suggestion(self, model_name: str):
|
||||
print("query_hf_for_prompt_template_suggestion", model_name)
|
||||
api = huggingface_hub.HfApi()
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import urllib
|
||||
|
||||
import random
|
||||
import pydantic
|
||||
import structlog
|
||||
from openai import AsyncOpenAI, NotFoundError, PermissionDeniedError
|
||||
@@ -20,6 +20,7 @@ class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 8192
|
||||
model: str = ""
|
||||
api_handles_prompt_template: bool = False
|
||||
double_coercion: str = None
|
||||
|
||||
|
||||
class ClientConfig(BaseClientConfig):
|
||||
@@ -43,9 +44,9 @@ class OpenAICompatibleClient(ClientBase):
|
||||
"api_handles_prompt_template": ExtraField(
|
||||
name="api_handles_prompt_template",
|
||||
type="bool",
|
||||
label="API Handles Prompt Template",
|
||||
label="API handles prompt template (chat/completions)",
|
||||
required=False,
|
||||
description="The API handles the prompt template, meaning your choice in the UI for the prompt template below will be ignored.",
|
||||
description="The API handles the prompt template, meaning your choice in the UI for the prompt template below will be ignored. This is not recommended and should only be used if the API does not support the `completions` andpoint or you don't know which prompt template to use.",
|
||||
)
|
||||
}
|
||||
|
||||
@@ -83,13 +84,12 @@ class OpenAICompatibleClient(ClientBase):
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
keys = list(parameters.keys())
|
||||
allowed_params = ["max_tokens", "presence_penalty", "top_p", "temperature"]
|
||||
|
||||
valid_keys = ["temperature", "top_p", "max_tokens"]
|
||||
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
# drop unsupported params
|
||||
for param in list(parameters.keys()):
|
||||
if param not in allowed_params:
|
||||
del parameters[param]
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
|
||||
@@ -117,16 +117,27 @@ class OpenAICompatibleClient(ClientBase):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name, messages=[human_message], **parameters
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
if self.api_handles_prompt_template:
|
||||
# OpenAI API handles prompt template
|
||||
# Use the chat completions endpoint
|
||||
self.log.debug("generate (chat/completions)", prompt=prompt[:128] + " ...", parameters=parameters)
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name, messages=[human_message], **parameters
|
||||
)
|
||||
response = response.choices[0].message.content
|
||||
return self.process_response_for_indirect_coercion(prompt, response)
|
||||
else:
|
||||
# Talemate handles prompt template
|
||||
# Use the completions endpoint
|
||||
self.log.debug("generate (completions)", prompt=prompt[:128] + " ...", parameters=parameters)
|
||||
parameters["prompt"] = prompt
|
||||
response = await self.client.completions.create(
|
||||
model=self.model_name, **parameters
|
||||
)
|
||||
return response.choices[0].text
|
||||
except PermissionDeniedError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="Client API: Permission Denied", status="error")
|
||||
@@ -151,7 +162,33 @@ class OpenAICompatibleClient(ClientBase):
|
||||
self.api_key = kwargs["api_key"]
|
||||
if "api_handles_prompt_template" in kwargs:
|
||||
self.api_handles_prompt_template = kwargs["api_handles_prompt_template"]
|
||||
# TODO: why isn't this calling super()?
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
if "double_coercion" in kwargs:
|
||||
self.double_coercion = kwargs["double_coercion"]
|
||||
|
||||
log.warning("reconfigure", kwargs=kwargs)
|
||||
|
||||
self.set_client(**kwargs)
|
||||
|
||||
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
|
||||
"""
|
||||
adjusts temperature and presence penalty
|
||||
by random values using the base value as a center
|
||||
"""
|
||||
|
||||
temp = prompt_config["temperature"]
|
||||
|
||||
min_offset = offset * 0.3
|
||||
|
||||
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
|
||||
|
||||
try:
|
||||
presence_penalty = prompt_config["presence_penalty"]
|
||||
prompt_config["presence_penalty"] = round(random.uniform(
|
||||
presence_penalty + 0.1, presence_penalty + offset
|
||||
),1)
|
||||
except KeyError:
|
||||
pass
|
||||
@@ -11,10 +11,15 @@ __all__ = [
|
||||
"PRESET_SIMPLE_1",
|
||||
]
|
||||
|
||||
# TODO: refactor abstraction and make configurable
|
||||
|
||||
PRESENCE_PENALTY_BASE = 0.2
|
||||
|
||||
PRESET_TALEMATE_CONVERSATION = {
|
||||
"temperature": 0.65,
|
||||
"top_p": 0.47,
|
||||
"top_k": 42,
|
||||
"presence_penalty": PRESENCE_PENALTY_BASE,
|
||||
"repetition_penalty": 1.18,
|
||||
"repetition_penalty_range": 2048,
|
||||
}
|
||||
@@ -23,6 +28,7 @@ PRESET_TALEMATE_CREATOR = {
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"top_k": 20,
|
||||
"presence_penalty": PRESENCE_PENALTY_BASE,
|
||||
"repetition_penalty": 1.15,
|
||||
"repetition_penalty_range": 512,
|
||||
}
|
||||
@@ -31,6 +37,7 @@ PRESET_LLAMA_PRECISE = {
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.1,
|
||||
"top_k": 40,
|
||||
"presence_penalty": PRESENCE_PENALTY_BASE,
|
||||
"repetition_penalty": 1.18,
|
||||
}
|
||||
|
||||
@@ -45,6 +52,7 @@ PRESET_DIVINE_INTELLECT = {
|
||||
"temperature": 1.31,
|
||||
"top_p": 0.14,
|
||||
"top_k": 49,
|
||||
"presence_penalty": PRESENCE_PENALTY_BASE,
|
||||
"repetition_penalty_range": 1024,
|
||||
"repetition_penalty": 1.17,
|
||||
}
|
||||
@@ -53,6 +61,7 @@ PRESET_SIMPLE_1 = {
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"top_k": 20,
|
||||
"presence_penalty": PRESENCE_PENALTY_BASE,
|
||||
"repetition_penalty": 1.15,
|
||||
}
|
||||
|
||||
|
||||
@@ -51,6 +51,39 @@ class TextGeneratorWebuiClient(ClientBase):
|
||||
# is this needed?
|
||||
parameters["max_new_tokens"] = parameters["max_tokens"]
|
||||
parameters["stop"] = parameters["stopping_strings"]
|
||||
|
||||
|
||||
# textgenwebui does not error on unsupported parameters
|
||||
# but we should still drop them so they don't get passed to the API
|
||||
# and show up in our prompt debugging tool.
|
||||
|
||||
# note that this is not the full list of their supported parameters
|
||||
# but only those we send.
|
||||
|
||||
allowed_params = [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"max_tokens",
|
||||
"repetition_penalty",
|
||||
"repetition_penalty_range",
|
||||
"max_tokens",
|
||||
"stopping_strings",
|
||||
"skip_special_tokens",
|
||||
"stream",
|
||||
# is this needed?
|
||||
"max_new_tokens",
|
||||
"stop",
|
||||
# talemate internal
|
||||
# These will be removed before sending to the API
|
||||
# but we keep them here since they are used during the prompt finalization
|
||||
"extra_stopping_strings",
|
||||
]
|
||||
|
||||
# drop unsupported params
|
||||
for param in list(parameters.keys()):
|
||||
if param not in allowed_params:
|
||||
del parameters[param]
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.api_key = kwargs.get("api_key", self.api_key)
|
||||
|
||||
@@ -5,7 +5,7 @@ import json
|
||||
import re
|
||||
import textwrap
|
||||
from typing import List, Union
|
||||
|
||||
import struct
|
||||
import isodate
|
||||
import structlog
|
||||
from colorama import Back, Fore, Style, init
|
||||
@@ -179,6 +179,29 @@ def color_emotes(text: str, color: str = "blue") -> str:
|
||||
def extract_metadata(img_path, img_format):
|
||||
return chara_read(img_path)
|
||||
|
||||
def read_metadata_from_png_text(image_path:str) -> dict:
|
||||
|
||||
"""
|
||||
Reads the character metadata from the tEXt chunk of a PNG image.
|
||||
"""
|
||||
|
||||
# Read the image
|
||||
with open(image_path, 'rb') as f:
|
||||
png_data = f.read()
|
||||
|
||||
# Split the PNG data into chunks
|
||||
offset = 8 # Skip the PNG signature
|
||||
while offset < len(png_data):
|
||||
length = struct.unpack('!I', png_data[offset:offset+4])[0]
|
||||
chunk_type = png_data[offset+4:offset+8]
|
||||
chunk_data = png_data[offset+8:offset+8+length]
|
||||
if chunk_type == b'tEXt':
|
||||
keyword, text_data = chunk_data.split(b'\x00', 1)
|
||||
if keyword == b'chara':
|
||||
return json.loads(base64.b64decode(text_data).decode('utf-8'))
|
||||
offset += 12 + length
|
||||
|
||||
raise ValueError('No character metadata found.')
|
||||
|
||||
def chara_read(img_url, input_format=None):
|
||||
if input_format is None:
|
||||
@@ -194,7 +217,6 @@ def chara_read(img_url, input_format=None):
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
|
||||
exif_data = image.getexif()
|
||||
|
||||
if format == "webp":
|
||||
try:
|
||||
if 37510 in exif_data:
|
||||
@@ -235,7 +257,15 @@ def chara_read(img_url, input_format=None):
|
||||
return base64_decoded_data
|
||||
else:
|
||||
log.warn("chara_load", msg="No chara data found in PNG image.")
|
||||
return False
|
||||
log.warn("chara_load", msg="Trying to read from PNG text.")
|
||||
|
||||
try:
|
||||
return read_metadata_from_png_text(img_url)
|
||||
except ValueError:
|
||||
return False
|
||||
except Exception as exc:
|
||||
log.error("chara_load", msg="Error reading metadata from PNG text.", exc_info=exc)
|
||||
return False
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
4
talemate_frontend/package-lock.json
generated
4
talemate_frontend/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "talemate_frontend",
|
||||
"version": "0.25.4",
|
||||
"version": "0.25.5",
|
||||
"lockfileVersion": 2,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "talemate_frontend",
|
||||
"version": "0.25.4",
|
||||
"version": "0.25.5",
|
||||
"dependencies": {
|
||||
"@codemirror/lang-markdown": "^6.2.5",
|
||||
"@codemirror/theme-one-dark": "^6.1.2",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "talemate_frontend",
|
||||
"version": "0.25.4",
|
||||
"version": "0.25.5",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"serve": "vue-cli-service serve",
|
||||
|
||||
1
templates/llm-prompt/std/Mistral.jinja2
Normal file
1
templates/llm-prompt/std/Mistral.jinja2
Normal file
@@ -0,0 +1 @@
|
||||
<s>[INST] {{ system_message }} {{ user_message }} [/INST] {{ coercion_message }}
|
||||
Reference in New Issue
Block a user