我用 Python 编写了一个混淆矩阵计算代码:
def conf_mat(prob_arr, input_arr):
# confusion matrix
conf_arr = [[0, 0], [0, 0]]
for i in range(len(prob_arr)):
if int(input_arr[i]) == 1:
if float(prob_arr[i]) < 0.5:
conf_arr[0][1] = conf_arr[0][1] + 1
else:
conf_arr[0][0] = conf_arr[0][0] + 1
elif int(input_arr[i]) == 2:
if float(prob_arr[i]) >= 0.5:
conf_arr[1][0] = conf_arr[1][0] +1
else:
conf_arr[1][1] = conf_arr[1][1] +1
accuracy = float(conf_arr[0][0] + conf_arr[1][1])/(len(input_arr))
prob_arr
是我的分类代码返回的一个数组,样例数组如下:
[1.0, 1.0, 1.0, 0.41592955657342651, 1.0, 0.0053405015805891975, 4.5321494433440449e-299, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.70943426182688163, 1.0, 1.0, 1.0, 1.0]
input_arr
是数据集的原始类标签,它是这样的:
[2, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 2, 1, 1, 1]
我的代码试图做的是: 我得到 prob_arr
和 input_arr
,对于每个类(1和2) ,我检查它们是否被错误分类。
但是我的代码只适用于两个类。如果我为多个分类数据运行这段代码,它就不工作了。我怎样才能为多个类制作这个?
例如,对于包含三个类的数据集,它应该返回: [[21, 7, 3], [3, 38, 6],[5, 4, 19]]
。