如何通过 LINQ 将树压平?

所以我有一个简单的树:

class MyNode
{
public MyNode Parent;
public IEnumerable<MyNode> Elements;
int group = 1;
}

我有一个 IEnumerable<MyNode>。我想得到一个所有 MyNode(包括内部节点对象(Elements))的列表作为一个平面列表 Wheregroup == 1。如何通过 LINQ 做这样的事情?

57837 次浏览

You can flatten a tree like this:

IEnumerable<MyNode> Flatten(IEnumerable<MyNode> e) =>
e.SelectMany(c => Flatten(c.Elements)).Concat(new[] { e });

You can then filter by group using Where(...).

To earn some "points for style", convert Flatten to an extension function in a static class.

public static IEnumerable<MyNode> Flatten(this IEnumerable<MyNode> e) =>
e.SelectMany(c => c.Elements.Flatten()).Concat(e);

To earn more points for "even better style", convert Flatten to a generic extension method that takes a tree and a function that produces descendants from a node:

public static IEnumerable<T> Flatten<T>(
this IEnumerable<T> e
,   Func<T,IEnumerable<T>> f
) => e.SelectMany(c => f(c).Flatten(f)).Concat(e);

Call this function like this:

IEnumerable<MyNode> tree = ....
var res = tree.Flatten(node => node.Elements);

If you would prefer flattening in pre-order rather than in post-order, switch around the sides of the Concat(...).

void Main()
{
var allNodes = GetTreeNodes().Flatten(x => x.Elements);


allNodes.Dump();
}


public static class ExtensionMethods
{
public static IEnumerable<T> Flatten<T>(this IEnumerable<T> source, Func<T, IEnumerable<T>> childrenSelector = null)
{
if (source == null)
{
return new List<T>();
}


var list = source;


if (childrenSelector != null)
{
foreach (var item in source)
{
list = list.Concat(childrenSelector(item).Flatten(childrenSelector));
}
}


return list;
}
}


IEnumerable<MyNode> GetTreeNodes() {
return new[] {
new MyNode { Elements = new[] { new MyNode() }},
new MyNode { Elements = new[] { new MyNode(), new MyNode(), new MyNode() }}
};
}


class MyNode
{
public MyNode Parent;
public IEnumerable<MyNode> Elements;
int group = 1;
}

The problem with the accepted answer is that it is inefficient if the tree is deep. If the tree is very deep then it blows the stack. You can solve the problem by using an explicit stack:

public static IEnumerable<MyNode> Traverse(this MyNode root)
{
var stack = new Stack<MyNode>();
stack.Push(root);
while(stack.Count > 0)
{
var current = stack.Pop();
yield return current;
foreach(var child in current.Elements)
stack.Push(child);
}
}

Assuming n nodes in a tree of height h and a branching factor considerably less than n, this method is O(1) in stack space, O(h) in heap space and O(n) in time. The other algorithm given is O(h) in stack, O(1) in heap and O(nh) in time. If the branching factor is small compared to n then h is between O(lg n) and O(n), which illustrates that the naïve algorithm can use a dangerous amount of stack and a large amount of time if h is close to n.

Now that we have a traversal, your query is straightforward:

root.Traverse().Where(item=>item.group == 1);

Just for completeness, here is the combination of the answers from dasblinkenlight and Eric Lippert. Unit tested and everything. :-)

 public static IEnumerable<T> Flatten<T>(
this IEnumerable<T> items,
Func<T, IEnumerable<T>> getChildren)
{
var stack = new Stack<T>();
foreach(var item in items)
stack.Push(item);


while(stack.Count > 0)
{
var current = stack.Pop();
yield return current;


var children = getChildren(current);
if (children == null) continue;


foreach (var child in children)
stack.Push(child);
}
}

