PySpark groupBy 中的中位数/分位数

我想在 Spark 数据框架上计算组分位数(使用 PySpark)。近似或精确的结果都可以。我更喜欢在 groupBy/agg上下文中使用的解决方案,这样就可以将其与其他 PySpark 聚合函数混合使用。如果由于某种原因这是不可能的,那么采用不同的方法也是可以的。

这个问题 是相关的,但是没有说明如何使用 approxQuantile作为聚合函数。

我也可以访问 percentile_approx蜂巢 UDF,但我不知道如何使用它作为一个聚合函数。

为了具体起见,假设我有以下数据框架:

from pyspark import SparkContext
import pyspark.sql.functions as f


sc = SparkContext()


df = sc.parallelize([
['A', 1],
['A', 2],
['A', 3],
['B', 4],
['B', 5],
['B', 6],
]).toDF(('grp', 'val'))


df_grp = df.groupBy('grp').agg(f.magic_percentile('val', 0.5).alias('med_val'))
df_grp.show()

预期结果是:

+----+-------+
| grp|med_val|
+----+-------+
|   A|      2|
|   B|      5|
+----+-------+
95481 次浏览

Since you have access to percentile_approx, one simple solution would be to use it in a SQL command:

from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)


df.registerTempTable("df")
df2 = sqlContext.sql("select grp, percentile_approx(val, 0.5) as med_val from df group by grp")

(UPDATE: now it is possible, see accepted answer above)


Unfortunately, and to the best of my knowledge, it seems that it is not possible to do this with "pure" PySpark commands (the solution by Shaido provides a workaround with SQL), and the reason is very elementary: in contrast with other aggregate functions, such as mean, approxQuantile does not return a Column type, but a list.

Let's see a quick example with your sample data:

spark.version
# u'2.2.0'


import pyspark.sql.functions as func
from pyspark.sql import DataFrameStatFunctions as statFunc


# aggregate with mean works OK:
df_grp_mean = df.groupBy('grp').agg(func.mean(df['val']).alias('mean_val'))
df_grp_mean.show()
# +---+--------+
# |grp|mean_val|
# +---+--------+
# |  B|     5.0|
# |  A|     2.0|
# +---+--------+


# try aggregating by median:
df_grp_med = df.groupBy('grp').agg(statFunc(df).approxQuantile('val', [0.5], 0.1))
# AssertionError: all exprs should be Column


# mean aggregation is a Column, but median is a list:


type(func.mean(df['val']))
# pyspark.sql.column.Column


type(statFunc(df).approxQuantile('val', [0.5], 0.1))
# list

I doubt that a window-based approach will make any difference, since as I said the underlying reason is a very elementary one.

See also my answer here for some more details.

I guess you don't need it anymore. But will leave it here for future generations (i.e. me next week when I forget).

from pyspark.sql import Window
import pyspark.sql.functions as F


grp_window = Window.partitionBy('grp')
magic_percentile = F.expr('percentile_approx(val, 0.5)')


df.withColumn('med_val', magic_percentile.over(grp_window))

Or to address exactly your question, this also works:

df.groupBy('grp').agg(magic_percentile.alias('med_val'))

And as a bonus, you can pass an array of percentiles:

quantiles = F.expr('percentile_approx(val, array(0.25, 0.5, 0.75))')

And you'll get a list in return.

The most simple way to do this with pyspark==2.4.5 is:

df \
.groupby('grp') \
.agg(expr('percentile(val, array(0.5))')[0].alias('50%')) \
.show()


output:

|grp|50%|
+---+---+
|  B|5.0|
|  A|2.0|
+---+---+

problem of "percentile_approx(val, 0.5)": if e.g. range is [1,2,3,4] this function returns 2 (as median) the function below returns 2.5:

import statistics


median_udf = F.udf(lambda x: statistics.median(x) if bool(x) else None, DoubleType())


... .groupBy('something').agg(median_udf(F.collect_list(F.col('value'))).alias('median'))

It seems to be completely solved by pyspark >= 3.1.0 using percentile_approx

import pyspark.sql.functions as func


df.groupBy("grp").agg(func.percentile_approx("val", 0.5).alias("median"))

For further information see: https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.functions.percentile_approx.html