为图例中的点设置固定大小

我正在做一些散点图,我想把图例中的点的大小设置为一个固定的等价值。

现在我有这个:

import matplotlib.pyplot as plt
import numpy as np


def rand_data():
return np.random.uniform(low=0., high=1., size=(100,))


# Generate data.
x1, y1 = [rand_data() for i in range(2)]
x2, y2 = [rand_data() for i in range(2)]




plt.figure()
plt.scatter(x1, y1, marker='o', label='first', s=20., c='b')
plt.scatter(x2, y2, marker='o', label='second', s=35., c='r')
# Plot legend.
plt.legend(loc="lower left", markerscale=2., scatterpoints=1, fontsize=10)
plt.show()

它产生了这个:

enter image description here

图例中的点的大小按比例缩放,但不相同。如何在不影响 scatter图形中的大小的情况下,将图例中的点的大小固定为相等的值?

82794 次浏览

我查了一下 matplotlib的源代码。坏消息是,似乎没有任何简单的方法来设置同样大小的点在图例中。这是特别困难的散点图(错误: 请参阅下面的更新)。基本上有两种选择:

  1. 更改 maplotlib代码
  2. 向表示图像中点的 PathCollection对象中添加一个转换。转换(缩放)必须考虑到原始大小。

这两种方法都不是很有趣,尽管第一种方法似乎更简单。scatter图在这方面特别具有挑战性。

然而,我有一个黑客,它可能做你想要的:

import matplotlib.pyplot as plt
import numpy as np


def rand_data():
return np.random.uniform(low=0., high=1., size=(100,))


# Generate data.
x1, y1 = [rand_data() for i in range(2)]
x2, y2 = [rand_data() for i in range(2)]


plt.figure()
plt.plot(x1, y1, 'o', label='first', markersize=np.sqrt(20.), c='b')
plt.plot(x2, y2, 'o', label='second', markersize=np.sqrt(35.), c='r')
# Plot legend.
lgnd = plt.legend(loc="lower left", numpoints=1, fontsize=10)


#change the marker size manually for both lines
lgnd.legendHandles[0]._legmarker.set_markersize(6)
lgnd.legendHandles[1]._legmarker.set_markersize(6)
plt.show()

这意味着:

enter image description here

这似乎是你想要的。

变化:

  • scatter变成了 plot,它改变了标记尺寸(因此是 sqrt) ,使得不可能使用变化的标记尺寸(如果这是有意的)
  • 图例中的两个标记的标记大小手动更改为6点

正如您所看到的,这使用了隐藏的下划线属性(_legmarker) ,并且非常难看。在 matplotlib的任何更新中都可能出现故障。

更新

哈,我找到了,一个更好的黑客:

import matplotlib.pyplot as plt
import numpy as np


def rand_data():
return np.random.uniform(low=0., high=1., size=(100,))


# Generate data.
x1, y1 = [rand_data() for i in range(2)]
x2, y2 = [rand_data() for i in range(2)]


plt.figure()
plt.scatter(x1, y1, marker='o', label='first', s=20., c='b')
plt.scatter(x2, y2, marker='o', label='second', s=35., c='r')
# Plot legend.
lgnd = plt.legend(loc="lower left", scatterpoints=1, fontsize=10)
lgnd.legendHandles[0]._sizes = [30]
lgnd.legendHandles[1]._sizes = [30]
plt.show()

现在,_sizes(另一个下划线属性)完成了这个任务。不需要触摸源代码,即使这是一个相当黑客。但是现在您可以使用 scatter提供的所有内容。

enter image description here

使用@DrV 的解决方案我并没有取得多大的成功,尽管我的用例可能是独一无二的。由于点的密度,我使用最小的标记大小,即 plt.plot(x, y, '.', ms=1, ...),并希望图例符号更大。

我遵循了我在 Matplotlib 论坛上找到的建议:

  1. 绘制数据(无标签)
  2. 记录轴极限(xlimits = plt.xlim())
  3. 使用适合图例的符号颜色和大小绘制远离真实数据的假数据
  4. 恢复轴极限(plt.xlim(xlimits))
  5. 创造传奇

下面是结果(点实际上没有线重要) : enter image description here

希望这能帮到别人。

与答案类似,假设你想要所有的记号笔都是相同的大小:

lgnd = plt.legend(loc="lower left", scatterpoints=1, fontsize=10)
for handle in lgnd.legendHandles:
handle.set_sizes([6.0])

使用 MatPlotlib 2.0

您可以创建一个 Line2D 对象,该对象与您选择的标记类似,只是具有您选择的不同标记大小,然后使用该对象构建图例。这很好,因为它不需要在你的坐标轴中放置一个对象(可能会触发一个调整大小的事件) ,也不需要使用任何隐藏的属性。唯一真正的缺点是您必须根据对象和标签列表显式地构造图例,但是这是一个文档良好的 matplotlib 特性,因此使用起来非常安全。

from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
import numpy as np


def rand_data():
return np.random.uniform(low=0., high=1., size=(100,))


# Generate data.
x1, y1 = [rand_data() for i in range(2)]
x2, y2 = [rand_data() for i in range(2)]


plt.figure()
plt.scatter(x1, y1, marker='o', label='first', s=20., c='b')
plt.scatter(x2, y2, marker='o', label='second', s=35., c='r')


# Create dummy Line2D objects for legend
h1 = Line2D([0], [0], marker='o', markersize=np.sqrt(20), color='b', linestyle='None')
h2 = Line2D([0], [0], marker='o', markersize=np.sqrt(20), color='r', linestyle='None')


# Set axes limits
plt.gca().set_xlim(-0.2, 1.2)
plt.gca().set_ylim(-0.2, 1.2)


# Plot legend.
plt.legend([h1, h2], ['first', 'second'], loc="lower left", markerscale=2,
scatterpoints=1, fontsize=10)
plt.show()

resulting figure

只是另一个选择。这样做的好处是,它不会使用任何“私有”方法,甚至可以使用图例中的散点以外的其他对象。关键是将分散的 PathCollection映射到 HandlerPathCollection,并设置一个更新函数。

def update(handle, orig):
handle.update_from(orig)
handle.set_sizes([64])


plt.legend(handler_map={PathCollection : HandlerPathCollection(update_func=update)})

完整的代码示例:

import matplotlib.pyplot as plt
import numpy as np; np.random.seed(42)
from matplotlib.collections import PathCollection
from matplotlib.legend_handler import HandlerPathCollection, HandlerLine2D


colors = ["limegreen", "crimson", "indigo"]
markers = ["o", "s", r"$\clubsuit$"]
labels = ["ABC", "DEF", "XYZ"]
plt.plot(np.linspace(0,1,8), np.random.rand(8), marker="o", markersize=22, label="A line")
for i,(c,m,l) in enumerate(zip(colors,markers,labels)):
plt.scatter(np.random.rand(8),np.random.rand(8),
c=c, marker=m, s=10+np.exp(i*2.9), label=l)


def updatescatter(handle, orig):
handle.update_from(orig)
handle.set_sizes([64])


def updateline(handle, orig):
handle.update_from(orig)
handle.set_markersize(8)


plt.legend(handler_map={PathCollection : HandlerPathCollection(update_func=updatescatter),
plt.Line2D : HandlerLine2D(update_func = updateline)})


plt.show()

enter image description here