In case anyone else finds this, but also needs to know the level after they've flattened the tree, this expands on Konamiman's combination of dasblinkenlight and Eric Lippert's solutions:

    public static IEnumerable<Tuple<T, int>> FlattenWithLevel<T>(
this IEnumerable<T> items,
Func<T, IEnumerable<T>> getChilds)
{
var stack = new Stack<Tuple<T, int>>();
foreach (var item in items)
stack.Push(new Tuple<T, int>(item, 1));


while (stack.Count > 0)
{
var current = stack.Pop();
yield return current;
foreach (var child in getChilds(current.Item1))
stack.Push(new Tuple<T, int>(child, current.Item2 + 1));
}
}

Update:

For people interested in level of nesting (depth). One of the good things about explicit enumerator stack implementation is that at any moment (and in particular when yielding the element) the stack.Count represents the currently processing depth. So taking this into account and utilizing the C#7.0 value tuples, we can simply change the method declaration as follows:

public static IEnumerable<(T Item, int Level)> ExpandWithLevel<T>(
this IEnumerable<T> source, Func<T, IEnumerable<T>> elementSelector)

and yield statement:

yield return (item, stack.Count);

Then we can implement the original method by applying simple Select on the above:

public static IEnumerable<T> Expand<T>(
this IEnumerable<T> source, Func<T, IEnumerable<T>> elementSelector) =>
source.ExpandWithLevel(elementSelector).Select(e => e.Item);

Original:

Surprisingly no one (even Eric) showed the "natural" iterative port of a recursive pre-order DFT, so here it is:

    public static IEnumerable<T> Expand<T>(
this IEnumerable<T> source, Func<T, IEnumerable<T>> elementSelector)
{
var stack = new Stack<IEnumerator<T>>();
var e = source.GetEnumerator();
try
{
while (true)
{
while (e.MoveNext())
{
var item = e.Current;
yield return item;
var elements = elementSelector(item);
if (elements == null) continue;
stack.Push(e);
e = elements.GetEnumerator();
}
if (stack.Count == 0) break;
e.Dispose();
e = stack.Pop();
}
}
finally
{
e.Dispose();
while (stack.Count != 0) stack.Pop().Dispose();
}
}

Combining Dave's and Ivan Stoev's answer in case you need the level of nesting and the list flattened "in order" and not reversed like in the answer given by Konamiman.

 public static class HierarchicalEnumerableUtils
{
private static IEnumerable<Tuple<T, int>> ToLeveled<T>(this IEnumerable<T> source, int level)
{
if (source == null)
{
return null;
}
else
{
return source.Select(item => new Tuple<T, int>(item, level));
}
}


public static IEnumerable<Tuple<T, int>> FlattenWithLevel<T>(this IEnumerable<T> source, Func<T, IEnumerable<T>> elementSelector)
{
var stack = new Stack<IEnumerator<Tuple<T, int>>>();
var leveledSource = source.ToLeveled(0);
var e = leveledSource.GetEnumerator();
try
{
while (true)
{
while (e.MoveNext())
{
var item = e.Current;
yield return item;
var elements = elementSelector(item.Item1).ToLeveled(item.Item2 + 1);
if (elements == null) continue;
stack.Push(e);
e = elements.GetEnumerator();
}
if (stack.Count == 0) break;
e.Dispose();
e = stack.Pop();
}
}
finally
{
e.Dispose();
while (stack.Count != 0) stack.Pop().Dispose();
}
}
}

Building on Konamiman's answer, and the comment that the ordering is unexpected, here's a version with an explicit sort param:

public static IEnumerable<T> TraverseAndFlatten<T, V>(this IEnumerable<T> items, Func<T, IEnumerable<T>> nested, Func<T, V> orderBy)
{
var stack = new Stack<T>();
foreach (var item in items.OrderBy(orderBy))
stack.Push(item);


while (stack.Count > 0)
{
var current = stack.Pop();
yield return current;


var children = nested(current).OrderBy(orderBy);
if (children == null) continue;


foreach (var child in children)
stack.Push(child);
}
}

And a sample usage:

var flattened = doc.TraverseAndFlatten(x => x.DependentDocuments, y => y.Document.DocDated).ToList();

Below is Ivan Stoev's code with the additonal feature of telling the index of every object in the path. E.g. search for "Item_120":

Item_0--Item_00
Item_01


