Welcome toVigges Developer Community-Open, Learning,Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
3.7k views
in Technique[技术] by (71.8m points)

sql - Find month to date and month to go on a Pyspark dataframe

I have the following dataframe in Spark (using PySpark):

DT_BORD_REF: Timestamp column,
COUNTRY_ALPHA: Country Alpha-3 code,
working_day_flag: if the date is a working day in that country or not

I need to add two fields:

  • count of working days from the beginning of the month for that country (month to date)
  • count of working days remaining until the end of that month for that country (month to go)

It seems it's an application of a window function, but I can't figure out

+-------------------+-------------+----------------+
|        DT_BORD_REF|COUNTRY_ALPHA|working_day_flag|
+-------------------+-------------+----------------+
|2021-01-01 00:00:00|          FRA|               N|
|2021-01-01 00:00:00|          ITA|               N|
|2021-01-01 00:00:00|          BRA|               N|
|2021-01-02 00:00:00|          BRA|               N|
|2021-01-02 00:00:00|          FRA|               N|
|2021-01-02 00:00:00|          ITA|               N|
|2021-01-03 00:00:00|          ITA|               N|
|2021-01-03 00:00:00|          BRA|               N|
|2021-01-03 00:00:00|          FRA|               N|
|2021-01-04 00:00:00|          BRA|               Y|
|2021-01-04 00:00:00|          FRA|               Y|
|2021-01-04 00:00:00|          ITA|               Y|
|2021-01-05 00:00:00|          FRA|               Y|
|2021-01-05 00:00:00|          BRA|               Y|
|2021-01-05 00:00:00|          ITA|               Y|
|2021-01-06 00:00:00|          ITA|               N|
|2021-01-06 00:00:00|          FRA|               Y|
|2021-01-06 00:00:00|          BRA|               Y|
|2021-01-07 00:00:00|          ITA|               Y|
+-------------------+-------------+----------------+

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

You can do a conditional count using count_if:

df.createOrReplaceTempView('df')

result = spark.sql("""
select *,
    count_if(working_day_flag = 'Y')
        over(partition by country_alpha, trunc(dt_bord_ref, 'month') order by dt_bord_ref)
        month_to_date,
    count_if(working_day_flag = 'Y')
        over(partition by country_alpha, trunc(dt_bord_ref, 'month') order by dt_bord_ref
             rows between 1 following and unbounded following)
        month_to_go    
from df
""")

result.show()
+-------------------+-------------+----------------+-------------+-----------+
|        DT_BORD_REF|COUNTRY_ALPHA|working_day_flag|month_to_date|month_to_go|
+-------------------+-------------+----------------+-------------+-----------+
|2021-01-01 00:00:00|          BRA|               N|            0|          3|
|2021-01-02 00:00:00|          BRA|               N|            0|          3|
|2021-01-03 00:00:00|          BRA|               N|            0|          3|
|2021-01-04 00:00:00|          BRA|               Y|            1|          2|
|2021-01-05 00:00:00|          BRA|               Y|            2|          1|
|2021-01-06 00:00:00|          BRA|               Y|            3|          0|
|2021-01-01 00:00:00|          ITA|               N|            0|          3|
|2021-01-02 00:00:00|          ITA|               N|            0|          3|
|2021-01-03 00:00:00|          ITA|               N|            0|          3|
|2021-01-04 00:00:00|          ITA|               Y|            1|          2|
|2021-01-05 00:00:00|          ITA|               Y|            2|          1|
|2021-01-06 00:00:00|          ITA|               N|            2|          1|
|2021-01-07 00:00:00|          ITA|               Y|            3|          0|
|2021-01-01 00:00:00|          FRA|               N|            0|          3|
|2021-01-02 00:00:00|          FRA|               N|            0|          3|
|2021-01-03 00:00:00|          FRA|               N|            0|          3|
|2021-01-04 00:00:00|          FRA|               Y|            1|          2|
|2021-01-05 00:00:00|          FRA|               Y|            2|          1|
|2021-01-06 00:00:00|          FRA|               Y|            3|          0|
+-------------------+-------------+----------------+-------------+-----------+

If you want a similar solution in Pyspark API:

import pyspark.sql.functions as F
from pyspark.sql.window import Window

result = df.withColumn(
    'month_to_date',
    F.count(
        F.when(F.col('working_day_flag') == 'Y', 1)
    ).over(
        Window.partitionBy('country_alpha', F.trunc('dt_bord_ref', 'month'))
              .orderBy('dt_bord_ref')
    )
).withColumn(
    'month_to_go',
    F.count(
        F.when(F.col('working_day_flag') == 'Y', 1)
    ).over(
        Window.partitionBy('country_alpha', F.trunc('dt_bord_ref', 'month'))
              .orderBy('dt_bord_ref')
              .rowsBetween(1, Window.unboundedFollowing)
    )
)

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to Vigges Developer Community for programmer and developer-Open, Learning and Share
...