如何防止 Task 上的同步延续?

我有一些库(套接字网络)代码,它们基于 TaskCompletionSource<T>为挂起的请求响应提供了一个基于 Task的 API。然而,TPL 中有一个烦恼,似乎不可能阻止同步延续。我希望 喜欢能够做的是:

  • 告诉一个 TaskCompletionSource<T>,是不应该允许来电者附加与 TaskContinuationOptions.ExecuteSynchronously,或
  • 以指定应忽略 TaskContinuationOptions.ExecuteSynchronously的方式设置结果(SetResult/TrySetResult) ,而是使用池

具体来说,我遇到的问题是,传入的数据是由一个专门的读取器处理的,如果调用者可以连接到 TaskContinuationOptions.ExecuteSynchronously,那么他们可以拖延读取器(这不仅仅影响到他们)。在此之前,我已经通过一些技巧来解决这个问题,这些技巧可以检测到是否存在 任何延续,如果存在的话,它会将完成推送到 ThreadPool,但是如果调用者已经饱和了他们的工作队列,这会产生很大的影响,因为完成不会被及时处理。如果它们使用 Task.Wait()(或类似的) ,那么它们本质上将自身死锁。同样,这就是为什么读取器在一个专用线程上而不是使用 worker 的原因。

因此,在我试图向 TPL 团队唠叨之前: 我是否缺少一个选项?

要点:

  • 我不希望外部呼叫者能够劫持我的线程
  • 我不能使用 ThreadPool作为实现,因为它需要在池达到饱和时工作

下面的示例生成输出(订单可能因时间不同而不同) :

Continuation on: Main thread
Press [return]
Continuation on: Thread pool

问题在于随机调用方设法在“ Main thread”上获得了一个延续。在实际的代码中,这会打断主要读取器; 不好的事情!

密码:

using System;
using System.Threading;
using System.Threading.Tasks;


static class Program
{
static void Identify()
{
var thread = Thread.CurrentThread;
string name = thread.IsThreadPoolThread
? "Thread pool" : thread.Name;
if (string.IsNullOrEmpty(name))
name = "#" + thread.ManagedThreadId;
Console.WriteLine("Continuation on: " + name);
}
static void Main()
{
Thread.CurrentThread.Name = "Main thread";
var source = new TaskCompletionSource<int>();
var task = source.Task;
task.ContinueWith(delegate {
Identify();
});
task.ContinueWith(delegate {
Identify();
}, TaskContinuationOptions.ExecuteSynchronously);
source.TrySetResult(123);
Console.WriteLine("Press [return]");
Console.ReadLine();
}
}
9678 次浏览

What about instead of doing

var task = source.Task;

you do this instead

var task = source.Task.ContinueWith<Int32>( x => x.Result );

Thus you are always adding one continuation which will be executed asynchronously and then it doesn't matter if the subscribers want a continuation in the same context. It's sort of currying the task, isn't it?

Updated, I posted a separate answer to deal with ContinueWith as opposed to await (because ContinueWith doesn't care about the current synchronization context).

You could use a dumb synchronization context to impose asynchrony upon continuation triggered by calling SetResult/SetCancelled/SetException on TaskCompletionSource. I believe the current synchronization context (at the point of await tcs.Task) is the criteria TPL uses to decide whether to make such continuation synchronous or asynchronous.

The following works for me:

if (notifyAsync)
{
tcs.SetResultAsync(null);
}
else
{
tcs.SetResult(null);
}

SetResultAsync is implemented like this:

public static class TaskExt
{
static public void SetResultAsync<T>(this TaskCompletionSource<T> tcs, T result)
{
FakeSynchronizationContext.Execute(() => tcs.SetResult(result));
}


// FakeSynchronizationContext
class FakeSynchronizationContext : SynchronizationContext
{
private static readonly ThreadLocal<FakeSynchronizationContext> s_context =
new ThreadLocal<FakeSynchronizationContext>(() => new FakeSynchronizationContext());


private FakeSynchronizationContext() { }


public static FakeSynchronizationContext Instance { get { return s_context.Value; } }


public static void Execute(Action action)
{
var savedContext = SynchronizationContext.Current;
SynchronizationContext.SetSynchronizationContext(FakeSynchronizationContext.Instance);
try
{
action();
}
finally
{
SynchronizationContext.SetSynchronizationContext(savedContext);
}
}


// SynchronizationContext methods


public override SynchronizationContext CreateCopy()
{
return this;
}


public override void OperationStarted()
{
throw new NotImplementedException("OperationStarted");
}


public override void OperationCompleted()
{
throw new NotImplementedException("OperationCompleted");
}


public override void Post(SendOrPostCallback d, object state)
{
throw new NotImplementedException("Post");
}


public override void Send(SendOrPostCallback d, object state)
{
throw new NotImplementedException("Send");
}
}
}

SynchronizationContext.SetSynchronizationContext is very cheap in terms of the overhead it adds. In fact, a very similar approach is taken by the implementation of WPF Dispatcher.BeginInvoke.

TPL compares the target synchronization context at the point of await to that of the point of tcs.SetResult. If the synchronization context is the same (or there is no synchronization context at both places), the continuation is called directly, synchronously. Otherwise, it's queued using SynchronizationContext.Post on the target synchronization context, i.e., the normal await behavior. What this approach does is always impose the SynchronizationContext.Post behavior (or a pool thread continuation if there's no target synchronization context).

