Defining UDF functions
PySpark User Defind Functions (UDFs) are an easy way to make python code scalable. Using PySpark one can distribute a Python function to computing cluster with many nodes, making its execution faster.
A guide for defining a simple UDF function is provided in: Databrics Documentation.
In the example below we are performing computation of the one-dimensional n-point discrete Fourier Transform (DFT) with the efficient Fast Fourier Transform using numpy.fft.fft function. Several steps have to be performed in order to accomplish the task:
1 - Extraction of NXCALS vectornumeric data to be transformed:
data = DataQuery.builder(spark).byVariables().system('CMW') \
.startTime('2019-01-01 16:00:00.000').endTime('2019-04-01 17:00:00.000') \
.variable('L4LT.PULLER-I:WAVEFORM').build().select('nxcals_value')
vectors = data.withColumn("nx_value", col("nxcals_value")["elements"]).select("nx_value")
3 - definition of return type for UDF function
schema = StructType([
StructField("real", ArrayType(DoubleType()), False),
StructField("imag", ArrayType(DoubleType()), False)
])
4 - UDF function must be defined
def fft_arrays(x):
tran = np.fft.fft(np.sin(x))
r = tran.real
i = tran.imag
return r.tolist(), i.tolist()
5 - the newly defined function must be registered:
spark_fft_array = F.udf(fft_arrays, schema)
6 - at this stage data can be processed and results displayed (note reference to the structure elements: imag and real):
result = vectors.withColumn('fft', spark_fft_array('nx_value'))
result.show()
a=result.select('nx_value','fft')\
.withColumn("fft_real", result["fft"].getItem("real"))\
.withColumn("fft_imag", result["fft"].getItem("imag"))\
.show()
Complete example:
import numpy as np
import pyspark.sql.functions as F
from nxcals.api.extraction.data.builders import DataQuery
from pyspark.sql.functions import col
from pyspark.sql.types import ArrayType
from pyspark.sql.types import DoubleType
from pyspark.sql.types import StructField
from pyspark.sql.types import StructType
data = DataQuery.builder(spark).byVariables().system('CMW') \
.startTime('2019-01-01 16:00:00.000').endTime('2019-04-01 17:00:00.000') \
.variable('L4LT.PULLER-I:WAVEFORM').build().select('nxcals_value')
vectors = data.withColumn("nx_value", col("nxcals_value")["elements"]).select("nx_value")
vectors.printSchema()
schema = StructType([
StructField("real", ArrayType(DoubleType()), False),
StructField("imag", ArrayType(DoubleType()), False)
])
def fft_arrays(x):
tran = np.fft.fft(np.sin(x))
r = tran.real
i = tran.imag
return r.tolist(), i.tolist()
spark_fft_array = F.udf(fft_arrays, schema)
result = vectors.withColumn('fft', spark_fft_array('nx_value'))
result.show()
a=result.select('nx_value','fft')\
.withColumn("fft_real", result["fft"].getItem("real"))\
.withColumn("fft_imag", result["fft"].getItem("imag"))\
.show()
Click to see expected script output...
root
|-- nx_value: array (nullable = true)
| |-- element: double (containsNull = true)
root
|-- nx_value: array (nullable = true)
| |-- element: double (containsNull = true)
|-- fft: struct (nullable = true)
| |-- real: array (nullable = false)
| | |-- element: double (containsNull = true)
| |-- imag: array (nullable = false)
| | |-- element: double (containsNull = true)
+--------------------+--------------------+--------------------+--------------------+
| nx_value| fft| fft_real| fft_imag|
+--------------------+--------------------+--------------------+--------------------+
|[0.00439453125, 0...|[[16.811978928008...|[16.8119789280082...|[0.0, 0.007864349...|
|[0.00341796875, 0...|[[16.783170489420...|[16.7831704894204...|[0.0, -0.01019607...|
|[0.0029296875, 0....|[[16.854947463144...|[16.8549474631446...|[0.0, -0.03797981...|
|[0.00341796875, 0...|[[16.899380750936...|[16.8993807509361...|[0.0, 0.020026648...|
|[0.00341796875, 0...|[[16.764127651229...|[16.7641276512299...|[0.0, -0.03394368...|
|[0.00341796875, 0...|[[16.778287699233...|[16.7782876992332...|[0.0, -0.01204766...|
|[0.00390625, 0.00...|[[16.760221438055...|[16.7602214380559...|[0.0, -0.01970519...|
|[0.00390625, 0.00...|[[16.807584439540...|[16.8075844395409...|[0.0, -0.01540521...|
|[0.00390625, 0.00...|[[16.653776668788...|[16.6537766687882...|[0.0, -0.00396300...|
|[0.0029296875, 0....|[[16.706022464185...|[16.7060224641854...|[0.0, -0.02806074...|
|[0.0029296875, 0....|[[16.697233496680...|[16.6972334966801...|[0.0, -0.05580112...|
|[0.00341796875, 0...|[[16.808561017950...|[16.8085610179509...|[0.0, 0.018091221...|
|[0.0029296875, 0....|[[16.973599082733...|[16.9735990827338...|[0.0, -0.02788702...|
|[0.00341796875, 0...|[[17.019985540267...|[17.0199855402675...|[0.0, -0.00899781...|
|[0.00390625, 0.00...|[[16.849088110147...|[16.8490881101470...|[0.0, -0.04303762...|
|[0.00341796875, 0...|[[16.811978921605...|[16.8119789216054...|[0.0, 0.009134023...|
|[0.00341796875, 0...|[[16.929165714848...|[16.9291657148487...|[0.0, -0.00706610...|
|[0.00341796875, 0...|[[16.867642627176...|[16.8676426271764...|[0.0, -0.02141728...|
|[0.0029296875, 0....|[[16.824185910810...|[16.8241859108101...|[0.0, 0.010009693...|
|[0.00341796875, 0...|[[16.870572331090...|[16.8705723310909...|[0.0, -0.03254147...|
+--------------------+--------------------+--------------------+--------------------+
only showing top 20 rows