Problem Statement
Geeks For Geeks : https://www.geeksforgeeks.org/problems/count-all-triplets-with-given-sum-in-sorted-array/1
Given a sorted array arr[] and a target value, the task is to count triplets (i, j, k) of valid indices, such that arr[i] + arr[j] + arr[k] = target and i < j < k.
Input: arr[] = [-3, -1, -1, 0, 1, 2], target = -2
Output: 4
Explanation: Two triplets that add up to -2 are:
arr[0] + arr[3] + arr[4] = (-3) + 0 + (1) = -2
arr[0] + arr[1] + arr[5] = (-3) + (-1) + (2) = -2
arr[0] + arr[2] + arr[5] = (-3) + (-1) + (2) = -2
arr[1] + arr[2] + arr[3] = (-1) + (-1) + (0) = -2
Input: arr[] = [-2, 0, 1, 1, 5], target = 1
Output: 0
Explanation: There is no triplet whose sum is equal to 1.
My Approach:
Initially i tried to approach the problem, similar to this. All testcases but 1 passed. Initial time complexity is O(n3). Failed 6 times.

class Solution:
def countTriplets(self, arr, target):
hash_set = {}
total = len(arr)
cnt = 0
# Build the hash_set with indices for each value in arr
for i in range(total):
if arr[i] not in hash_set:
hash_set[arr[i]] = []
hash_set[arr[i]].append(i)
# Iterate through all pairs (itr, jtr)
for itr in range(total):
for jtr in range(itr + 1, total):
rem = target - arr[itr] - arr[jtr]
# Check for remaining value in hash_set
if rem in hash_set:
# Use binary search to count indices greater than jtr
indices = hash_set[rem]
low, high = 0, len(indices)
while low < high:
mid = (low + high) // 2
if indices[mid] > jtr:
high = mid
else:
low = mid + 1
cnt += len(indices) - low
return cnt
Then after reading blogs, switched to Two Pointer method
class Solution:
def countTriplets(self, arr, target):
n = len(arr)
res = 0
for i in range(n - 2):
left = i + 1
right = n - 1
while left < right:
sum = arr[i] + arr[left] + arr[right]
if sum < target:
left += 1
elif sum > target:
right -= 1
else:
ele1 = arr[left]
ele2 = arr[right]
cnt1 = 0
cnt2 = 0
while left <= right and arr[left] == ele1:
left += 1
cnt1 += 1
while left <= right and arr[right] == ele2:
right -= 1
cnt2 += 1
if ele1 == ele2:
res += (cnt1 * (cnt1 - 1)) // 2
else:
res += (cnt1 * cnt2)
return res
