-
Notifications
You must be signed in to change notification settings - Fork 0
/
diet-restrictions-etl.py
39 lines (33 loc) · 1.29 KB
/
diet-restrictions-etl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import json
from pyspark.sql import SparkSession, functions, types
import sys
assert sys.version_info >= (3, 5) # make sure we have Python 3.5+
# add more functions as necessary
@functions.udf(returnType=types.ArrayType(types.StringType()))
def explode(row):
row = row.replace("\'", "\"")
row = row.replace("False", '\"False\"')
row = row.replace("True", '\"True\"')
js = json.loads(row)
_list = []
for key in js:
if js[key] == "True":
_list.append(key)
return _list
def main(inputs, output):
# main logic starts here
df = spark.read.parquet(inputs).where((functions.col(
'DietaryRestrictions').isNotNull()) & (functions.col(
'DietaryRestrictions') != 'None')).select("business_id", "DietaryRestrictions")
# df.rdd.map()
df = df.select('business_id', functions.explode(
explode(df['DietaryRestrictions'])).alias('DietaryRestrictions'))
df.write.json(output+"/DietaryRestrictions.json", mode="overwrite")
if __name__ == '__main__':
inputs = sys.argv[1]
output = sys.argv[2]
spark = SparkSession.builder.appName('example code').getOrCreate()
assert spark.version >= '3.0' # make sure we have Spark 3.0+
spark.sparkContext.setLogLevel('WARN')
sc = spark.sparkContext
main(inputs, output)