У меня есть кадр данных pyspark со столбцом MapType(StringType(), FloatType()), и я получу список всех ключей, появляющихся в столбце. Например, имея этот кадр данных:

+---+--------------------+
| ID|                 map|
+---+--------------------+
|  0|[a -> 3.0, b -> 2...|
|  1|[a -> 1.0, b -> 4...|
|  2|[a -> 6.0, c -> 5...|
|  3|[a -> 6.0, f -> 8...|
|  4|[a -> 2.0, c -> 1...|
|  5|[c -> 1.0, d -> 1...|
|  6|[a -> 4.0, c -> 1...|
|  7|[a -> 2.0, d -> 1...|
|  8|          [a -> 2.0]|
|  9|[e -> 1.0, f -> 1.0]|
| 10|          [g -> 1.0]|
| 11|[e -> 2.0, b -> 3.0]|
+---+--------------------+

Я ожидаю получить следующий список:

['a', 'b', 'c', 'd', 'e', 'f', 'g']

Я уже пробовал

df.select(explode(col('map'))).groupby('key').count().select('key').collect()
df.select(explode(col('map'))).select('key').drop_duplicates().collect()
df.select(explode(col('map'))).select('key').distinct().collect()
df.select(explode(map_keys(col('map')))).select('key').distinct().collect()

...

Но для каждой из этих команд я получаю разные результаты не только для разных команд, но и при выполнении одной и той же команды в одном и том же кадре данных.

Например:

keys_1 = df.select(explode(col('map'))).select('key').drop_duplicates().collect()
keys_1 = [row['key'] for row in keys_1]

А также:

keys_2 = df.select(explode(col('map'))).select('key').drop_duplicates().collect()
keys_2 = [row['key'] for row in keys_2]

Потом довольно часто бывает, что len(keys_1) != len(keys_2).

В моем фрейме данных около 10e7 строк, и для моего столбца карты имеется около 2000 различных ключей.

Обратите внимание, что на небольшом наборе данных это работает без проблем, но, к сожалению, довольно сложно найти большой набор данных.

Пример кода небольшого набора данных:

df = spark.createDataFrame([
  (0, {'a':3.0, 'b':2.0, 'c':2.0}),
  (1, {'a':1.0, 'b':4.0, 'd':6.0}),
  (2, {'a':6.0, 'e':5.0, 'c':5.0}),
  (3, {'f':8.0, 'a':6.0, 'g':4.0}),
  (4, {'a':2.0, 'c':1.0, 'd':3.0}),
  (5, {'d':1.0, 'g':5.0, 'c':1.0}),
  (6, {'a':4.0, 'c':1.0, 'f':1.0}),
  (7, {'a':2.0, 'e':2.0, 'd':1.0}),
  (8, {'a':2.0}),
  (9, {'e':1.0, 'f':1.0}),
  (10, {'g':1.0}),
  (11, {'b':3.0, 'e':2.0})
],
  ['ID', 'map']
)
df.select(explode(col('map'))).groupby('key').count().select('key').collect()
2
olileo 18 Окт 2019 в 12:13
Можете ли вы также вставить схему для оригинального df, пожалуйста. Я
 – 
Aditya Vikram Singh
20 Окт 2020 в 22:45

1 ответ

Это должно работать

df.select(explode($"map")).select($"key").distinct().collect()
0
tom 19 Окт 2021 в 11:23