Syed Jafer K

Its all about Trade-Offs

POTD #15 – Count all triplets with given sum in sorted array | Geeks For Geeks

Problem Statement

Geeks For Geeks : https://www.geeksforgeeks.org/problems/count-all-triplets-with-given-sum-in-sorted-array/1

Given a sorted array arr[] and a target value, the task is to count triplets (i, j, k) of valid indices, such that arr[i] + arr[j] + arr[k] = target and i < j < k.


Input: arr[] = [-3, -1, -1, 0, 1, 2], target = -2
Output: 4
Explanation: Two triplets that add up to -2 are:
arr[0] + arr[3] + arr[4] = (-3) + 0 + (1) = -2
arr[0] + arr[1] + arr[5] = (-3) + (-1) + (2) = -2
arr[0] + arr[2] + arr[5] = (-3) + (-1) + (2) = -2
arr[1] + arr[2] + arr[3] = (-1) + (-1) + (0) = -2


Input: arr[] = [-2, 0, 1, 1, 5], target = 1
Output: 0
Explanation: There is no triplet whose sum is equal to 1. 

My Approach:

Initially i tried to approach the problem, similar to this. All testcases but 1 passed. Initial time complexity is O(n3). Failed 6 times.



class Solution:
    def countTriplets(self, arr, target):
        hash_set = {}
        total = len(arr)
        cnt = 0
        
        # Build the hash_set with indices for each value in arr
        for i in range(total):
            if arr[i] not in hash_set:
                hash_set[arr[i]] = []
            hash_set[arr[i]].append(i)
        
        # Iterate through all pairs (itr, jtr)
        for itr in range(total):
            for jtr in range(itr + 1, total):
                rem = target - arr[itr] - arr[jtr]
                
                # Check for remaining value in hash_set
                if rem in hash_set:
                    # Use binary search to count indices greater than jtr
                    indices = hash_set[rem]
                    low, high = 0, len(indices)
                    
                    while low < high:
                        mid = (low + high) // 2
                        if indices[mid] > jtr:
                            high = mid
                        else:
                            low = mid + 1
                    
                    cnt += len(indices) - low

        return cnt






Then after reading blogs, switched to Two Pointer method



class Solution:
    def countTriplets(self, arr, target):
        n = len(arr)
        res = 0
 
        for i in range(n - 2):
            left = i + 1
            right = n - 1

            while left < right:
                sum = arr[i] + arr[left] + arr[right]
    
                if sum < target:
                    left += 1
    
                elif sum > target:
                    right -= 1
    
                else:
                    ele1 = arr[left]
                    ele2 = arr[right]
                    cnt1 = 0
                    cnt2 = 0
    
                    while left <= right and arr[left] == ele1:
                        left += 1
                        cnt1 += 1
    
                    while left <= right and arr[right] == ele2:
                        right -= 1
                        cnt2 += 1

                    if ele1 == ele2:
                        res += (cnt1 * (cnt1 - 1)) // 2
                    else:
                        res += (cnt1 * cnt2)
    
        return res