将 parks DataFrame 列转换为 python 列表

我处理一个包含两列的数据框架,mvv 和 count。

+---+-----+
|mvv|count|
+---+-----+
| 1 |  5  |
| 2 |  9  |
| 3 |  3  |
| 4 |  1  |

我想获得两个包含 mvv 值和 count 值的列表

mvv = [1,2,3,4]
count = [5,9,3,1]

因此,我尝试了以下代码: 第一行应该返回一个行的 python 列表。我想看看第一个值:

mvv_list = mvv_count_df.select('mvv').collect()
firstvalue = mvv_list[0].getInt(0)

但是我在第二行得到一个错误消息:

AttributeError: getInt 属性错误: getInt

351522 次浏览

看,为什么你现在这样做不管用。首先,您尝试从 划船类型获取整数,您的集合的输出如下:

>>> mvv_list = mvv_count_df.select('mvv').collect()
>>> mvv_list[0]
Out: Row(mvv=1)

如果你拿这样的东西:

>>> firstvalue = mvv_list[0].mvv
Out: 1

您将得到 mvv值。如果你想要数组的所有信息,你可以这样做:

>>> mvv_array = [int(row.mvv) for row in mvv_list.collect()]
>>> mvv_array
Out: [1,2,3,4]

但是如果你在另一个专栏中尝试同样的方法,你会得到:

>>> mvv_count = [int(row.count) for row in mvv_list.collect()]
Out: TypeError: int() argument must be a string or a number, not 'builtin_function_or_method'

这是因为 count是一个内置的方法。该列的名称与 count相同。解决办法是将 count的列名改为 _count:

>>> mvv_list = mvv_list.selectExpr("mvv as mvv", "count as _count")
>>> mvv_count = [int(row._count) for row in mvv_list.collect()]

但是不需要这种变通方法,因为您可以使用字典语法访问该列:

>>> mvv_array = [int(row['mvv']) for row in mvv_list.collect()]
>>> mvv_count = [int(row['count']) for row in mvv_list.collect()]

最终会成功的!

跟随一行代码可以得到所需的列表。

mvv = mvv_count_df.select("mvv").rdd.flatMap(lambda x: x).collect()

如果你得到下面的错误:

AttributeError: ‘ list’对象没有属性‘ Collection’

这个代码将解决你的问题:

mvv_list = mvv_count_df.select('mvv').collect()


mvv_array = [int(i.mvv) for i in mvv_list]

下面的代码将帮助您

mvv_count_df.select('mvv').rdd.map(lambda row : row[0]).collect()

这将给出作为列表的所有元素。

mvv_list = list(
mvv_count_df.select('mvv').toPandas()['mvv']
)

根据我的数据,我得到了这些基准:

>>> data.select(col).rdd.flatMap(lambda x: x).collect()

0.52秒

>>> [row[col] for row in data.collect()]

0.271秒

>>> list(data.select(col).toPandas()[col])

0.427秒

结果是一样的

我做了一个基准分析,list(mvv_count_df.select('mvv').toPandas()['mvv'])是最快的方法,我很惊讶。

我使用 Spark 2.4.5的5节点 i3.xlarge 集群(每个节点有30.5 GB 的 RAM 和4个核)对10万/1亿行数据集运行了不同的方法。数据均匀分布在20个快速压缩 Parquet 文件与一个单列。

下面是基准测试结果(运行时间以秒为单位) :

+-------------------------------------------------------------+---------+-------------+
|                          Code                               | 100,000 | 100,000,000 |
+-------------------------------------------------------------+---------+-------------+
| df.select("col_name").rdd.flatMap(lambda x: x).collect()    |     0.4 | 55.3        |
| list(df.select('col_name').toPandas()['col_name'])          |     0.4 | 17.5        |
| df.select('col_name').rdd.map(lambda row : row[0]).collect()|     0.9 | 69          |
| [row[0] for row in df.select('col_name').collect()]         |     1.0 | OOM         |
| [r[0] for r in mid_df.select('col_name').toLocalIterator()] |     1.2 | *           |
+-------------------------------------------------------------+---------+-------------+


* cancelled after 800 seconds

收集驱动程序节点数据时应遵循的黄金法则:

  • 尝试用其他方法解决这个问题。向驱动程序节点收集数据代价高昂,无法利用 Spark 集群的能力,因此应尽可能避免这种做法。
  • 尽可能少地收集行。在收集数据之前对列进行聚合、重复、筛选和删除。尽可能少地向驱动程序节点发送数据。

如果你使用的是早于2.3的 Spark 版本,这可能不是最好的方法。

有关详细信息/基准测试结果,请参阅 给你

一个可能的解决方案是使用来自 pyspark.sql.functionscollect_list()函数。这将把所有列值聚合到一个 pypark 数组中,在收集到该数组时,该数组将转换为一个 python 列表:

mvv_list   = df.select(collect_list("mvv")).collect()[0][0]
count_list = df.select(collect_list("count")).collect()[0][0]

让我们创建有问题的数据框架

df_test = spark.createDataFrame(
[
(1, 5),
(2, 9),
(3, 3),
(4, 1),
],
['mvv', 'count']
)
df_test.show()

也就是说

+---+-----+
|mvv|count|
+---+-----+
|  1|    5|
|  2|    9|
|  3|    3|
|  4|    1|
+---+-----+

然后应用 rdd.latMap (f) . Collection ()获取列表

test_list = df_test.select("mvv").rdd.flatMap(list).collect()
print(type(test_list))
print(test_list)

所以

<type 'list'>
[1, 2, 3, 4]

尽管有许多答案,但当您需要一个列表与 whenisin命令结合使用时,其中一些答案不会起作用。最简单有效的方法是使用列表内涵和 [0]来避免行名:

flatten_list_from_spark_df=[i[0] for i in df.select("your column").collect()]

另一种方法是使用熊猫数据帧,然后再使用 list函数,但不方便和有效

您可以首先使用 will 返回 Row 类型的列表来收集 df

row_list = df.select('mvv').collect()

在行上迭代以转换为列表

sno_id_array = [ int(row.mvv) for row in row_list]


sno_id_array
[1,2,3,4]

用平面地图

sno_id_array = df.select("mvv").rdd.flatMap(lambda x: x).collect()