如何保存和加载 xgost 模型?

来自 XGBoost 指南:

经过训练,模型可以保存。

bst.save_model('0001.model')

模型及其特征映射也可以转储到一个文本文件中。

# dump model
bst.dump_model('dump.raw.txt')
# dump model with feature map
bst.dump_model('dump.raw.txt', 'featmap.txt')

保存的模型可以加载如下:

bst = xgb.Booster({'nthread': 4})  # init model
bst.load_model('model.bin')  # load data

接下来是我的问题。

  1. save_modeldump_model有什么区别?
  2. 保存 '0001.model'和保存 'dump.raw.txt','featmap.txt'有什么区别?
  3. 为什么载入 model.bin的型号名称与保存 0001.model的型号名称不同?
  4. 假设我训练了两个模型: model_Amodel_B。我想保存这两种型号以备将来使用。我应该使用哪个 save & load函数?你能帮忙展示一下清晰的过程吗?
135100 次浏览

两个函数 save_modeldump_model保存模型,不同的是,在 dump_model可以保存功能名称和保存文本格式的树。

load_model将与来自 save_model的模型一起工作。来自 dump_model的模型可以与 Xgbfi一起使用。

在加载模型期间,您需要指定保存模型的路径。在示例中,bst.load_model("model.bin")模型是从文件 model.bin加载的-它只是一个带有模型的文件名。祝你好运!

EDIT : 从 Xgost 文档(对于版本 1.3.3)来看,应该使用 dump_model()保存模型以便进一步解释。为了保存和加载模型,应该使用 save_model()load_model()。详情请参阅 医生

Xgboost 的 Learning APIScikit-Learn API也有所不同。后者保存 best_ntree_limit变量,该变量是在早期停止训练期间设置的。你可以在我的文章 如何在 Python 中保存和加载 Xgost?中读到细节

save_model()方法识别文件名的格式,如果指定了 *.json,则模型保存在 JSON,否则就是文本文件。

保存和加载 xgost 模型的一种简单方法是使用 joblib 库。

import joblib
#save model
joblib.dump(xgb, filename)


#load saved model
xgb = joblib.load(filename)

我是这样解决这个问题的:

import pickle
file_name = "xgb_reg.pkl"


# save
pickle.dump(xgb_model, open(file_name, "wb"))


# load
xgb_model_loaded = pickle.load(open(file_name, "rb"))


# test
ind = 1
test = X_val[ind]
xgb_model_loaded.predict(test)[0] == xgb_model.predict(test)[0]


Out[1]: True

不要使用 pickle 或 joblib,因为这可能会引入对 xgost 版本的依赖性。保存和恢复模型的规范方法是使用 load_modelsave_model

如果希望存储或存档模型以便进行长期存储,请使用 save _ model (Python)和 xgb.save (R)。

这个 是 XGBoost 最新版本的相关文档,它还解释了 dump_modelsave_model之间的区别。

注意,在使用 bst.save_model时,可以通过指定 json 作为扩展来将模型序列化/反序列化为 json。如果保存和恢复模型的速度对您来说并不重要,那么这非常方便,因为它允许您对模型进行正确的版本控制,因为它是一个简单的文本文件。

如果你正在使用 sklearn API,你可以使用以下方法:


xgb_model_latest = xgboost.XGBClassifier() # or which ever sklearn booster you're are using


xgb_model_latest.load_model("model.json") # or model.bin if you are using binary format and not the json

如果您使用上面的 booster 方法进行加载,您将在 python api 中获得 xgost booster,而不是在 sklearn api 中获得 sklearn booster。

因此,如果您正在使用 sklearn API,那么这似乎是加载保存的 xgost 模型数据的最简单的方法。