// Copyright (c) Microsoft Corporation // The Microsoft Corporation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using ManagedCommon; using Microsoft.CmdPal.Common.Services; using Microsoft.CmdPal.UI.ViewModels.Models; using Microsoft.CmdPal.UI.ViewModels.Services; using Microsoft.CommandPalette.Extensions; using Microsoft.CommandPalette.Extensions.Toolkit; using Microsoft.Extensions.DependencyInjection; using Windows.Foundation; namespace Microsoft.CmdPal.UI.ViewModels; public sealed class CommandProviderWrapper { public bool IsExtension => Extension is not null; private readonly bool isValid; private readonly ExtensionObject _commandProvider; private readonly TaskScheduler _taskScheduler; private readonly ICommandProviderCache? _commandProviderCache; public TopLevelViewModel[] TopLevelItems { get; private set; } = []; public TopLevelViewModel[] FallbackItems { get; private set; } = []; public string DisplayName { get; private set; } = string.Empty; public IExtensionWrapper? Extension { get; } public CommandPaletteHost ExtensionHost { get; private set; } public event TypedEventHandler? CommandsChanged; public string Id { get; private set; } = string.Empty; public IconInfoViewModel Icon { get; private set; } = new(null); public CommandSettingsViewModel? Settings { get; private set; } public bool IsActive { get; private set; } public string ProviderId => string.IsNullOrEmpty(Extension?.ExtensionUniqueId) ? Id : Extension.ExtensionUniqueId; public CommandProviderWrapper(ICommandProvider provider, TaskScheduler mainThread) { // This ctor is only used for in-proc builtin commands. So the Unsafe! // calls are pretty dang safe actually. _commandProvider = new(provider); _taskScheduler = mainThread; // Hook the extension back into us ExtensionHost = new CommandPaletteHost(provider); _commandProvider.Unsafe!.InitializeWithHost(ExtensionHost); _commandProvider.Unsafe!.ItemsChanged += CommandProvider_ItemsChanged; isValid = true; Id = provider.Id; DisplayName = provider.DisplayName; Icon = new(provider.Icon); Icon.InitializeProperties(); // Note: explicitly not InitializeProperties()ing the settings here. If // we do that, then we'd regress GH #38321 Settings = new(provider.Settings, this, _taskScheduler); Logger.LogDebug($"Initialized command provider {ProviderId}"); } public CommandProviderWrapper(IExtensionWrapper extension, TaskScheduler mainThread, ICommandProviderCache commandProviderCache) { _taskScheduler = mainThread; _commandProviderCache = commandProviderCache; Extension = extension; ExtensionHost = new CommandPaletteHost(extension); if (!Extension.IsRunning()) { throw new ArgumentException("You forgot to start the extension. This is a CmdPal error - we need to make sure to call StartExtensionAsync"); } var extensionImpl = extension.GetExtensionObject(); var providerObject = extensionImpl?.GetProvider(ProviderType.Commands); if (providerObject is not ICommandProvider provider) { throw new ArgumentException("extension didn't actually implement ICommandProvider"); } _commandProvider = new(provider); try { var model = _commandProvider.Unsafe!; // Hook the extension back into us model.InitializeWithHost(ExtensionHost); model.ItemsChanged += CommandProvider_ItemsChanged; isValid = true; Logger.LogDebug($"Initialized extension command provider {Extension.PackageFamilyName}:{Extension.ExtensionUniqueId}"); } catch (Exception e) { Logger.LogError("Failed to initialize CommandProvider for extension."); Logger.LogError($"Extension was {Extension!.PackageFamilyName}"); Logger.LogError(e.ToString()); } isValid = true; } private ProviderSettings GetProviderSettings(SettingsModel settings) { return settings.GetProviderSettings(this); } public async Task LoadTopLevelCommands(IServiceProvider serviceProvider, WeakReference pageContext) { if (!isValid) { IsActive = false; RecallFromCache(); return; } var settings = serviceProvider.GetService()!; var providerSettings = GetProviderSettings(settings); IsActive = providerSettings.IsEnabled; if (!IsActive) { RecallFromCache(); return; } var displayInfoInitialized = false; try { var model = _commandProvider.Unsafe!; Task loadTopLevelCommandsTask = new(model.TopLevelCommands); loadTopLevelCommandsTask.Start(); var commands = await loadTopLevelCommandsTask.ConfigureAwait(false); // On a BG thread here var fallbacks = model.FallbackCommands(); if (model is ICommandProvider2 two) { UnsafePreCacheApiAdditions(two); } // Load pinned commands from saved settings var pinnedCommands = LoadPinnedCommands(model, providerSettings); Id = model.Id; DisplayName = model.DisplayName; Icon = new(model.Icon); Icon.InitializeProperties(); displayInfoInitialized = true; // Update cached display name if (_commandProviderCache is not null && Extension?.ExtensionUniqueId is not null) { _commandProviderCache.Memorize(Extension.ExtensionUniqueId, new CommandProviderCacheItem(model.DisplayName)); } // Note: explicitly not InitializeProperties()ing the settings here. If // we do that, then we'd regress GH #38321 Settings = new(model.Settings, this, _taskScheduler); // We do need to explicitly initialize commands though InitializeCommands(commands, fallbacks, pinnedCommands, serviceProvider, pageContext); Logger.LogDebug($"Loaded commands from {DisplayName} ({ProviderId})"); } catch (Exception e) { Logger.LogError("Failed to load commands from extension"); Logger.LogError($"Extension was {Extension!.PackageFamilyName}"); Logger.LogError(e.ToString()); if (!displayInfoInitialized) { RecallFromCache(); } } } private void RecallFromCache() { var cached = _commandProviderCache?.Recall(ProviderId); if (cached is not null) { DisplayName = cached.DisplayName; } if (string.IsNullOrWhiteSpace(DisplayName)) { DisplayName = Extension?.PackageDisplayName ?? Extension?.PackageFamilyName ?? ProviderId; } } private void InitializeCommands(ICommandItem[] commands, IFallbackCommandItem[] fallbacks, ICommandItem[] pinnedCommands, IServiceProvider serviceProvider, WeakReference pageContext) { var settings = serviceProvider.GetService()!; var providerSettings = GetProviderSettings(settings); var ourContext = GetProviderContext(); var makeAndAdd = (ICommandItem? i, bool fallback) => { CommandItemViewModel commandItemViewModel = new(new(i), pageContext); TopLevelViewModel topLevelViewModel = new(commandItemViewModel, fallback, ExtensionHost, ourContext, settings, providerSettings, serviceProvider, i); topLevelViewModel.InitializeProperties(); return topLevelViewModel; }; var topLevelList = new List(); if (commands is not null) { topLevelList.AddRange(commands.Select(c => makeAndAdd(c, false))); } if (pinnedCommands is not null) { topLevelList.AddRange(pinnedCommands.Select(c => makeAndAdd(c, false))); } TopLevelItems = topLevelList.ToArray(); if (fallbacks is not null) { FallbackItems = fallbacks .Select(c => makeAndAdd(c, true)) .ToArray(); } } private ICommandItem[] LoadPinnedCommands(ICommandProvider model, ProviderSettings providerSettings) { var pinnedItems = new List(); if (model is ICommandProvider4 provider4) { foreach (var pinnedId in providerSettings.PinnedCommandIds) { try { var commandItem = provider4.GetCommandItem(pinnedId); if (commandItem is not null) { pinnedItems.Add(commandItem); } } catch (Exception e) { Logger.LogError($"Failed to load pinned command {pinnedId}: {e.Message}"); } } } return pinnedItems.ToArray(); } private void UnsafePreCacheApiAdditions(ICommandProvider2 provider) { var apiExtensions = provider.GetApiExtensionStubs(); Logger.LogDebug($"Provider supports {apiExtensions.Length} extensions"); foreach (var a in apiExtensions) { if (a is IExtendedAttributesProvider command2) { Logger.LogDebug($"{ProviderId}: Found an IExtendedAttributesProvider"); } } } public void PinCommand(string commandId, IServiceProvider serviceProvider) { var settings = serviceProvider.GetService()!; var providerSettings = GetProviderSettings(settings); if (!providerSettings.PinnedCommandIds.Contains(commandId)) { providerSettings.PinnedCommandIds.Add(commandId); SettingsModel.SaveSettings(settings); // Raise CommandsChanged so the TopLevelCommandManager reloads our commands this.CommandsChanged?.Invoke(this, new ItemsChangedEventArgs(-1)); } } public CommandProviderContext GetProviderContext() { return new() { ProviderId = ProviderId }; } public override bool Equals(object? obj) => obj is CommandProviderWrapper wrapper && isValid == wrapper.isValid; public override int GetHashCode() => _commandProvider.GetHashCode(); private void CommandProvider_ItemsChanged(object sender, IItemsChangedEventArgs args) => // We don't want to handle this ourselves - we want the // TopLevelCommandManager to know about this, so they can remove // our old commands from their own list. // // In handling this, a call will be made to `LoadTopLevelCommands` to // retrieve the new items. this.CommandsChanged?.Invoke(this, args); }