diff --git a/src/modules/AdvancedPaste/AdvancedPaste/Services/OpenAI/CustomTextTransformService.cs b/src/modules/AdvancedPaste/AdvancedPaste/Services/OpenAI/CustomTextTransformService.cs index 992a88dd7c..58eba35989 100644 --- a/src/modules/AdvancedPaste/AdvancedPaste/Services/OpenAI/CustomTextTransformService.cs +++ b/src/modules/AdvancedPaste/AdvancedPaste/Services/OpenAI/CustomTextTransformService.cs @@ -3,53 +3,56 @@ // See the LICENSE file in the project root for more information. using System; -using System.Text.Json; +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; - using AdvancedPaste.Helpers; using AdvancedPaste.Models; using AdvancedPaste.Telemetry; using Azure; using ManagedCommon; using Microsoft.PowerToys.Telemetry; +using OpenAI; +using OpenAI.Chat; namespace AdvancedPaste.Services.OpenAI; public sealed class CustomTextTransformService(IAICredentialsProvider aiCredentialsProvider, IPromptModerationService promptModerationService) : ICustomTextTransformService { - private const string ModelName = "gpt-3.5-turbo-instruct"; + private const string ModelName = "gpt-3.5-turbo"; private readonly IAICredentialsProvider _aiCredentialsProvider = aiCredentialsProvider; private readonly IPromptModerationService _promptModerationService = promptModerationService; - private async Task GetAICompletionAsync(string systemInstructions, string userMessage, CancellationToken cancellationToken) + private async Task GetAICompletionAsync(string systemInstructions, string userMessage, CancellationToken cancellationToken) { var fullPrompt = systemInstructions + "\n\n" + userMessage; await _promptModerationService.ValidateAsync(fullPrompt, cancellationToken); - OpenAIClient azureAIClient = new(_aiCredentialsProvider.Key); + ChatClient chatClient = new(ModelName, _aiCredentialsProvider.Key); - var response = await azureAIClient.GetCompletionsAsync( - new() + var messages = new List + { + ChatMessage.CreateSystemMessage(systemInstructions), + ChatMessage.CreateUserMessage(userMessage), + }; + + var response = await chatClient.CompleteChatAsync( + messages, + new ChatCompletionOptions { - DeploymentName = ModelName, - Prompts = - { - fullPrompt, - }, - Temperature = 0.01F, - MaxTokens = 2000, + Temperature = 0.01f, + MaxOutputTokenCount = 2000, }, cancellationToken); - if (response.Value.Choices[0].FinishReason == "length") + if (response.Value.FinishReason == ChatFinishReason.Length) { Logger.LogDebug("Cut off due to length constraints"); } - return response; + return response.Value; } public async Task TransformTextAsync(string prompt, string inputText, CancellationToken cancellationToken, IProgress progress) @@ -84,13 +87,13 @@ Output: var response = await GetAICompletionAsync(systemInstructions, userMessage, cancellationToken); var usage = response.Usage; - AdvancedPasteGenerateCustomFormatEvent telemetryEvent = new(usage.PromptTokens, usage.CompletionTokens, ModelName); + AdvancedPasteGenerateCustomFormatEvent telemetryEvent = new(usage.InputTokenCount, usage.OutputTokenCount, ModelName); PowerToysTelemetry.Log.WriteEvent(telemetryEvent); var logEvent = new AIServiceFormatEvent(telemetryEvent); Logger.LogDebug($"{nameof(TransformTextAsync)} complete; {logEvent.ToJsonString()}"); - return response.Choices[0].Text; + return response.Content[0].Text; } catch (Exception ex) { diff --git a/src/modules/AdvancedPaste/AdvancedPaste/Services/OpenAI/KernelService.cs b/src/modules/AdvancedPaste/AdvancedPaste/Services/OpenAI/KernelService.cs index 1f3b25dbcc..0b13fe5264 100644 --- a/src/modules/AdvancedPaste/AdvancedPaste/Services/OpenAI/KernelService.cs +++ b/src/modules/AdvancedPaste/AdvancedPaste/Services/OpenAI/KernelService.cs @@ -3,7 +3,6 @@ // See the LICENSE file in the project root for more information. using System.Collections.Generic; - using AdvancedPaste.Models; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.OpenAI; @@ -26,8 +25,33 @@ public sealed class KernelService(IKernelQueryCacheService queryCacheService, IA protected override void AddChatCompletionService(IKernelBuilder kernelBuilder) => kernelBuilder.AddOpenAIChatCompletion(ModelName, _aiCredentialsProvider.Key); - protected override AIServiceUsage GetAIServiceUsage(ChatMessageContent chatMessage) => - chatMessage.Metadata?.GetValueOrDefault("Usage") is CompletionsUsage completionsUsage - ? new(PromptTokens: completionsUsage.PromptTokens, CompletionTokens: completionsUsage.CompletionTokens) - : AIServiceUsage.None; + protected override AIServiceUsage GetAIServiceUsage(ChatMessageContent chatMessage) + { + // Try to get usage information from metadata + if (chatMessage.Metadata?.TryGetValue("Usage", out var usageObj) == true) + { + // Handle different possible usage types through reflection to be version-agnostic + var usageType = usageObj.GetType(); + + try + { + // Try common property names for prompt tokens + var promptTokensProp = usageType.GetProperty("PromptTokens") ?? usageType.GetProperty("InputTokens") ?? usageType.GetProperty("InputTokenCount"); + var completionTokensProp = usageType.GetProperty("CompletionTokens") ?? usageType.GetProperty("OutputTokens") ?? usageType.GetProperty("OutputTokenCount"); + + if (promptTokensProp != null && completionTokensProp != null) + { + var promptTokens = (int)(promptTokensProp.GetValue(usageObj) ?? 0); + var completionTokens = (int)(completionTokensProp.GetValue(usageObj) ?? 0); + return new AIServiceUsage(promptTokens, completionTokens); + } + } + catch + { + // If reflection fails, fall back to no usage + } + } + + return AIServiceUsage.None; + } }