Fold method is not commonly used in Scala as we have reduce method, which gives the same functionality. When we already have a method reduce() which gives the same functionality, what is the need of fold() method? We are going to see the details as go on on this post.
What is fold() method?
Fold() method is similar to reduce but it takes the initial value that is to be used along with operation that we are going to perform on the collection of items.
The syntax of fold method is:
def fold(intialValue: J)(op: (J, J) ⇒ J): J
initialValue - The initial value for the aggregation. Ideally this value should be neutral value according to the operation that you are going to perform. 0 for addition, 1 for multiplication like that, which should not affect the final value of our operation.
op - This is the lambda function that we pass to fold method which will be performed on all the elements of collection.
Eg: Below is the code that shows sample code snippet of fold method usage.
val rdd = sparkContext.parallelize(List(1,2,3,4,5)) rdd.fold(0)(_+_)
Here the initial value for the aggregation is 0 and the lambda function supplied is addition _+_. Fold method starts at the first element and adds the initial value as the operation supplied is just an addition. Now moves to second element and adds the sum obtained by adding first value and initial value supplied. This process will continue until the last element. So the addition of all the elements along with initial value will be returned, so the output is 1+2+3+4+5+0 = 15.
In terms Scala it is very simple. In the above case we have given initial value as 0 so the output is 15. Suppose if you run below code snippet, the output will be weird.
val rdd = sparkContext.parallelize(List(1,2,3,4,5)) rdd.fold(2)(_+_)
According to the above explanation the output should be 1+2+3+4+5+2 = 17, but answer is 19. Why is it so? Why fold method is giving wrong result?
The answer is correct, but we need to understand carefully why the answer is so.
Let's go to the first line of the above code, where we have parallelized a collection. Whenever we parallelize a collection, Spark will divide the data of the collection into two partitions. So we have the 5 elements of the collection in two partitions of one rdd. Now we are applying the fold() operation on the rdd, which will be applied on all of the partitions.
Let's assume 1,2,3 are in one partition and 4,5 are in another partition. So the fold() method will be applied as below:
1+2+3+2(initial value) + 4+5+2(initial value)
So the output become 1+2+3+4+5+2+2 = 19.
For each partition the initial value will be added once. So the output will be 19 only.
To cross check, just increase or decrease the partitions number in parallelize number and check as shown below:
val rdd = sparkContext.parallelize(List(1,2,3,4,5),1) rdd.fold(2)(_+_)
Here the number of partitions given as one so the collection will be stored in one partition of rdd. So the initial value will be added only once and the sum of the elements added to it.
val rdd = sparkContext.parallelize(List(1,2,3,4,5),4) rdd.fold(2)(_+_)
Here the number of partitions is 4. So 1+2+3+4+5+2+2+2+2 = 23. For each partition the initial value will be added once. So the answer is 23.