Scikit-learn train_test_split with index

在使用 train _ test _ split ()时,如何获得数据的原始索引?

我有以下几点

from sklearn.cross_validation import train_test_split
import numpy as np
data = np.reshape(np.randn(20),(10,2)) # 10 training examples
labels = np.random.randint(2, size=10) # 10 labels
x1, x2, y1, y2 = train_test_split(data, labels, size=0.2)

但这并没有给出原始数据的索引。 一种变通方法是将索引添加到数据中(例如 data = [(i, d) for i, d in enumerate(data)]) ,然后在 train_test_split中传递它们,然后再次展开。 有没有更清洁的解决方案?

143118 次浏览

Scikit learn plays really well with Pandas, so I suggest you use it. Here's an example:

In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
data = np.reshape(np.random.randn(20),(10,2)) # 10 training examples
labels = np.random.randint(2, size=10) # 10 labels


In [2]: # Giving columns in X a name
X = pd.DataFrame(data, columns=['Column_1', 'Column_2'])
y = pd.Series(labels)


In [3]:
X_train, X_test, y_train, y_test = train_test_split(X, y,
test_size=0.2,
random_state=0)


In [4]: X_test
Out[4]:


Column_1    Column_2
2   -1.39       -1.86
8    0.48       -0.81
4   -0.10       -1.83


In [5]: y_test
Out[5]:


2    1
8    1
4    1
dtype: int32

You can directly call any scikit functions on DataFrame/Series and it will work.

Let's say you wanted to do a LogisticRegression, here's how you could retrieve the coefficients in a nice way:

In [6]:
from sklearn.linear_model import LogisticRegression


model = LogisticRegression()
model = model.fit(X_train, y_train)


# Retrieve coefficients: index is the feature name (['Column_1', 'Column_2'] here)
df_coefs = pd.DataFrame(model.coef_[0], index=X.columns, columns = ['Coefficient'])
df_coefs
Out[6]:
Coefficient
Column_1    0.076987
Column_2    -0.352463

You can use pandas dataframes or series as Julien said but if you want to restrict your-self to numpy you can pass an additional array of indices:

from sklearn.model_selection import train_test_split
import numpy as np
n_samples, n_features, n_classes = 10, 2, 2
data = np.random.randn(n_samples, n_features)  # 10 training examples
labels = np.random.randint(n_classes, size=n_samples)  # 10 labels
indices = np.arange(n_samples)
(
data_train,
data_test,
labels_train,
labels_test,
indices_train,
indices_test,
) = train_test_split(data, labels, indices, test_size=0.2)

The docs mention train_test_split is just a convenience function on top of shuffle split.

I just rearranged some of their code to make my own example. Note the actual solution is the middle block of code. The rest is imports, and setup for a runnable example.

from sklearn.model_selection import ShuffleSplit
from sklearn.utils import safe_indexing, indexable
from itertools import chain
import numpy as np
X = np.reshape(np.random.randn(20),(10,2)) # 10 training examples
y = np.random.randint(2, size=10) # 10 labels
seed = 1


cv = ShuffleSplit(random_state=seed, test_size=0.25)
arrays = indexable(X, y)
train, test = next(cv.split(X=X))
iterator = list(chain.from_iterable((
safe_indexing(a, train),
safe_indexing(a, test),
train,
test
) for a in arrays)
)
X_train, X_test, train_is, test_is, y_train, y_test, _, _  = iterator


print(X)
print(train_is)
print(X_train)

Now I have the actual indexes: train_is, test_is

Here's the simplest solution (Jibwa made it seem complicated in another answer), without having to generate indices yourself - just using the ShuffleSplit object to generate 1 split.

import numpy as np
from sklearn.model_selection import ShuffleSplit # or StratifiedShuffleSplit
sss = ShuffleSplit(n_splits=1, test_size=0.1)


data_size = 100
X = np.reshape(np.random.rand(data_size*2),(data_size,2))
y = np.random.randint(2, size=data_size)


sss.get_n_splits(X, y)
train_index, test_index = next(sss.split(X, y))


X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]

If you are using pandas you can access the index by calling .index of whatever array you wish to mimic. The train_test_split carries over the pandas indices to the new dataframes.

In your code you simply use x1.index and the returned array is the indexes relating to the original positions in x.