diff --git a/src/modules/cmdpal/Microsoft.CmdPal.UI.ViewModels/TopLevelCommandManager.cs b/src/modules/cmdpal/Microsoft.CmdPal.UI.ViewModels/TopLevelCommandManager.cs index b33a515b7a..540659f9e1 100644 --- a/src/modules/cmdpal/Microsoft.CmdPal.UI.ViewModels/TopLevelCommandManager.cs +++ b/src/modules/cmdpal/Microsoft.CmdPal.UI.ViewModels/TopLevelCommandManager.cs @@ -4,6 +4,7 @@ using System.Collections.Immutable; using System.Collections.ObjectModel; +using System.Diagnostics; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; using CommunityToolkit.Mvvm.Messaging; @@ -53,53 +54,44 @@ public partial class TopLevelCommandManager : ObservableObject, { CommandProviderWrapper wrapper = new(provider, _taskScheduler); _builtInCommands.Add(wrapper); - await LoadTopLevelCommandsFromProvider(wrapper); + var commands = await LoadTopLevelCommandsFromProvider(wrapper); + lock (TopLevelCommands) + { + foreach (var c in commands) + { + TopLevelCommands.Add(c); + } + } } return true; } // May be called from a background thread - private async Task LoadTopLevelCommandsFromProvider(CommandProviderWrapper commandProvider) + private async Task> LoadTopLevelCommandsFromProvider(CommandProviderWrapper commandProvider) { WeakReference weakSelf = new(this); await commandProvider.LoadTopLevelCommands(_serviceProvider, weakSelf); var settings = _serviceProvider.GetService()!; - var makeAndAdd = (ICommandItem? i, bool fallback) => + + List commands = []; + + foreach (var item in commandProvider.TopLevelItems) { - var commandItemViewModel = new CommandItemViewModel(new(i), weakSelf); - var topLevelViewModel = new TopLevelViewModel(commandItemViewModel, fallback, commandProvider.ExtensionHost, commandProvider.ProviderId, settings, _serviceProvider); + commands.Add(item); + } - lock (TopLevelCommands) - { - TopLevelCommands.Add(topLevelViewModel); - } - }; - - await Task.Factory.StartNew( - () => - { - lock (TopLevelCommands) - { - foreach (var item in commandProvider.TopLevelItems) - { - TopLevelCommands.Add(item); - } - - foreach (var item in commandProvider.FallbackItems) - { - TopLevelCommands.Add(item); - } - } - }, - CancellationToken.None, - TaskCreationOptions.None, - _taskScheduler); + foreach (var item in commandProvider.FallbackItems) + { + commands.Add(item); + } commandProvider.CommandsChanged -= CommandProvider_CommandsChanged; commandProvider.CommandsChanged += CommandProvider_CommandsChanged; + + return commands; } // By all accounts, we're already on a background thread (the COM call @@ -239,25 +231,71 @@ public partial class TopLevelCommandManager : ObservableObject, private async Task StartExtensionsAndGetCommands(IEnumerable extensions) { - // TODO This most definitely needs a lock - foreach (var extension in extensions) - { - Logger.LogDebug($"Starting {extension.PackageFullName}"); - try - { - // start it ... - await extension.StartExtensionAsync(); + var timer = new Stopwatch(); + timer.Start(); - // ... and fetch the command provider from it. - CommandProviderWrapper wrapper = new(extension, _taskScheduler); - _extensionCommandProviders.Add(wrapper); - await LoadTopLevelCommandsFromProvider(wrapper); - } - catch (Exception ex) + // Start all extensions in parallel + var startTasks = extensions.Select(StartExtensionWithTimeoutAsync); + + // Wait for all extensions to start + var wrappers = (await Task.WhenAll(startTasks)).Where(wrapper => wrapper != null).Select(w => w!).ToList(); + + foreach (var wrapper in wrappers) + { + _extensionCommandProviders.Add(wrapper!); + } + + // Load the commands from the providers in parallel + var loadTasks = wrappers.Select(LoadCommandsWithTimeoutAsync); + + var commandSets = (await Task.WhenAll(loadTasks)).Where(results => results != null).Select(r => r!).ToList(); + + lock (TopLevelCommands) + { + foreach (var commands in commandSets) { - Logger.LogError(ex.ToString()); + foreach (var c in commands) + { + TopLevelCommands.Add(c); + } } } + + timer.Stop(); + Logger.LogDebug($"Loading extensions took {timer.ElapsedMilliseconds} ms"); + } + + private async Task StartExtensionWithTimeoutAsync(IExtensionWrapper extension) + { + Logger.LogDebug($"Starting {extension.PackageFullName}"); + try + { + await extension.StartExtensionAsync().WaitAsync(TimeSpan.FromSeconds(10)); + return new CommandProviderWrapper(extension, _taskScheduler); + } + catch (Exception ex) + { + Logger.LogError($"Failed to start extension {extension.PackageFullName}: {ex}"); + return null; // Return null for failed extensions + } + } + + private async Task?> LoadCommandsWithTimeoutAsync(CommandProviderWrapper wrapper) + { + try + { + return await LoadTopLevelCommandsFromProvider(wrapper!).WaitAsync(TimeSpan.FromSeconds(10)); + } + catch (TimeoutException) + { + Logger.LogError($"Loading commands from {wrapper!.ExtensionHost?.Extension?.PackageFullName} timed out"); + } + catch (Exception ex) + { + Logger.LogError($"Failed to load commands for extension {wrapper!.ExtensionHost?.Extension?.PackageFullName}: {ex}"); + } + + return null; } private void ExtensionService_OnExtensionRemoved(IExtensionService sender, IEnumerable extensions)