我怎样才能画出混淆矩阵?

我使用 scikit-learn 将文本文档(22000)分类为100个类。我使用 scikit-learn 的混淆矩阵方法计算混淆矩阵。

model1 = LogisticRegression()
model1 = model1.fit(matrix, labels)
pred = model1.predict(test_matrix)
cm=metrics.confusion_matrix(test_labels,pred)
print(cm)
plt.imshow(cm, cmap='binary')

我的混淆矩阵是这样的:

[[3962  325    0 ...,    0    0    0]
[ 250 2765    0 ...,    0    0    0]
[   2    8   17 ...,    0    0    0]
...,
[   1    6    0 ...,    5    0    0]
[   1    1    0 ...,    0    0    0]
[   9    0    0 ...,    0    0    9]]

然而,我没有收到一个明确或清晰的情节。有没有更好的方法来做到这一点?

380852 次浏览

enter image description here

你可以使用 plt.matshow()代替 plt.imshow(),或者使用海运模块的 heatmap(请参阅文件)来绘制混淆矩阵

import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
array = [[33,2,0,0,0,0,0,0,0,1,3],
[3,31,0,0,0,0,0,0,0,0,0],
[0,4,41,0,0,0,0,0,0,0,1],
[0,1,0,30,0,6,0,0,0,0,1],
[0,0,0,0,38,10,0,0,0,0,0],
[0,0,0,3,1,39,0,0,0,0,4],
[0,2,2,0,4,1,31,0,0,0,2],
[0,1,0,0,0,0,0,36,0,2,0],
[0,0,0,0,0,0,1,5,37,5,1],
[3,0,0,0,0,0,0,0,0,39,0],
[0,0,0,0,0,0,0,0,0,0,38]]
df_cm = pd.DataFrame(array, index = [i for i in "ABCDEFGHIJK"],
columns = [i for i in "ABCDEFGHIJK"])
plt.figure(figsize = (10,7))
sn.heatmap(df_cm, annot=True)

@ bninopaul 的答案并不完全适合初学者

这是你可以“复制并运行”的代码

import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt


array = [[13,1,1,0,2,0],
[3,9,6,0,1,0],
[0,0,16,2,0,0],
[0,0,0,13,0,0],
[0,0,0,0,15,0],
[0,0,1,0,0,15]]


df_cm = pd.DataFrame(array, range(6), range(6))
# plt.figure(figsize=(10,7))
sn.set(font_scale=1.4) # for label size
sn.heatmap(df_cm, annot=True, annot_kws={"size": 16}) # font size


plt.show()

result

如果你希望你的混淆矩阵中有 更多数据,包括“ 总数栏”和“ 总数线”,以及每个细胞中有 百分比(%) ,那就是 就像 Matlab 默认的那样(见下图)

enter image description here

包括热图和其他选项。

您应该对上面的模块(在 github 中共享)感到满意;)

Https://github.com/wcipriano/pretty-print-confusion-matrix


这个模块可以很容易地完成您的任务,并生成上面的输出和大量的参数来定制您的 CM: enter image description here