spark_data_utils.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. # Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import os
  16. import sys
  17. from argparse import ArgumentParser
  18. from collections import OrderedDict
  19. from contextlib import contextmanager
  20. from operator import itemgetter
  21. from time import time
  22. from pyspark import broadcast
  23. from pyspark.sql import Row, SparkSession, Window
  24. from pyspark.sql.functions import *
  25. from pyspark.sql.types import *
  26. LABEL_COL = 0
  27. INT_COLS = list(range(1, 14))
  28. CAT_COLS = list(range(14, 40))
  29. def get_column_counts_with_frequency_limit(df, frequency_limit = None):
  30. cols = ['_c%d' % i for i in CAT_COLS]
  31. df = (df
  32. .select(posexplode(array(*cols)))
  33. .withColumnRenamed('pos', 'column_id')
  34. .withColumnRenamed('col', 'data')
  35. .filter('data is not null')
  36. .groupBy('column_id', 'data')
  37. .count())
  38. if frequency_limit:
  39. frequency_limit = frequency_limit.split(",")
  40. exclude = []
  41. default_limit = None
  42. for fl in frequency_limit:
  43. frequency_pair = fl.split(":")
  44. if len(frequency_pair) == 1:
  45. default_limit = int(frequency_pair[0])
  46. elif len(frequency_pair) == 2:
  47. df = df.filter((col('column_id') != int(frequency_pair[0]) - CAT_COLS[0]) | (col('count') >= int(frequency_pair[1])))
  48. exclude.append(int(frequency_pair[0]))
  49. if default_limit:
  50. remain = [x - CAT_COLS[0] for x in CAT_COLS if x not in exclude]
  51. df = df.filter((~col('column_id').isin(remain)) | (col('count') >= default_limit))
  52. # for comparing isin and separate filter
  53. # for i in remain:
  54. # df = df.filter((col('column_id') != i - CAT_COLS[0]) | (col('count') >= default_limit))
  55. return df
  56. def assign_id_with_window(df):
  57. windowed = Window.partitionBy('column_id').orderBy(desc('count'))
  58. return (df
  59. .withColumn('id', row_number().over(windowed))
  60. .withColumnRenamed('count', 'model_count'))
  61. def assign_low_mem_partial_ids(df):
  62. # To avoid some scaling issues with a simple window operation, we use a more complex method
  63. # to compute the same thing, but in a more distributed spark specific way
  64. df = df.orderBy(asc('column_id'), desc('count'))
  65. # The monotonically_increasing_id is the partition id in the top 31 bits and the rest
  66. # is an increasing count of the rows within that partition. So we split it into two parts,
  67. # the partion id part_id and the count mono_id
  68. df = df.withColumn('part_id', spark_partition_id())
  69. return df.withColumn('mono_id', monotonically_increasing_id() - shiftLeft(col('part_id'), 33))
  70. def assign_low_mem_final_ids(df):
  71. # Now we can find the minimum and maximum mono_ids within a given column/partition pair
  72. sub_model = df.groupBy('column_id', 'part_id').agg(max('mono_id').alias('top'), min('mono_id').alias('bottom'))
  73. sub_model = sub_model.withColumn('diff', col('top') - col('bottom') + 1)
  74. sub_model = sub_model.drop('top')
  75. # This window function is over aggregated column/partition pair table. It will do a running sum of the rows
  76. # within that column
  77. windowed = Window.partitionBy('column_id').orderBy('part_id').rowsBetween(Window.unboundedPreceding, -1)
  78. sub_model = sub_model.withColumn('running_sum', sum('diff').over(windowed)).na.fill(0, ["running_sum"])
  79. joined = df.withColumnRenamed('column_id', 'i_column_id')
  80. joined = joined.withColumnRenamed('part_id', 'i_part_id')
  81. joined = joined.withColumnRenamed('count', 'model_count')
  82. # Then we can join the original input with the pair it is a part of
  83. joined = joined.join(sub_model, (col('i_column_id') == col('column_id')) & (col('part_id') == col('i_part_id')))
  84. # So with all that we can subtract bottom from mono_id makeing it start at 0 for each partition
  85. # and then add in the running_sum so the id is contiguous and unique for the entire column. + 1 to make it match the 1 based indexing
  86. # for row_number
  87. ret = joined.select(col('column_id'),
  88. col('data'),
  89. (col('mono_id') - col('bottom') + col('running_sum') + 1).cast(IntegerType()).alias('id'),
  90. col('model_count'))
  91. return ret
  92. def get_column_models(combined_model):
  93. for i in CAT_COLS:
  94. model = (combined_model
  95. .filter('column_id == %d' % (i - CAT_COLS[0]))
  96. .drop('column_id'))
  97. yield i, model
  98. def col_of_rand_long():
  99. return (rand() * (1 << 52)).cast(LongType())
  100. def skewed_join(df, model, col_name, cutoff):
  101. # Most versions of spark don't have a good way
  102. # to deal with a skewed join out of the box.
  103. # Some do and if you want to replace this with
  104. # one of those that would be great.
  105. # Because we have statistics about the skewedness
  106. # that we can used we divide the model up into two parts
  107. # one part is the highly skewed part and we do a
  108. # broadcast join for that part, but keep the result in
  109. # a separate column
  110. b_model = broadcast(model.filter(col('model_count') >= cutoff)
  111. .withColumnRenamed('data', col_name)
  112. .drop('model_count'))
  113. df = (df
  114. .join(b_model, col_name, how='left')
  115. .withColumnRenamed('id', 'id_tmp'))
  116. # We also need to spread the skewed data that matched
  117. # evenly. We will use a source of randomness for this
  118. # but use a -1 for anything that still needs to be matched
  119. if 'ordinal' in df.columns:
  120. rand_column = col('ordinal')
  121. else:
  122. rand_column = col_of_rand_long()
  123. df = df.withColumn('join_rand',
  124. # null values are not in the model, they are filtered out
  125. # but can be a source of skewedness so include them in
  126. # the even distribution
  127. when(col('id_tmp').isNotNull() | col(col_name).isNull(), rand_column)
  128. .otherwise(lit(-1)))
  129. # Null out the string data that already matched to save memory
  130. df = df.withColumn(col_name,
  131. when(col('id_tmp').isNotNull(), None)
  132. .otherwise(col(col_name)))
  133. # Now do the second join, which will be a non broadcast join.
  134. # Sadly spark is too smart for its own good and will optimize out
  135. # joining on a column it knows will always be a constant value.
  136. # So we have to make a convoluted version of assigning a -1 to the
  137. # randomness column for the model itself to work around that.
  138. nb_model = (model
  139. .withColumn('join_rand', when(col('model_count') < cutoff, lit(-1)).otherwise(lit(-2)))
  140. .filter(col('model_count') < cutoff)
  141. .withColumnRenamed('data', col_name)
  142. .drop('model_count'))
  143. df = (df
  144. .join(nb_model, ['join_rand', col_name], how='left')
  145. .drop(col_name, 'join_rand')
  146. # Pick either join result as an answer
  147. .withColumn(col_name, coalesce(col('id'), col('id_tmp')))
  148. .drop('id', 'id_tmp'))
  149. return df
  150. def apply_models(df, models, broadcast_model = False, skew_broadcast_pct = 1.0):
  151. # sort the models so broadcast joins come first. This is
  152. # so we reduce the amount of shuffle data sooner than later
  153. # If we parsed the string hex values to ints early on this would
  154. # not make a difference.
  155. models = sorted(models, key=itemgetter(3), reverse=True)
  156. for i, model, original_rows, would_broadcast in models:
  157. col_name = '_c%d' % i
  158. if not (would_broadcast or broadcast_model):
  159. # The data is highly skewed so we need to offset that
  160. cutoff = int(original_rows * skew_broadcast_pct/100.0)
  161. df = skewed_join(df, model, col_name, cutoff)
  162. else:
  163. # broadcast joins can handle skewed data so no need to
  164. # do anything special
  165. model = (model.drop('model_count')
  166. .withColumnRenamed('data', col_name))
  167. model = broadcast(model) if broadcast_model else model
  168. df = (df
  169. .join(model, col_name, how='left')
  170. .drop(col_name)
  171. .withColumnRenamed('id', col_name))
  172. return df.fillna(0, ['_c%d' % i for i in CAT_COLS])
  173. def transform_log(df, transform_log = False):
  174. cols = ['_c%d' % i for i in INT_COLS]
  175. if transform_log:
  176. for col_name in cols:
  177. df = df.withColumn(col_name, log(df[col_name] + 3))
  178. return df.fillna(0, cols)
  179. def would_broadcast(spark, str_path):
  180. sc = spark.sparkContext
  181. config = sc._jsc.hadoopConfiguration()
  182. path = sc._jvm.org.apache.hadoop.fs.Path(str_path)
  183. fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(config)
  184. stat = fs.listFiles(path, True)
  185. sum = 0
  186. while stat.hasNext():
  187. sum = sum + stat.next().getLen()
  188. sql_conf = sc._jvm.org.apache.spark.sql.internal.SQLConf()
  189. cutoff = sql_conf.autoBroadcastJoinThreshold() * sql_conf.fileCompressionFactor()
  190. return sum <= cutoff
  191. def delete_data_source(spark, path):
  192. sc = spark.sparkContext
  193. config = sc._jsc.hadoopConfiguration()
  194. path = sc._jvm.org.apache.hadoop.fs.Path(path)
  195. sc._jvm.org.apache.hadoop.fs.FileSystem.get(config).delete(path, True)
  196. def load_raw(spark, folder, day_range):
  197. label_fields = [StructField('_c%d' % LABEL_COL, IntegerType())]
  198. int_fields = [StructField('_c%d' % i, IntegerType()) for i in INT_COLS]
  199. str_fields = [StructField('_c%d' % i, StringType()) for i in CAT_COLS]
  200. schema = StructType(label_fields + int_fields + str_fields)
  201. paths = [os.path.join(folder, 'day_%d' % i) for i in day_range]
  202. return (spark
  203. .read
  204. .schema(schema)
  205. .option('sep', '\t')
  206. .csv(paths))
  207. def rand_ordinal(df):
  208. # create a random long from the double precision float.
  209. # The fraction part of a double is 52 bits, so we try to capture as much
  210. # of that as possible
  211. return df.withColumn('ordinal', col_of_rand_long())
  212. def day_from_ordinal(df, num_days):
  213. return df.withColumn('day', (col('ordinal') % num_days).cast(IntegerType()))
  214. def day_from_input_file(df):
  215. return df.withColumn('day', substring_index(input_file_name(), '_', -1).cast(IntegerType()))
  216. def psudo_sort_by_day_plus(spark, df, num_days):
  217. # Sort is very expensive because it needs to calculate the partitions
  218. # which in our case may involve rereading all of the data. In some cases
  219. # we can avoid this by repartitioning the data and sorting within a single partition
  220. shuffle_parts = int(spark.conf.get('spark.sql.shuffle.partitions'))
  221. extra_parts = int(shuffle_parts/num_days)
  222. if extra_parts <= 0:
  223. df = df.repartition('day')
  224. else:
  225. #We want to spread out the computation to about the same amount as shuffle_parts
  226. divided = (col('ordinal') / num_days).cast(LongType())
  227. extra_ident = divided % extra_parts
  228. df = df.repartition(col('day'), extra_ident)
  229. return df.sortWithinPartitions('day', 'ordinal')
  230. def load_combined_model(spark, model_folder):
  231. path = os.path.join(model_folder, 'combined.parquet')
  232. return spark.read.parquet(path)
  233. def save_combined_model(df, model_folder, mode=None):
  234. path = os.path.join(model_folder, 'combined.parquet')
  235. df.write.parquet(path, mode=mode)
  236. def delete_combined_model(spark, model_folder):
  237. path = os.path.join(model_folder, 'combined.parquet')
  238. delete_data_source(spark, path)
  239. def load_low_mem_partial_ids(spark, model_folder):
  240. path = os.path.join(model_folder, 'partial_ids.parquet')
  241. return spark.read.parquet(path)
  242. def save_low_mem_partial_ids(df, model_folder, mode=None):
  243. path = os.path.join(model_folder, 'partial_ids.parquet')
  244. df.write.parquet(path, mode=mode)
  245. def delete_low_mem_partial_ids(spark, model_folder):
  246. path = os.path.join(model_folder, 'partial_ids.parquet')
  247. delete_data_source(spark, path)
  248. def load_column_models(spark, model_folder, count_required):
  249. for i in CAT_COLS:
  250. path = os.path.join(model_folder, '%d.parquet' % i)
  251. df = spark.read.parquet(path)
  252. if count_required:
  253. values = df.agg(sum('model_count').alias('sum'), count('*').alias('size')).collect()
  254. else:
  255. values = df.agg(sum('model_count').alias('sum')).collect()
  256. yield i, df, values[0], would_broadcast(spark, path)
  257. def save_column_models(column_models, model_folder, mode=None):
  258. for i, model in column_models:
  259. path = os.path.join(model_folder, '%d.parquet' % i)
  260. model.write.parquet(path, mode=mode)
  261. def save_model_size(model_size, path, write_mode):
  262. if os.path.exists(path) and write_mode == 'errorifexists':
  263. print('Error: model size file %s exists' % path)
  264. sys.exit(1)
  265. os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
  266. with open(path, 'w') as fp:
  267. json.dump(model_size, fp, indent=4)
  268. _benchmark = {}
  269. @contextmanager
  270. def _timed(step):
  271. start = time()
  272. yield
  273. end = time()
  274. _benchmark[step] = end - start
  275. def _parse_args():
  276. parser = ArgumentParser()
  277. parser.add_argument(
  278. '--mode',
  279. required=True,
  280. choices=['generate_models', 'transform'])
  281. parser.add_argument('--days', required=True)
  282. parser.add_argument('--input_folder', required=True)
  283. parser.add_argument('--output_folder')
  284. parser.add_argument('--model_size_file')
  285. parser.add_argument('--model_folder', required=True)
  286. parser.add_argument(
  287. '--write_mode',
  288. choices=['overwrite', 'errorifexists'],
  289. default='errorifexists')
  290. parser.add_argument('--frequency_limit')
  291. parser.add_argument('--no_numeric_log_col', action='store_true')
  292. #Support for running in a lower memory environment
  293. parser.add_argument('--low_mem', action='store_true')
  294. parser.add_argument(
  295. '--output_ordering',
  296. choices=['total_random', 'day_random', 'any', 'input'],
  297. default='total_random')
  298. parser.add_argument(
  299. '--output_partitioning',
  300. choices=['day', 'none'],
  301. default='none')
  302. parser.add_argument('--dict_build_shuffle_parallel_per_day', type=int, default=2)
  303. parser.add_argument('--apply_shuffle_parallel_per_day', type=int, default=25)
  304. parser.add_argument('--skew_broadcast_pct', type=float, default=1.0)
  305. parser.add_argument('--debug_mode', action='store_true')
  306. args = parser.parse_args()
  307. start, end = args.days.split('-')
  308. args.day_range = list(range(int(start), int(end) + 1))
  309. args.days = len(args.day_range)
  310. return args
  311. def _main():
  312. args = _parse_args()
  313. spark = SparkSession.builder.getOrCreate()
  314. df = load_raw(spark, args.input_folder, args.day_range)
  315. if args.mode == 'generate_models':
  316. spark.conf.set('spark.sql.shuffle.partitions', args.days * args.dict_build_shuffle_parallel_per_day)
  317. with _timed('generate models'):
  318. col_counts = get_column_counts_with_frequency_limit(df, args.frequency_limit)
  319. if args.low_mem:
  320. # in low memory mode we have to save an intermediate result
  321. # because if we try to do it in one query spark ends up assigning the
  322. # partial ids in two different locations that are not guaranteed to line up
  323. # this prevents that from happening by assigning the partial ids
  324. # and then writeing them out.
  325. save_low_mem_partial_ids(
  326. assign_low_mem_partial_ids(col_counts),
  327. args.model_folder,
  328. args.write_mode)
  329. save_combined_model(
  330. assign_low_mem_final_ids(load_low_mem_partial_ids(spark, args.model_folder)),
  331. args.model_folder,
  332. args.write_mode)
  333. if not args.debug_mode:
  334. delete_low_mem_partial_ids(spark, args.model_folder)
  335. else:
  336. save_combined_model(
  337. assign_id_with_window(col_counts),
  338. args.model_folder,
  339. args.write_mode)
  340. save_column_models(
  341. get_column_models(load_combined_model(spark, args.model_folder)),
  342. args.model_folder,
  343. args.write_mode)
  344. if not args.debug_mode:
  345. delete_combined_model(spark, args.model_folder)
  346. if args.mode == 'transform':
  347. spark.conf.set('spark.sql.shuffle.partitions', args.days * args.apply_shuffle_parallel_per_day)
  348. with _timed('transform'):
  349. if args.output_ordering == 'total_random':
  350. df = rand_ordinal(df)
  351. if args.output_partitioning == 'day':
  352. df = day_from_ordinal(df, args.days)
  353. elif args.output_ordering == 'day_random':
  354. df = rand_ordinal(df)
  355. df = day_from_input_file(df)
  356. elif args.output_ordering == 'input':
  357. df = df.withColumn('ordinal', monotonically_increasing_id())
  358. if args.output_partitioning == 'day':
  359. df = day_from_input_file(df)
  360. else: # any ordering
  361. if args.output_partitioning == 'day':
  362. df = day_from_input_file(df)
  363. models = list(load_column_models(spark, args.model_folder, bool(args.model_size_file)))
  364. if args.model_size_file:
  365. save_model_size(
  366. OrderedDict(('_c%d' % i, agg.size) for i, _, agg, _ in models),
  367. args.model_size_file,
  368. args.write_mode)
  369. models = [(i, df, agg.sum, flag) for i, df, agg, flag in models]
  370. df = apply_models(
  371. df,
  372. models,
  373. not args.low_mem,
  374. args.skew_broadcast_pct)
  375. df = transform_log(df, not args.no_numeric_log_col)
  376. if args.output_partitioning == 'day':
  377. partitionBy = 'day'
  378. else:
  379. partitionBy = None
  380. if args.output_ordering == 'total_random':
  381. if args.output_partitioning == 'day':
  382. df = psudo_sort_by_day_plus(spark, df, args.days)
  383. else: # none
  384. # Don't do a full sort it is expensive. Order is random so
  385. # just make it random
  386. df = df.repartition('ordinal').sortWithinPartitions('ordinal')
  387. df = df.drop('ordinal')
  388. elif args.output_ordering == 'day_random':
  389. df = psudo_sort_by_day_plus(spark, df, args.days)
  390. df = df.drop('ordinal')
  391. if args.output_partitioning != 'day':
  392. df = df.drop('day')
  393. elif args.output_ordering == 'input':
  394. if args.low_mem:
  395. # This is the slowest option. We totally messed up the order so we have to put
  396. # it back in the correct order
  397. df = df.orderBy('ordinal')
  398. else:
  399. # Applying the dictionary happened within a single task so we are already really
  400. # close to the correct order, just need to sort within the partition
  401. df = df.sortWithinPartitions('ordinal')
  402. df = df.drop('ordinal')
  403. if args.output_partitioning != 'day':
  404. df = df.drop('day')
  405. # else: any ordering so do nothing the ordering does not matter
  406. df.write.parquet(
  407. args.output_folder,
  408. mode=args.write_mode,
  409. partitionBy=partitionBy)
  410. print('=' * 100)
  411. print(_benchmark)
  412. if __name__ == '__main__':
  413. _main()