Updated, this won't work for task.ContinueWith, because ContinueWith doesn't care about the current synchronization context. It however works for await task (fiddle). It also does work for await task.ConfigureAwait(false).

OTOH, this approach works for ContinueWith.

if you can and are ready to use reflection, this should do it;

public static class MakeItAsync
{
static public void TrySetAsync<T>(this TaskCompletionSource<T> source, T result)
{
var continuation = typeof(Task).GetField("m_continuationObject", BindingFlags.NonPublic | BindingFlags.GetField | BindingFlags.Instance);
var continuations = (List<object>)continuation.GetValue(source.Task);


foreach (object c in continuations)
{
var option = c.GetType().GetField("m_options", BindingFlags.NonPublic | BindingFlags.GetField | BindingFlags.Instance);
var options = (TaskContinuationOptions)option.GetValue(c);


options &= ~TaskContinuationOptions.ExecuteSynchronously;
option.SetValue(c, options);
}


source.TrySetResult(result);
}
}

I don't think there's anything in TPL which would provides explicit API control over TaskCompletionSource.SetResult continuations. I decided to keep my initial answer for controlling this behavior for async/await scenarios.

Here is another solution which imposes asynchronous upon ContinueWith, if the tcs.SetResult-triggered continuation takes place on the same thread the SetResult was called on:

public static class TaskExt
{
static readonly ConcurrentDictionary<Task, Thread> s_tcsTasks =
new ConcurrentDictionary<Task, Thread>();


// SetResultAsync
static public void SetResultAsync<TResult>(
this TaskCompletionSource<TResult> @this,
TResult result)
{
s_tcsTasks.TryAdd(@this.Task, Thread.CurrentThread);
try
{
@this.SetResult(result);
}
finally
{
Thread thread;
s_tcsTasks.TryRemove(@this.Task, out thread);
}
}


// ContinueWithAsync, TODO: more overrides
static public Task ContinueWithAsync<TResult>(
this Task<TResult> @this,
Action<Task<TResult>> action,
TaskContinuationOptions continuationOptions = TaskContinuationOptions.None)
{
return @this.ContinueWith((Func<Task<TResult>, Task>)(t =>
{
Thread thread = null;
s_tcsTasks.TryGetValue(t, out thread);
if (Thread.CurrentThread == thread)
{
// same thread which called SetResultAsync, avoid potential deadlocks


// using thread pool
return Task.Run(() => action(t));


// not using thread pool (TaskCreationOptions.LongRunning creates a normal thread)
// return Task.Factory.StartNew(() => action(t), TaskCreationOptions.LongRunning);
}
else
{
// continue on the same thread
var task = new Task(() => action(t));
task.RunSynchronously();
return Task.FromResult(task);
}
}), continuationOptions).Unwrap();
}
}

Updated to address the comment:

I don't control the caller - I can't get them to use a specific continue-with variant: if I could, the problem would not exist in the first place

I wasn't aware you don't control the caller. Nevertheless, if you don't control it, you're probably not passing the TaskCompletionSource object directly to the caller, either. Logically, you'd be passing the token part of it, i.e. tcs.Task. In which case, the solution might be even easier, by adding another extension method to the above:

// ImposeAsync, TODO: more overrides
static public Task<TResult> ImposeAsync<TResult>(this Task<TResult> @this)
{
return @this.ContinueWith(new Func<Task<TResult>, Task<TResult>>(antecedent =>
{
Thread thread = null;
s_tcsTasks.TryGetValue(antecedent, out thread);
if (Thread.CurrentThread == thread)
{
// continue on a pool thread
return antecedent.ContinueWith(t => t,
TaskContinuationOptions.None).Unwrap();
}
else
{
return antecedent;
}
}), TaskContinuationOptions.ExecuteSynchronously).Unwrap();
}