Item_1--Item_10
Item_11
Item_12--Item_120

would return the item and an int array [1,2,0]. Obviously, nesting level is also available, as length of the array.

public static IEnumerable<(T, int[])> Expand<T>(this IEnumerable<T> source, Func<T, IEnumerable<T>> getChildren) {
var stack = new Stack<IEnumerator<T>>();
var e = source.GetEnumerator();
List<int> indexes = new List<int>() { -1 };
try {
while (true) {
while (e.MoveNext()) {
var item = e.Current;
indexes[stack.Count]++;
yield return (item, indexes.Take(stack.Count + 1).ToArray());
var elements = getChildren(item);
if (elements == null) continue;
stack.Push(e);
e = elements.GetEnumerator();
if (indexes.Count == stack.Count)
indexes.Add(-1);
}
if (stack.Count == 0) break;
e.Dispose();
indexes[stack.Count] = -1;
e = stack.Pop();
}
} finally {
e.Dispose();
while (stack.Count != 0) stack.Pop().Dispose();
}
}

I found some small issues with the answers given here:

  • What if the initial list of items is null?
  • What if there is a null value in the list of children?

Built on the previous answers and came up with the following:

public static class IEnumerableExtensions
{
public static IEnumerable<T> Flatten<T>(
this IEnumerable<T> items,
Func<T, IEnumerable<T>> getChildren)
{
if (items == null)
yield break;


var stack = new Stack<T>(items);
while (stack.Count > 0)
{
var current = stack.Pop();
yield return current;


if (current == null) continue;


var children = getChildren(current);
if (children == null) continue;


foreach (var child in children)
stack.Push(child);
}
}
}

And the unit tests:

[TestClass]
public class IEnumerableExtensionsTests
{
[TestMethod]
public void NullList()
{
IEnumerable<Test> items = null;
var flattened = items.Flatten(i => i.Children);
Assert.AreEqual(0, flattened.Count());
}
[TestMethod]
public void EmptyList()
{
var items = new Test[0];
var flattened = items.Flatten(i => i.Children);
Assert.AreEqual(0, flattened.Count());
}
[TestMethod]
public void OneItem()
{
var items = new[] { new Test() };
var flattened = items.Flatten(i => i.Children);
Assert.AreEqual(1, flattened.Count());
}
[TestMethod]
public void OneItemWithChild()
{
var items = new[] { new Test { Id = 1, Children = new[] { new Test { Id = 2 } } } };
var flattened = items.Flatten(i => i.Children);
Assert.AreEqual(2, flattened.Count());
Assert.IsTrue(flattened.Any(i => i.Id == 1));
Assert.IsTrue(flattened.Any(i => i.Id == 2));
}
[TestMethod]
public void OneItemWithNullChild()
{
var items = new[] { new Test { Id = 1, Children = new Test[] { null } } };
var flattened = items.Flatten(i => i.Children);
Assert.AreEqual(2, flattened.Count());
Assert.IsTrue(flattened.Any(i => i.Id == 1));
Assert.IsTrue(flattened.Any(i => i == null));
}
class Test
{
public int Id { get; set; }
public IEnumerable<Test> Children { get; set; }
}
}

A really other option is to have a proper OO design.

e.g. ask the MyNode to return all flatten.

Like this:

class MyNode
{
public MyNode Parent;
public IEnumerable<MyNode> Elements;
int group = 1;


public IEnumerable<MyNode> GetAllNodes()
{
if (Elements == null)
{
return Enumerable.Empty<MyNode>();
}


return Elements.SelectMany(e => e.GetAllNodes());
}
}

Now you could ask the top level MyNode to get all the nodes.

var flatten = topNode.GetAllNodes();

If you can't edit the class, then this isn't an option. But otherwise, I think this is could be preferred of a separate (recursive) LINQ method.

This is using LINQ, So I think this answer is applicable here ;)

Here some ready to use implementation using Queue and returning the Flatten tree me first and then my children.

