Apache Spark tips and recipes #1
How to groupBy and collect rows with multiple columns into a list
How to groupBy a flat dataset into a case class that will have for example a String field and list field, where that list will contain some objects
case class Example(id: String, tags: List[SomeOtherClass])
lets imagine data sets of Crypto Currency Data
Investors Portfolios and Price Reports
Basic Models
case class CryptoToken(ticker: String, count: Double)
case class CryptoInvestor(investorId: String, countryCode: String, tokens: List[CryptoToken])
case class CryptoTokenPriceReport(ticker: String, priceInUSD: Double)
Investors portfolios
val investorsPortfolios: Dataset[CryptoInvestor] = List(
CryptoInvestor("kyc-investor1", "USA", List(
CryptoToken("btc", 0.5),
CryptoToken("eth", 5)
)),
CryptoInvestor("kyc-investor2", "UK", List(
CryptoToken("ltc", 1.5),
CryptoToken("eth", 3)
))
).toDS()// result of investorsPortfolios.show(truncate = false)
+-------------+-----------+------------------------+
|investorId |countryCode|tokens |
+-------------+-----------+------------------------+
|kyc-investor1|USA |[[btc, 0.5], [eth, 5.0]]|
|kyc-investor2|UK |[[ltc, 1.5], [eth, 3.0]]|
+-------------+-----------+------------------------+
current price reports
val tokensPriceReport: Dataset[CryptoTokenPriceReport] = List(
CryptoTokenPriceReport("btc", 8000.1),
CryptoTokenPriceReport("eth", 123.0),
CryptoTokenPriceReport("ltc", 51.2)
).toDS()// result of tokensPriceReport.show(truncate = false)
+------+----------+
|ticker|priceInUSD|
+------+----------+
|btc |8000.1 |
|eth |123.0 |
|ltc |51.2 |
+------+----------+
Now how to perform a join of those datasets such that we will get a result in a format:
case class CryptoTokenWithCurrentValue(ticker: String, count: Double, priceInUSD: Double)
case class CryptoInvestorsReport(investorId: String, countryCode: String, tokensWithCurrentPricing: List[CryptoTokenWithCurrentValue])
To make join operation possible we will need to flatten a bit the CryptoInvestor entry such that it will be like a SQL Table row
case class CryptoInvestorTmpRow(investorId: String, countryCode: String, ticker: String, count: Double)val investorsAndPriceReportJoin = investorsPortfolios.flatMap { investor =>
import investor._
tokens.map { token =>
import token._
CryptoInvestorTmpRow(investorId, countryCode, ticker, count)
}
}.as[CryptoInvestorTmpRow].join(tokensPriceReport, "ticker")
After the join a new dataset will look like following
+------+-------------+-----------+-----+----------+
|ticker| investorId|countryCode|count|priceInUSD|
+------+-------------+-----------+-----+----------+
| btc|kyc-investor1| USA| 0.5| 8000.1|
| eth|kyc-investor1| USA| 5.0| 123.0|
| ltc|kyc-investor2| UK| 1.5| 51.2|
| eth|kyc-investor2| UK| 3.0| 123.0|
+------+-------------+-----------+-----+----------+
Now the main subject of this Post: how to perform the groupBy ?
we will use here collect_list and struct functions from org.apache.spark.sql.functions package
import org.apache.spark.sql.functions.{collect_list, struct}investorsAndPriceReportJoin.groupBy("investorId", "countryCode").agg(
collect_list(struct("ticker", "count", "priceInUSD")) alias "tokensWithCurrentPricing"
).as[CryptoInvestorsReport]
the result after groupBy will look like the following
+-------------+-----------+---------------------------------------+
|investorId |countryCode|tokensWithCurrentPricing |
+-------------+-----------+---------------------------------------+
|kyc-investor1|USA |[[btc, 0.5, 8000.1], [eth, 5.0, 123.0]]|
|kyc-investor2|UK |[[ltc, 1.5, 51.2], [eth, 3.0, 123.0]] |
+-------------+-----------+---------------------------------------+
Checkout the full code on GitHub:
What do you think ? Let me know in comments below to help me improve this article
Thanks for reading!