From a28cf2a02965f4fb84e34854f42edb20344f7bf8 Mon Sep 17 00:00:00 2001 From: veguAI <152010387+vegu-ai-tools@users.noreply.github.com> Date: Fri, 10 May 2024 21:29:29 +0300 Subject: [PATCH] 0.25.2 (#108) * fix typo * fix openai compat config save issue maybe * fix api_handles_prompt_template no longer saving changes after last fix * koboldcpp client * default to kobold ai api * linting * conversation cleanup tweak * 0.25.2 * allowed hosts to all on dev instance * ensure numbers on parameters when sending edited values * fix prompt parameter issues * remove debug message --- pyproject.toml | 2 +- src/talemate/__init__.py | 2 +- src/talemate/agents/conversation.py | 4 +- src/talemate/client/__init__.py | 3 +- src/talemate/client/koboldccp.py | 207 +++++++++++++++++- src/talemate/server/devtools.py | 16 +- talemate_frontend/package-lock.json | 4 +- talemate_frontend/package.json | 2 +- talemate_frontend/src/components/AIClient.vue | 17 ++ .../src/components/ClientModal.vue | 4 +- .../src/components/TalemateApp.vue | 2 +- talemate_frontend/vue.config.js | 1 + 12 files changed, 242 insertions(+), 22 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index abd46e62..60e38161 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.masonry.api" [tool.poetry] name = "talemate" -version = "0.25.1" +version = "0.25.2" description = "AI-backed roleplay and narrative tools" authors = ["FinalWombat"] license = "GNU Affero General Public License v3.0" diff --git a/src/talemate/__init__.py b/src/talemate/__init__.py index da8a9f0f..60d9ae28 100644 --- a/src/talemate/__init__.py +++ b/src/talemate/__init__.py @@ -2,4 +2,4 @@ from .agents import Agent from .client import TextGeneratorWebuiClient from .tale_mate import * -VERSION = "0.25.1" +VERSION = "0.25.2" diff --git a/src/talemate/agents/conversation.py b/src/talemate/agents/conversation.py index ea1826af..6a7177f8 100644 --- a/src/talemate/agents/conversation.py +++ b/src/talemate/agents/conversation.py @@ -668,7 +668,9 @@ class ConversationAgent(Agent): total_result = util.handle_endofline_special_delimiter(total_result) - if total_result.startswith(":\n"): + log.info("conversation agent", total_result=total_result) + + if total_result.startswith(":\n") or total_result.startswith(": "): total_result = total_result[2:] # movie script format diff --git a/src/talemate/client/__init__.py b/src/talemate/client/__init__.py index 42298feb..b37f007e 100644 --- a/src/talemate/client/__init__.py +++ b/src/talemate/client/__init__.py @@ -5,9 +5,10 @@ from talemate.client.anthropic import AnthropicClient from talemate.client.cohere import CohereClient from talemate.client.google import GoogleClient from talemate.client.groq import GroqClient +from talemate.client.koboldccp import KoboldCppClient from talemate.client.lmstudio import LMStudioClient from talemate.client.mistral import MistralAIClient from talemate.client.openai import OpenAIClient from talemate.client.openai_compat import OpenAICompatibleClient from talemate.client.registry import CLIENT_CLASSES, get_client_class, register -from talemate.client.textgenwebui import TextGeneratorWebuiClient +from talemate.client.textgenwebui import TextGeneratorWebuiClient \ No newline at end of file diff --git a/src/talemate/client/koboldccp.py b/src/talemate/client/koboldccp.py index 91869e1d..cf31289c 100644 --- a/src/talemate/client/koboldccp.py +++ b/src/talemate/client/koboldccp.py @@ -1,16 +1,201 @@ -import asyncio -import json -import logging import random -from abc import ABC, abstractmethod -from typing import Callable, Union +import re -import requests +# import urljoin +from urllib.parse import urljoin +import httpx +import structlog -import talemate.client.system_prompts as system_prompts -import talemate.util as util +from talemate.client.base import STOPPING_STRINGS, ClientBase, Defaults, ExtraField from talemate.client.registry import register -from talemate.client.textgenwebui import RESTTaleMateClient -from talemate.emit import Emission, emit -# NOT IMPLEMENTED AT THIS POINT +log = structlog.get_logger("talemate.client.koboldcpp") + + +class KoboldCppClientDefaults(Defaults): + api_key: str = "" + + +@register() +class KoboldCppClient(ClientBase): + auto_determine_prompt_template: bool = True + client_type = "koboldcpp" + + class Meta(ClientBase.Meta): + name_prefix: str = "KoboldCpp" + title: str = "KoboldCpp" + enable_api_auth: bool = True + defaults: KoboldCppClientDefaults = KoboldCppClientDefaults() + + @property + def request_headers(self): + headers = {} + headers["Content-Type"] = "application/json" + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + return headers + + @property + def is_openai(self) -> bool: + """ + kcpp has two apis + + open-ai implementation at /v1 + their own implenation at /api/v1 + """ + return "/api/v1" not in self.api_url + + @property + def api_url_for_model(self) -> str: + if self.is_openai: + # join /model to url + return urljoin(self.api_url, "models") + else: + # join /models to url + return urljoin(self.api_url, "model") + + @property + def api_url_for_generation(self) -> str: + if self.is_openai: + # join /v1/completions + return urljoin(self.api_url, "completions") + else: + # join /api/v1/generate + return urljoin(self.api_url, "generate") + + def api_endpoint_specified(self, url: str) -> bool: + return "/v1" in self.api_url + + def ensure_api_endpoint_specified(self): + if not self.api_endpoint_specified(self.api_url): + # url doesn't specify the api endpoint + # use the koboldcpp openai api + self.api_url = urljoin(self.api_url.rstrip("/") + "/", "/api/v1/") + if not self.api_url.endswith("/"): + self.api_url += "/" + + def __init__(self, **kwargs): + self.api_key = kwargs.pop("api_key", "") + super().__init__(**kwargs) + self.ensure_api_endpoint_specified() + + def tune_prompt_parameters(self, parameters: dict, kind: str): + super().tune_prompt_parameters(parameters, kind) + if not self.is_openai: + # adjustments for united api + parameters["max_length"] = parameters.pop("max_tokens") + parameters["max_context_length"] = self.max_token_length + if "repetition_penalty_range" in parameters: + parameters["rep_pen_range"] = parameters.pop("repetition_penalty_range") + if "repetition_penalty" in parameters: + parameters["rep_pen"] = parameters.pop("repetition_penalty") + if parameters.get("stop_sequence"): + parameters["stop_sequence"] = parameters.pop("stopping_strings") + + if parameters.get("extra_stopping_strings"): + if "stop_sequence" in parameters: + parameters["stop_sequence"] += parameters.pop("extra_stopping_strings") + else: + parameters["stop_sequence"] = parameters.pop("extra_stopping_strings") + + + allowed_params = [ + "max_length", + "max_context_length", + "rep_pen", + "rep_pen_range", + "top_p", + "top_k", + "temperature", + "stop_sequence", + ] + else: + # adjustments for openai api + if "repetition_penalty" in parameters: + parameters["presence_penalty"] = parameters.pop( + "repetition_penalty" + ) + + allowed_params = ["max_tokens", "presence_penalty", "top_p", "temperature"] + + # drop unsupported params + for param in list(parameters.keys()): + if param not in allowed_params: + del parameters[param] + + def set_client(self, **kwargs): + self.api_key = kwargs.get("api_key", self.api_key) + self.ensure_api_endpoint_specified() + + async def get_model_name(self): + self.ensure_api_endpoint_specified() + async with httpx.AsyncClient() as client: + response = await client.get( + self.api_url_for_model, + timeout=2, + headers=self.request_headers, + ) + + if response.status_code == 404: + raise KeyError(f"Could not find model info at: {self.api_url_for_model}") + + response_data = response.json() + if self.is_openai: + # {"object": "list", "data": [{"id": "koboldcpp/dolphin-2.8-mistral-7b", "object": "model", "created": 1, "owned_by": "koboldcpp", "permission": [], "root": "koboldcpp"}]} + model_name = response_data.get("data")[0].get("id") + else: + # {"result": "koboldcpp/dolphin-2.8-mistral-7b"} + model_name = response_data.get("result") + + # split by "/" and take last + if model_name: + model_name = model_name.split("/")[-1] + + return model_name + + async def generate(self, prompt: str, parameters: dict, kind: str): + """ + Generates text from the given prompt and parameters. + """ + + parameters["prompt"] = prompt.strip(" ") + + async with httpx.AsyncClient() as client: + response = await client.post( + self.api_url_for_generation, + json=parameters, + timeout=None, + headers=self.request_headers, + ) + response_data = response.json() + + try: + if self.is_openai: + return response_data["choices"][0]["text"] + else: + return response_data["results"][0]["text"] + except (TypeError, KeyError) as exc: + log.error("Failed to generate text", exc=exc, response_data=response_data, response_status=response.status_code) + return "" + + def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict: + """ + adjusts temperature and repetition_penalty + by random values using the base value as a center + """ + + temp = prompt_config["temperature"] + rep_pen = prompt_config["rep_pen"] + + min_offset = offset * 0.3 + + prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset) + prompt_config["rep_pen"] = random.uniform( + rep_pen + min_offset * 0.3, rep_pen + offset * 0.3 + ) + + def reconfigure(self, **kwargs): + if "api_key" in kwargs: + self.api_key = kwargs.pop("api_key") + + super().reconfigure(**kwargs) diff --git a/src/talemate/server/devtools.py b/src/talemate/server/devtools.py index 8de27443..a2ce7222 100644 --- a/src/talemate/server/devtools.py +++ b/src/talemate/server/devtools.py @@ -11,6 +11,20 @@ class TestPromptPayload(pydantic.BaseModel): kind: str +def ensure_number(v): + """ + if v is a str but digit turn into into or float + """ + + if isinstance(v, str): + if v.isdigit(): + return int(v) + try: + return float(v) + except ValueError: + return v + return v + class DevToolsPlugin: router = "devtools" @@ -34,7 +48,7 @@ class DevToolsPlugin: log.info( "Testing prompt", payload={ - k: v for k, v in payload.generation_parameters.items() if k != "prompt" + k: ensure_number(v) for k, v in payload.generation_parameters.items() if k != "prompt" }, ) diff --git a/talemate_frontend/package-lock.json b/talemate_frontend/package-lock.json index b69d0a9b..7e37ac22 100644 --- a/talemate_frontend/package-lock.json +++ b/talemate_frontend/package-lock.json @@ -1,12 +1,12 @@ { "name": "talemate_frontend", - "version": "0.25.1", + "version": "0.25.2", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "talemate_frontend", - "version": "0.25.1", + "version": "0.25.2", "dependencies": { "@codemirror/lang-markdown": "^6.2.5", "@codemirror/theme-one-dark": "^6.1.2", diff --git a/talemate_frontend/package.json b/talemate_frontend/package.json index d93d5681..12c092ab 100644 --- a/talemate_frontend/package.json +++ b/talemate_frontend/package.json @@ -1,6 +1,6 @@ { "name": "talemate_frontend", - "version": "0.25.1", + "version": "0.25.2", "private": true, "scripts": { "serve": "vue-cli-service serve", diff --git a/talemate_frontend/src/components/AIClient.vue b/talemate_frontend/src/components/AIClient.vue index 0d6688d8..705c7443 100644 --- a/talemate_frontend/src/components/AIClient.vue +++ b/talemate_frontend/src/components/AIClient.vue @@ -244,6 +244,13 @@ export default { client.api_key = data.api_key; client.double_coercion = data.data.double_coercion; client.data = data.data; + for (let key in client.data.meta.extra_fields) { + if (client.data[key] === null || client.data[key] === undefined) { + client.data[key] = client.data.meta.defaults[key]; + } + client[key] = client.data[key]; + } + } else if(!client) { console.log("Adding new client", data); @@ -259,6 +266,16 @@ export default { double_coercion: data.data.double_coercion, data: data.data, }); + + // apply extra field defaults + let client = this.state.clients[this.state.clients.length - 1]; + for (let key in client.data.meta.extra_fields) { + if (client.data[key] === null || client.data[key] === undefined) { + client.data[key] = client.data.meta.defaults[key]; + } + client[key] = client.data[key]; + } + // sort the clients by name this.state.clients.sort((a, b) => (a.name > b.name) ? 1 : -1); } diff --git a/talemate_frontend/src/components/ClientModal.vue b/talemate_frontend/src/components/ClientModal.vue index 366a65a9..9b0e1fa3 100644 --- a/talemate_frontend/src/components/ClientModal.vue +++ b/talemate_frontend/src/components/ClientModal.vue @@ -56,9 +56,9 @@ - - diff --git a/talemate_frontend/src/components/TalemateApp.vue b/talemate_frontend/src/components/TalemateApp.vue index 19434268..b52d58a8 100644 --- a/talemate_frontend/src/components/TalemateApp.vue +++ b/talemate_frontend/src/components/TalemateApp.vue @@ -248,7 +248,7 @@ export default { messageHandlers: [], scene: {}, appConfig: {}, - autcompleting: false, + autocompleting: false, autocompletePartialInput: "", autocompleteCallback: null, autocompleteFocusElement: null, diff --git a/talemate_frontend/vue.config.js b/talemate_frontend/vue.config.js index f6275c64..09892f06 100644 --- a/talemate_frontend/vue.config.js +++ b/talemate_frontend/vue.config.js @@ -9,6 +9,7 @@ module.exports = defineConfig({ }, devServer: { + allowedHosts: "all", client: { overlay: { warnings: false,