public static IEnumerable<T> Flatten<T>(this IEnumerable<T> items,
Func<T,IEnumerable<T>> getChildren)
{
if (items == null)
yield break;


var queue = new Queue<T>();


foreach (var item in items) {
if (item == null)
continue;


queue.Enqueue(item);


while (queue.Count > 0) {
var current = queue.Dequeue();
yield return current;


if (current == null)
continue;


var children = getChildren(current);
if (children == null)
continue;


foreach (var child in children)
queue.Enqueue(child);
}
}


}

Every once in awhile I try to scratch at this problem and devise my own solution that supports arbitrarily deep structures (no recursion), performs breadth first traversal, and doesn't abuse too many LINQ queries or preemptively execute recursion on the children. After digging around in the .NET source and trying many solutions, I've finally come up with this solution. It ended up being very close to Ian Stoev's answer (whose answer I only saw just now), however mine doesn't utilize infinite loops or have unusual code flow.

public static IEnumerable<T> Traverse<T>(
this IEnumerable<T> source,
Func<T, IEnumerable<T>> fnRecurse)
{
if (source != null)
{
Stack<IEnumerator<T>> enumerators = new Stack<IEnumerator<T>>();
try
{
enumerators.Push(source.GetEnumerator());
while (enumerators.Count > 0)
{
var top = enumerators.Peek();
while (top.MoveNext())
{
yield return top.Current;


var children = fnRecurse(top.Current);
if (children != null)
{
top = children.GetEnumerator();
enumerators.Push(top);
}
}


enumerators.Pop().Dispose();
}
}
finally
{
while (enumerators.Count > 0)
enumerators.Pop().Dispose();
}
}
}

A working example can be found here.

Most of the answers presented here are producing depth-first or zig-zag sequences. For example starting with the tree below:

        1                   2
/ \                 / \
/   \               /   \
/     \             /     \
/       \           /       \
11       12         21       22
/ \       / \       / \       / \
/   \     /   \     /   \     /   \
111 112   121 122   211 212   221 222

dasblinkenlight's answer produces this flattened sequence:

111, 112, 121, 122, 11, 12, 211, 212, 221, 222, 21, 22, 1, 2

Konamiman's answer (that generalizes Eric Lippert's answer) produces this flattened sequence:

2, 22, 222, 221, 21, 212, 211, 1, 12, 122, 121, 11, 112, 111

Ivan Stoev's answer produces this flattened sequence:

1, 11, 111, 112, 12, 121, 122, 2, 21, 211, 212, 22, 221, 222

If you are interested in a breadth-first sequence like this:

1, 2, 11, 12, 21, 22, 111, 112, 121, 122, 211, 212, 221, 222

...then this is the solution for you:

public static IEnumerable<T> Flatten<T>(this IEnumerable<T> source,
Func<T, IEnumerable<T>> childrenSelector)
{
var queue = new Queue<T>(source);
while (queue.Count > 0)
{
var current = queue.Dequeue();
yield return current;
var children = childrenSelector(current);
if (children == null) continue;
foreach (var child in children) queue.Enqueue(child);
}
}

The difference in the implementation is basically using a Queue instead of a Stack. No actual sorting is taking place.


Caution: this implementation is far from optimal regarding memory-efficiency, since a large percentage of the total number of elements will end up being stored in the internal queue during the enumeration. Stack-based tree-traversals are much more efficient regarding memory usage than Queue-based implementations.

Based on previous answer Pre-order flatten

        public static IEnumerable<T> Flatten<T>(
this IEnumerable<T> e
, Func<T, IEnumerable<T>> f
) => e.Concat(e.SelectMany(c => f(c).Flatten(f)));

A safe generic version, of the best answer, in case you need it.

public static IEnumerable<T> DepthFirstTraversal<T>(
this T root,
Func<T, IEnumerable<T>> branchSelector)
{
ArgumentNullException.ThrowIfNull(branchSelector);
    

var stack = new Stack<T>();
stack.Push(root);
while(stack.Count > 0)
{
var current = stack.Pop();
yield return current;
        

if (current == null)
{
continue;
}
        

foreach(var child in branchSelector(current))
{
stack.Push(child);
}
}
}