CmdPal: Resilient loading of extensions (#45720)

## Summary of the Pull Request

This PR improves the loading of extensions in the Command Palette and
allows extensions that missed the initial timeout to finish loading.

<!-- Please review the items on the PR checklist before submitting-->
## PR Checklist

- [x] Closes: #45711
<!-- - [ ] Closes: #yyy (add separate lines for additional resolved
issues) -->
- [ ] **Communication:** I've discussed this with core contributors
already. If the work hasn't been agreed, this work might be rejected
- [ ] **Tests:** Added/updated and all pass
- [ ] **Localization:** All end-user-facing strings can be localized
- [ ] **Dev docs:** Added/updated
- [ ] **New binaries:** Added on the required places
- [ ] [JSON for
signing](https://github.com/microsoft/PowerToys/blob/main/.pipelines/ESRPSigning_core.json)
for new binaries
- [ ] [WXS for
installer](https://github.com/microsoft/PowerToys/blob/main/installer/PowerToysSetup/Product.wxs)
for new binaries and localization folder
- [ ] [YML for CI
pipeline](https://github.com/microsoft/PowerToys/blob/main/.pipelines/ci/templates/build-powertoys-steps.yml)
for new test projects
- [ ] [YML for signed
pipeline](https://github.com/microsoft/PowerToys/blob/main/.pipelines/release.yml)
- [ ] **Documentation updated:** If checked, please file a pull request
on [our docs
repo](https://github.com/MicrosoftDocs/windows-uwp/tree/docs/hub/powertoys)
and link it here: #xxx

<!-- Provide a more detailed description of the PR, other things fixed,
or any additional comments/features here -->
## Detailed Description of the Pull Request / Additional comments

<!-- Describe how you validated the behavior. Add automated tests
wherever possible, but list manual validation steps taken as well -->
## Validation Steps Performed
This commit is contained in:
Jiří Polášek
2026-03-03 18:56:44 +01:00
committed by GitHub
parent c066cc3deb
commit ce2e72832c
3 changed files with 312 additions and 72 deletions

View File

@@ -976,6 +976,7 @@ NTAPI
ntdll
NTSTATUS
NTSYSAPI
nullability
NULLCURSOR
nullonfailure
numberbox

View File

@@ -215,9 +215,7 @@ public sealed class CommandProviderWrapper : ICommandProviderContext
}
catch (Exception e)
{
Logger.LogError("Failed to load commands from extension");
Logger.LogError($"Extension was {Extension!.PackageFamilyName}");
Logger.LogError(e.ToString());
Logger.LogError($"Failed to load commands from extension {Extension!.PackageFamilyName}", e);
if (!displayInfoInitialized)
{

View File

@@ -5,6 +5,7 @@
using System.Collections.Immutable;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using CommunityToolkit.Mvvm.Messaging;
@@ -19,13 +20,18 @@ using Microsoft.Extensions.DependencyInjection;
namespace Microsoft.CmdPal.UI.ViewModels;
public partial class TopLevelCommandManager : ObservableObject,
public sealed partial class TopLevelCommandManager : ObservableObject,
IRecipient<ReloadCommandsMessage>,
IRecipient<PinCommandItemMessage>,
IRecipient<UnpinCommandItemMessage>,
IRecipient<PinToDockMessage>,
IDisposable
{
private static readonly TimeSpan ExtensionStartTimeout = TimeSpan.FromSeconds(10);
private static readonly TimeSpan CommandLoadTimeout = TimeSpan.FromSeconds(10);
private static readonly TimeSpan BackgroundStartTimeout = TimeSpan.FromSeconds(60);
private static readonly TimeSpan BackgroundCommandLoadTimeout = TimeSpan.FromSeconds(60);
private readonly IServiceProvider _serviceProvider;
private readonly ICommandProviderCache _commandProviderCache;
private readonly TaskScheduler _taskScheduler;
@@ -39,11 +45,14 @@ public partial class TopLevelCommandManager : ObservableObject,
// deadlock.
private readonly Lock _dockBandsLock = new();
private readonly SupersedingAsyncGate _reloadCommandsGate;
private CancellationTokenSource _extensionLoadCts = new();
private CancellationToken _currentExtensionLoadCancellationToken;
public TopLevelCommandManager(IServiceProvider serviceProvider, ICommandProviderCache commandProviderCache)
{
_serviceProvider = serviceProvider;
_commandProviderCache = commandProviderCache;
_currentExtensionLoadCancellationToken = _extensionLoadCts.Token;
_taskScheduler = _serviceProvider.GetService<TaskScheduler>()!;
WeakReferenceMessenger.Default.Register<ReloadCommandsMessage>(this);
WeakReferenceMessenger.Default.Register<PinCommandItemMessage>(this);
@@ -260,8 +269,15 @@ public partial class TopLevelCommandManager : ObservableObject,
private async Task ReloadAllCommandsAsyncCore(CancellationToken cancellationToken)
{
IsLoading = true;
// Invalidate any background continuations from the previous load cycle
await _extensionLoadCts.CancelAsync().ConfigureAwait(false);
_extensionLoadCts.Dispose();
_extensionLoadCts = new();
_currentExtensionLoadCancellationToken = _extensionLoadCts.Token;
var extensionService = _serviceProvider.GetService<IExtensionService>()!;
await extensionService.SignalStopExtensionsAsync();
await extensionService.SignalStopExtensionsAsync().ConfigureAwait(false);
lock (TopLevelCommands)
{
@@ -273,8 +289,8 @@ public partial class TopLevelCommandManager : ObservableObject,
DockBands.Clear();
}
await LoadBuiltinsAsync();
_ = Task.Run(LoadExtensionsAsync);
await LoadBuiltinsAsync().ConfigureAwait(false);
_ = Task.Run(LoadExtensionsAsync, cancellationToken);
}
// Load commands from our extensions. Called on a background thread.
@@ -292,16 +308,15 @@ public partial class TopLevelCommandManager : ObservableObject,
extensionService.OnExtensionAdded -= ExtensionService_OnExtensionAdded;
extensionService.OnExtensionRemoved -= ExtensionService_OnExtensionRemoved;
var extensions = (await extensionService.GetInstalledExtensionsAsync()).ToImmutableList();
var ct = _currentExtensionLoadCancellationToken;
var extensions = (await extensionService.GetInstalledExtensionsAsync().ConfigureAwait(false)).ToImmutableList();
lock (_commandProvidersLock)
{
_extensionCommandProviders.Clear();
}
if (extensions is not null)
{
await StartExtensionsAndGetCommands(extensions);
}
await StartExtensionsAndGetCommands(extensions, ct).ConfigureAwait(false);
extensionService.OnExtensionAdded += ExtensionService_OnExtensionAdded;
extensionService.OnExtensionRemoved += ExtensionService_OnExtensionRemoved;
@@ -316,46 +331,219 @@ public partial class TopLevelCommandManager : ObservableObject,
private void ExtensionService_OnExtensionAdded(IExtensionService sender, IEnumerable<IExtensionWrapper> extensions)
{
var ct = _currentExtensionLoadCancellationToken;
// When we get an extension install event, hop off to a BG thread
_ = Task.Run(async () =>
{
// for each newly installed extension, start it and get commands
// from it. One single package might have more than one
// IExtensionWrapper in it.
await StartExtensionsAndGetCommands(extensions);
});
_ = Task.Run(
async () =>
{
// for each newly installed extension, start it and get commands
// from it. One single package might have more than one
// IExtensionWrapper in it.
await StartExtensionsAndGetCommands(extensions, ct).ConfigureAwait(false);
},
ct);
}
private async Task StartExtensionsAndGetCommands(IEnumerable<IExtensionWrapper> extensions)
private async Task StartExtensionsAndGetCommands(IEnumerable<IExtensionWrapper> extensions, CancellationToken ct)
{
var timer = new Stopwatch();
timer.Start();
var timer = Stopwatch.StartNew();
// Start all extensions in parallel
var startTasks = extensions.Select(StartExtensionWithTimeoutAsync);
var startResults = await Task.WhenAll(extensions.Select(TryStartExtensionAsync)).ConfigureAwait(false);
// Wait for all extensions to start
var wrappers = (await Task.WhenAll(startTasks)).Where(wrapper => wrapper is not null).Select(w => w!).ToList();
var startedWrappers = new List<CommandProviderWrapper>();
foreach (var r in startResults)
{
if (r.IsStarted)
{
startedWrappers.Add(r.Wrapper);
}
else if (r.IsTimedOut)
{
_ = StartExtensionWhenReadyAsync(r.Extension, r.PendingStartTask, r.Stopwatch, ct);
}
}
// Register started extensions and load their commands
var loadSummary = await RegisterAndLoadCommandsAsync(startedWrappers, ct).ConfigureAwait(false);
timer.Stop();
Logger.LogInfo($"Loaded {loadSummary.CommandCount} command(s) and {loadSummary.DockBandCount} band(s) from {startedWrappers.Count} extension(s) in {timer.ElapsedMilliseconds} ms");
}
private async Task<RegisterAndLoadSummary> RegisterAndLoadCommandsAsync(ICollection<CommandProviderWrapper> wrappers, CancellationToken ct)
{
lock (_commandProvidersLock)
{
_extensionCommandProviders.AddRange(wrappers);
}
// Load the commands from the providers in parallel
var loadTasks = wrappers.Select(LoadCommandsWithTimeoutAsync);
var loadResults = await Task.WhenAll(wrappers.Select(w => TryLoadCommandsAsync(w, ct))).ConfigureAwait(false);
var commandSets = (await Task.WhenAll(loadTasks)).Where(results => results is not null).Select(r => r!).ToList();
var totalCommands = 0;
var totalDockBands = 0;
var timedOut = new List<CommandLoadResult>();
List<TopLevelViewModel> commandsToAdd = [];
List<TopLevelViewModel> dockBandsToAdd = [];
foreach (var providerObjects in commandSets)
foreach (var r in loadResults)
{
var commandsCount = providerObjects.Commands?.Count() ?? 0;
var bandsCount = providerObjects.DockBands?.Count() ?? 0;
Logger.LogDebug($"(some provider) Loaded {commandsCount} commands and {bandsCount} bands");
lock (TopLevelCommands)
if (r.IsLoaded)
{
if (providerObjects.Commands is IEnumerable<TopLevelViewModel> commands)
var commands = r.TopLevelObjectSets.Commands;
if (commands is not null)
{
foreach (var c in commands)
{
commandsToAdd.Add(c);
totalCommands++;
}
}
var bands = r.TopLevelObjectSets.DockBands;
if (bands is not null)
{
foreach (var b in bands)
{
dockBandsToAdd.Add(b);
totalDockBands++;
}
}
}
else if (r.IsTimedOut)
{
timedOut.Add(r);
}
}
lock (TopLevelCommands)
{
foreach (var c in commandsToAdd)
{
TopLevelCommands.Add(c);
}
}
lock (_dockBandsLock)
{
foreach (var b in dockBandsToAdd)
{
DockBands.Add(b);
}
}
// Fire background continuations for timed-out loads outside the lock
foreach (var r in timedOut)
{
// It's weird to repeat the condition here, but it allows the compiler to track nullability of other properties
if (r.IsTimedOut)
{
_ = AppendCommandsWhenReadyAsync(r.Wrapper, r.PendingLoadTask, r.Stopwatch, ct);
}
}
return new RegisterAndLoadSummary(totalCommands, totalDockBands);
}
private async Task<ExtensionStartResult> TryStartExtensionAsync(IExtensionWrapper extension)
{
Logger.LogDebug($"Starting {extension.PackageFullName}");
var sw = Stopwatch.StartNew();
var ct = _currentExtensionLoadCancellationToken;
var startTask = extension.StartExtensionAsync();
try
{
await startTask.WaitAsync(ExtensionStartTimeout, ct).ConfigureAwait(false);
Logger.LogInfo($"Started extension {extension.PackageFullName} in {sw.ElapsedMilliseconds} ms");
return ExtensionStartResult.Started(extension, new CommandProviderWrapper(extension, _taskScheduler, _commandProviderCache));
}
catch (TimeoutException)
{
Logger.LogWarning($"Starting extension {extension.PackageFullName} timed out after {sw.ElapsedMilliseconds} ms, continuing in background");
return ExtensionStartResult.TimedOut(extension, startTask, sw);
}
catch (OperationCanceledException)
{
Logger.LogDebug($"Starting extension {extension.PackageFullName} was cancelled after {sw.ElapsedMilliseconds} ms");
return ExtensionStartResult.Failed(extension);
}
catch (Exception ex)
{
Logger.LogError($"Failed to start extension {extension.PackageFullName} after {sw.ElapsedMilliseconds} ms: {ex}");
return ExtensionStartResult.Failed(extension);
}
}
private async Task StartExtensionWhenReadyAsync(
IExtensionWrapper extension,
Task startTask,
Stopwatch sw,
CancellationToken ct)
{
try
{
await startTask.WaitAsync(BackgroundStartTimeout, ct).ConfigureAwait(false);
var wrapper = new CommandProviderWrapper(extension, _taskScheduler, _commandProviderCache);
Logger.LogInfo($"Late-started extension {extension.PackageFullName} in {sw.ElapsedMilliseconds} ms, loading commands and bands");
await RegisterAndLoadCommandsAsync([wrapper], ct).ConfigureAwait(false);
}
catch (OperationCanceledException)
{
// Reload happened -- discard stale results
}
catch (Exception ex)
{
Logger.LogError($"Background start/load of extension {extension.PackageFullName} failed after {sw.ElapsedMilliseconds} ms: {ex}");
}
}
private async Task<CommandLoadResult> TryLoadCommandsAsync(CommandProviderWrapper wrapper, CancellationToken ct)
{
var sw = Stopwatch.StartNew();
var loadTask = LoadTopLevelCommandsFromProvider(wrapper);
try
{
var result = await loadTask.WaitAsync(CommandLoadTimeout, ct).ConfigureAwait(false);
var commandCount = result.Commands?.Count ?? 0;
var dockBandCount = result.DockBands?.Count ?? 0;
Logger.LogInfo($"Loaded {commandCount} command(s) and {dockBandCount} band(s) from {wrapper.ExtensionHost?.Extension?.PackageFullName} in {sw.ElapsedMilliseconds} ms");
return CommandLoadResult.Loaded(wrapper, result);
}
catch (TimeoutException)
{
Logger.LogWarning($"Loading commands and bands from {wrapper.ExtensionHost?.Extension?.PackageFullName} timed out after {sw.ElapsedMilliseconds} ms, continuing in background");
return CommandLoadResult.TimedOut(wrapper, loadTask, sw);
}
catch (OperationCanceledException)
{
Logger.LogDebug($"Loading commands and bands from {wrapper.ExtensionHost?.Extension?.PackageFullName} was cancelled after {sw.ElapsedMilliseconds} ms");
return CommandLoadResult.Failed(wrapper);
}
catch (Exception ex)
{
Logger.LogError($"Failed to load commands and bands for extension {wrapper.ExtensionHost?.Extension?.PackageFullName} after {sw.ElapsedMilliseconds} ms: {ex}");
return CommandLoadResult.Failed(wrapper);
}
}
private async Task AppendCommandsWhenReadyAsync(
CommandProviderWrapper wrapper,
Task<TopLevelObjectSets> loadTask,
Stopwatch sw,
CancellationToken ct)
{
try
{
var topLevelObjectSets = await loadTask.WaitAsync(BackgroundCommandLoadTimeout, ct).ConfigureAwait(false);
var commands = topLevelObjectSets.Commands;
if (commands is not null)
{
lock (TopLevelCommands)
{
foreach (var c in commands)
{
@@ -364,57 +552,30 @@ public partial class TopLevelCommandManager : ObservableObject,
}
}
lock (_dockBandsLock)
var dockBands = topLevelObjectSets.DockBands;
if (dockBands is not null)
{
if (providerObjects.DockBands is IEnumerable<TopLevelViewModel> bands)
lock (_dockBandsLock)
{
foreach (var c in bands)
foreach (var band in dockBands)
{
DockBands.Add(c);
DockBands.Add(band);
}
}
}
Logger.LogInfo($"Late-loaded {commands?.Count ?? 0} command(s) and {dockBands?.Count ?? 0} band(s) from {wrapper.ExtensionHost?.Extension?.PackageFullName} in {sw.ElapsedMilliseconds} ms");
}
timer.Stop();
Logger.LogDebug($"Loading extensions took {timer.ElapsedMilliseconds} ms");
}
private async Task<CommandProviderWrapper?> StartExtensionWithTimeoutAsync(IExtensionWrapper extension)
{
Logger.LogDebug($"Starting {extension.PackageFullName}");
try
catch (OperationCanceledException)
{
await extension.StartExtensionAsync().WaitAsync(TimeSpan.FromSeconds(10));
return new CommandProviderWrapper(extension, _taskScheduler, _commandProviderCache);
// Reload happened - discard stale results
}
catch (Exception ex)
{
Logger.LogError($"Failed to start extension {extension.PackageFullName}: {ex}");
return null; // Return null for failed extensions
Logger.LogError($"Background loading of commands and bands from {wrapper.ExtensionHost?.Extension?.PackageFullName} failed after {sw.ElapsedMilliseconds} ms: {ex}");
}
}
private record TopLevelObjectSets(IEnumerable<TopLevelViewModel>? Commands, IEnumerable<TopLevelViewModel>? DockBands);
private async Task<TopLevelObjectSets?> LoadCommandsWithTimeoutAsync(CommandProviderWrapper wrapper)
{
try
{
return await LoadTopLevelCommandsFromProvider(wrapper!).WaitAsync(TimeSpan.FromSeconds(10));
}
catch (TimeoutException)
{
Logger.LogError($"Loading commands from {wrapper!.ExtensionHost?.Extension?.PackageFullName} timed out");
}
catch (Exception ex)
{
Logger.LogError($"Failed to load commands for extension {wrapper!.ExtensionHost?.Extension?.PackageFullName}: {ex}");
}
return null;
}
private void ExtensionService_OnExtensionRemoved(IExtensionService sender, IEnumerable<IExtensionWrapper> extensions)
{
// When we get an extension uninstall event, hop off to a BG thread
@@ -515,7 +676,7 @@ public partial class TopLevelCommandManager : ObservableObject,
}
public void Receive(ReloadCommandsMessage message) =>
ReloadAllCommandsAsync().ConfigureAwait(false);
_ = ReloadAllCommandsAsync();
public void Receive(PinCommandItemMessage message)
{
@@ -611,7 +772,87 @@ public partial class TopLevelCommandManager : ObservableObject,
public void Dispose()
{
_extensionLoadCts.Cancel();
_extensionLoadCts.Dispose();
_reloadCommandsGate.Dispose();
GC.SuppressFinalize(this);
}
private sealed class ExtensionStartResult
{
public IExtensionWrapper Extension { get; }
public CommandProviderWrapper? Wrapper { get; private init; }
public Task? PendingStartTask { get; private init; }
public Stopwatch? Stopwatch { get; private init; }
[MemberNotNullWhen(true, nameof(Wrapper))]
public bool IsStarted => Wrapper is not null;
[MemberNotNullWhen(true, nameof(PendingStartTask), nameof(Stopwatch))]
public bool IsTimedOut => PendingStartTask is not null;
private ExtensionStartResult(IExtensionWrapper extension)
{
Extension = extension;
}
public static ExtensionStartResult Started(IExtensionWrapper extension, CommandProviderWrapper wrapper)
{
return new ExtensionStartResult(extension) { Wrapper = wrapper };
}
public static ExtensionStartResult TimedOut(IExtensionWrapper extension, Task pendingStartTask, Stopwatch sw)
{
return new ExtensionStartResult(extension) { PendingStartTask = pendingStartTask, Stopwatch = sw };
}
public static ExtensionStartResult Failed(IExtensionWrapper extension)
{
return new ExtensionStartResult(extension);
}
}
private sealed class CommandLoadResult
{
public TopLevelObjectSets? TopLevelObjectSets { get; private init; }
public CommandProviderWrapper Wrapper { get; }
public Task<TopLevelObjectSets>? PendingLoadTask { get; private init; }
public Stopwatch? Stopwatch { get; private init; }
[MemberNotNullWhen(true, nameof(TopLevelObjectSets))]
public bool IsLoaded => TopLevelObjectSets is not null;
[MemberNotNullWhen(true, nameof(PendingLoadTask), nameof(Stopwatch))]
public bool IsTimedOut => PendingLoadTask is not null;
private CommandLoadResult(CommandProviderWrapper wrapper)
{
Wrapper = wrapper;
}
public static CommandLoadResult Loaded(CommandProviderWrapper wrapper, TopLevelObjectSets topLevelObjectSets)
{
return new CommandLoadResult(wrapper) { TopLevelObjectSets = topLevelObjectSets };
}
public static CommandLoadResult TimedOut(CommandProviderWrapper wrapper, Task<TopLevelObjectSets> pendingLoadTask, Stopwatch sw)
{
return new CommandLoadResult(wrapper) { PendingLoadTask = pendingLoadTask, Stopwatch = sw };
}
public static CommandLoadResult Failed(CommandProviderWrapper wrapper)
{
return new CommandLoadResult(wrapper);
}
}
private readonly record struct RegisterAndLoadSummary(int CommandCount, int DockBandCount);
private record TopLevelObjectSets(ICollection<TopLevelViewModel>? Commands, ICollection<TopLevelViewModel>? DockBands);
}