熊猫/Pyplot 中的散点图: 如何按类别进行绘图

我试图使用熊猫数据框对象在 pyplot 中制作一个简单的散点图,但是想要一个有效的方法来绘制两个变量,但是要有第三列(键)指定的符号。我尝试了各种使用 df.groupby 的方法,但都不成功。下面是一个示例 df 脚本。这根据“ key1”给标记着色,但我希望看到具有“ key1”类别的图例。我猜对了吗?谢谢。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
plt.show()
232690 次浏览

With plt.scatter, I can only think of one: to use a proxy artist:

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
x=ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)


ccm=x.get_cmap()
circles=[Line2D(range(1), range(1), color='w', marker='o', markersize=10, markerfacecolor=item) for item in ccm((array([4,6,8])-4.0)/4)]
leg = plt.legend(circles, ['4','6','8'], loc = "center left", bbox_to_anchor = (1, 0.5), numpoints = 1)

And the result is:

enter image description here

You can use scatter for this, but that requires having numerical values for your key1, and you won't have a legend, as you noticed.

It's better to just use plot for discrete categories like this. For example:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)


# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))


groups = df.groupby('label')


# Plot
fig, ax = plt.subplots()
ax.margins(0.05) # Optional, just adds 5% padding to the autoscaling
for name, group in groups:
ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend()


plt.show()

enter image description here

If you'd like things to look like the default pandas style, then just update the rcParams with the pandas stylesheet and use its color generator. (I'm also tweaking the legend slightly):

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)


# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))


groups = df.groupby('label')


# Plot
plt.rcParams.update(pd.tools.plotting.mpl_stylesheet)
colors = pd.tools.plotting._get_standard_colors(len(groups), color_type='random')


fig, ax = plt.subplots()
ax.set_color_cycle(colors)
ax.margins(0.05)
for name, group in groups:
ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend(numpoints=1, loc='upper left')


plt.show()

enter image description here

This is simple to do with Seaborn (pip install seaborn) as a oneliner

sns.scatterplot(x_vars="one", y_vars="two", data=df, hue="key1") :

import seaborn as sns
import pandas as pd
import numpy as np
np.random.seed(1974)


df = pd.DataFrame(
np.random.normal(10, 1, 30).reshape(10, 3),
index=pd.date_range('2010-01-01', freq='M', periods=10),
columns=('one', 'two', 'three'))
df['key1'] = (4, 4, 4, 6, 6, 6, 8, 8, 8, 8)


sns.scatterplot(x="one", y="two", data=df, hue="key1")

enter image description here

Here is the dataframe for reference:

enter image description here

Since you have three variable columns in your data, you may want to plot all pairwise dimensions with:

sns.pairplot(vars=["one","two","three"], data=df, hue="key1")

enter image description here

https://rasbt.github.io/mlxtend/user_guide/plotting/category_scatter/ is another option.

You can also try Altair or ggpot which are focused on declarative visualisations.

import numpy as np
import pandas as pd
np.random.seed(1974)


# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

Altair code

from altair import Chart
c = Chart(df)
c.mark_circle().encode(x='x', y='y', color='label')

enter image description here

ggplot code

from ggplot import *
ggplot(aes(x='x', y='y', color='label'), data=df) +\
geom_point(size=50) +\
theme_bw()

enter image description here

You can use df.plot.scatter, and pass an array to c= argument defining the color of each point:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
colors = np.where(df["key1"]==4,'r','-')
colors[df["key1"]==6] = 'g'
colors[df["key1"]==8] = 'b'
print(colors)
df.plot.scatter(x="one",y="two",c=colors)
plt.show()

enter image description here

It's rather hacky, but you could use one1 as a Float64Index to do everything in one go:

df.set_index('one').sort_index().groupby('key1')['two'].plot(style='--o', legend=True)

enter image description here

Note that as of 0.20.3, sorting the index is necessary, and the legend is a bit wonky.

From matplotlib 3.1 onwards you can use .legend_elements(). An example is shown in Automated legend creation. The advantage is that a single scatter call can be used.

In this case:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3),
index = pd.date_range('2010-01-01', freq = 'M', periods = 10),
columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)




fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
ax.legend(*sc.legend_elements())
plt.show()

enter image description here

In case the keys were not directly given as numbers, it would look as

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3),
index = pd.date_range('2010-01-01', freq = 'M', periods = 10),
columns = ('one', 'two', 'three'))
df['key1'] = list("AAABBBCCCC")


labels, index = np.unique(df["key1"], return_inverse=True)


fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = index, alpha = 0.8)
ax.legend(sc.legend_elements()[0], labels)
plt.show()

enter image description here

seaborn has a wrapper function scatterplot that does it more efficiently.

sns.scatterplot(data = df, x = 'one', y = 'two', data =  'key1'])