Use:

// library code
var source = new TaskCompletionSource<int>();
var task = source.Task.ImposeAsync();
// ...


// client code
task.ContinueWith(delegate
{
Identify();
}, TaskContinuationOptions.ExecuteSynchronously);


// ...
// library code
source.SetResultAsync(123);

This actually works for both ABC0 and ContinueWith (fiddle) and is free of reflection hacks.

New in .NET 4.6:

.NET 4.6 contains a new TaskCreationOptions: RunContinuationsAsynchronously.


Since you're willing to use Reflection to access private fields...

You can mark the TCS's Task with the TASK_STATE_THREAD_WAS_ABORTED flag, which would cause all continuations not to be inlined.

const int TASK_STATE_THREAD_WAS_ABORTED = 134217728;


var stateField = typeof(Task).GetField("m_stateFlags", BindingFlags.NonPublic | BindingFlags.Instance);
stateField.SetValue(task, (int) stateField.GetValue(task) | TASK_STATE_THREAD_WAS_ABORTED);

Edit:

Instead of using Reflection emit, I suggest you use expressions. This is much more readable and has the advantage of being PCL-compatible:

var taskParameter = Expression.Parameter(typeof (Task));
const string stateFlagsFieldName = "m_stateFlags";
var setter =
Expression.Lambda<Action<Task>>(
Expression.Assign(Expression.Field(taskParameter, stateFlagsFieldName),
Expression.Or(Expression.Field(taskParameter, stateFlagsFieldName),
Expression.Constant(TASK_STATE_THREAD_WAS_ABORTED))), taskParameter).Compile();

Without using Reflection:

If anyone's interested, I've figured out a way to do this without Reflection, but it is a bit "dirty" as well, and of course carries a non-negligible perf penalty:

try
{
Thread.CurrentThread.Abort();
}
catch (ThreadAbortException)
{
source.TrySetResult(123);
Thread.ResetAbort();
}

The simulate abort approach looked really good, but led to the TPL hijacking threads in some scenarios.

I then had an implementation that was similar to checking the continuation object, but just checking for any continuation since there are actually too many scenarios for the given code to work well, but that meant that even things like Task.Wait resulted in a thread-pool lookup.

Ultimately, after inspecting lots and lots of IL, the only safe and useful scenario is the SetOnInvokeMres scenario (manual-reset-event-slim continuation). There are lots of other scenarios:

  • some aren't safe, and lead to thread hijacking
  • the rest aren't useful, as they ultimately lead to the thread-pool

So in the end, I opted to check for a non-null continuation-object; if it is null, fine (no continuations); if it is non-null, special-case check for SetOnInvokeMres - if it is that: fine (safe to invoke); otherwise, let the thread-pool perform the TrySetComplete, without telling the task to do anything special like spoofing abort. Task.Wait uses the SetOnInvokeMres approach, which is the specific scenario we want to try really hard not to deadlock.

Type taskType = typeof(Task);
FieldInfo continuationField = taskType.GetField("m_continuationObject", BindingFlags.Instance | BindingFlags.NonPublic);
Type safeScenario = taskType.GetNestedType("SetOnInvokeMres", BindingFlags.NonPublic);
if (continuationField != null && continuationField.FieldType == typeof(object) && safeScenario != null)
{
var method = new DynamicMethod("IsSyncSafe", typeof(bool), new[] { typeof(Task) }, typeof(Task), true);
var il = method.GetILGenerator();
var hasContinuation = il.DefineLabel();
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldfld, continuationField);
Label nonNull = il.DefineLabel(), goodReturn = il.DefineLabel();
// check if null
il.Emit(OpCodes.Brtrue_S, nonNull);
il.MarkLabel(goodReturn);
il.Emit(OpCodes.Ldc_I4_1);
il.Emit(OpCodes.Ret);


// check if is a SetOnInvokeMres - if so, we're OK
il.MarkLabel(nonNull);
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldfld, continuationField);
il.Emit(OpCodes.Isinst, safeScenario);
il.Emit(OpCodes.Brtrue_S, goodReturn);


il.Emit(OpCodes.Ldc_I4_0);
il.Emit(OpCodes.Ret);


IsSyncSafe = (Func<Task, bool>)method.CreateDelegate(typeof(Func<Task, bool>));